diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 58df3a1418a63e07cdc9617af85673c9cdaf1fa4..8cf6d52f383ec8f4e271eb0765d32bc0c763307a 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -191,16 +191,10 @@ class RailEnv(Environment): # for i in range(len(self.agents_handles)): for iAgent in range(self.get_num_agents()): agent = self.agents[iAgent] - print(agent.speed_data['speed']) if self.dones[iAgent]: # this agent has already completed... continue - if np.equal(agent.position, agent.target).all(): - self.dones[iAgent] = True - else: - self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] - if iAgent not in action_dict: # no action has been supplied for this agent action_dict[iAgent] = RailEnvActions.DO_NOTHING @@ -288,6 +282,11 @@ class RailEnv(Environment): agent.position = new_position agent.direction = new_direction + if np.equal(agent.position, agent.target).all(): + self.dones[iAgent] = True + else: + self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] + # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_environments.py b/tests/test_environments.py index 11f0acba2fd54df63c62047f8559897e7d222e72..aa24467dd1d548a2b68a408f300089ee8135c639 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -3,7 +3,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import Grid4Transitions +from flatland.core.transitions import Grid4Transitions, RailEnvTransitions from flatland.envs.agent_utils import EnvAgent from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator @@ -53,7 +53,7 @@ def test_rail_environment_single_agent(): # | | | # \_/\_/ - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) @@ -107,6 +107,7 @@ def test_rail_environment_single_agent(): if prev_pos != pos: valid_active_actions_done += 1 + # After 6 movements on this railway network, the train should be back # to its original height on the map. assert (initial_pos[0] == agent.position[0]) @@ -121,9 +122,9 @@ def test_rail_environment_single_agent(): action = np.random.randint(4) _, _, dones, _ = rail_env.step({0: action}) - done = dones['__all__'] +test_rail_environment_single_agent() def test_dead_end(): transitions = Grid4Transitions([])