Skip to content
Snippets Groups Projects
Commit 47a23b3c authored by spiglerg's avatar spiglerg
Browse files

fixed bad tests bug

parent 340e965b
No related branches found
No related tags found
No related merge requests found
......@@ -191,16 +191,10 @@ class RailEnv(Environment):
# for i in range(len(self.agents_handles)):
for iAgent in range(self.get_num_agents()):
agent = self.agents[iAgent]
print(agent.speed_data['speed'])
if self.dones[iAgent]: # this agent has already completed...
continue
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
else:
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
if iAgent not in action_dict: # no action has been supplied for this agent
action_dict[iAgent] = RailEnvActions.DO_NOTHING
......@@ -288,6 +282,11 @@ class RailEnv(Environment):
agent.position = new_position
agent.direction = new_direction
if np.equal(agent.position, agent.target).all():
self.dones[iAgent] = True
else:
self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed']
# Check for end of episode + add global reward to all rewards!
if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]):
self.dones["__all__"] = True
......
......@@ -3,7 +3,7 @@
import numpy as np
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid4Transitions
from flatland.core.transitions import Grid4Transitions, RailEnvTransitions
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator
......@@ -53,7 +53,7 @@ def test_rail_environment_single_agent():
# | | |
# \_/\_/
transitions = Grid4Transitions([])
transitions = RailEnvTransitions()
vertical_line = cells[1]
south_symmetrical_switch = cells[6]
north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180)
......@@ -107,6 +107,7 @@ def test_rail_environment_single_agent():
if prev_pos != pos:
valid_active_actions_done += 1
# After 6 movements on this railway network, the train should be back
# to its original height on the map.
assert (initial_pos[0] == agent.position[0])
......@@ -121,9 +122,9 @@ def test_rail_environment_single_agent():
action = np.random.randint(4)
_, _, dones, _ = rail_env.step({0: action})
done = dones['__all__']
test_rail_environment_single_agent()
def test_dead_end():
transitions = Grid4Transitions([])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment