From 47a23b3c2d8a7c244f857448ccf860d05b23c1e1 Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Wed, 19 Jun 2019 19:42:32 +0200 Subject: [PATCH] fixed bad tests bug --- flatland/envs/rail_env.py | 11 +++++------ tests/__init__.py | 0 tests/test_environments.py | 7 ++++--- 3 files changed, 9 insertions(+), 9 deletions(-) create mode 100644 tests/__init__.py diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 58df3a14..8cf6d52f 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -191,16 +191,10 @@ class RailEnv(Environment): # for i in range(len(self.agents_handles)): for iAgent in range(self.get_num_agents()): agent = self.agents[iAgent] - print(agent.speed_data['speed']) if self.dones[iAgent]: # this agent has already completed... continue - if np.equal(agent.position, agent.target).all(): - self.dones[iAgent] = True - else: - self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] - if iAgent not in action_dict: # no action has been supplied for this agent action_dict[iAgent] = RailEnvActions.DO_NOTHING @@ -288,6 +282,11 @@ class RailEnv(Environment): agent.position = new_position agent.direction = new_direction + if np.equal(agent.position, agent.target).all(): + self.dones[iAgent] = True + else: + self.rewards_dict[iAgent] += step_penalty * agent.speed_data['speed'] + # Check for end of episode + add global reward to all rewards! if np.all([np.array_equal(agent2.position, agent2.target) for agent2 in self.agents]): self.dones["__all__"] = True diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_environments.py b/tests/test_environments.py index 11f0acba..aa24467d 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -3,7 +3,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import Grid4Transitions +from flatland.core.transitions import Grid4Transitions, RailEnvTransitions from flatland.envs.agent_utils import EnvAgent from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator @@ -53,7 +53,7 @@ def test_rail_environment_single_agent(): # | | | # \_/\_/ - transitions = Grid4Transitions([]) + transitions = RailEnvTransitions() vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) @@ -107,6 +107,7 @@ def test_rail_environment_single_agent(): if prev_pos != pos: valid_active_actions_done += 1 + # After 6 movements on this railway network, the train should be back # to its original height on the map. assert (initial_pos[0] == agent.position[0]) @@ -121,9 +122,9 @@ def test_rail_environment_single_agent(): action = np.random.randint(4) _, _, dones, _ = rail_env.step({0: action}) - done = dones['__all__'] +test_rail_environment_single_agent() def test_dead_end(): transitions = Grid4Transitions([]) -- GitLab