From c69ceb5e9627a5ce098f1dec5c91d31cef4716ae Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 26 Apr 2019 10:02:55 +0200 Subject: [PATCH] fixed bug introduced by me in env.step --- examples/training_navigation.py | 6 +++--- flatland/envs/rail_env.py | 5 +++-- tests/test_environments.py | 5 +---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 456a4a03..a7b2fbd3 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -20,10 +20,10 @@ transition_probability = [5, # empty cell - Case 0 0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=15, - height=15, +env = RailEnv(width=10, + height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=3) + number_of_agents=5) env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6750b6a8..0e150c62 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -550,6 +550,8 @@ class RailEnv(Environment): if handle not in action_dict: continue + if self.dones[handle]: + continue action = action_dict[handle] if action < 0 or action > 3: @@ -637,8 +639,7 @@ class RailEnv(Environment): # if agent is not in target position, add step penalty if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1] and \ - action_dict[handle] == 0: + self.agents_position[i][1] == self.agents_target[i][1]: self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty diff --git a/tests/test_environments.py b/tests/test_environments.py index 198104fc..b46bb388 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -140,10 +140,7 @@ def test_dead_end(): _ = rail_env.step({0: 1}) _ = rail_env.step({0: 3}) assert (rail_env.agents_position[0] == prev_pos) - if rail_env.agents_position[0] != rail_env.agents_target[0]: - _, _, dones, _ = rail_env.step({0: 2}) - else: - _, _, dones, _ = rail_env.step({0: 0}) + _, _, dones, _ = rail_env.step({0: 2}) if i < 5: assert (not dones[0] and not dones['__all__']) -- GitLab