diff --git a/examples/training_navigation.py b/examples/training_navigation.py index 456a4a03b3a1716f4c69516b8132aa85293bdf91..a7b2fbd391d5991d0d5a01da6eba425df14abcd9 100644 --- a/examples/training_navigation.py +++ b/examples/training_navigation.py @@ -20,10 +20,10 @@ transition_probability = [5, # empty cell - Case 0 0] # Case 7 - dead end # Example generate a random rail -env = RailEnv(width=15, - height=15, +env = RailEnv(width=10, + height=10, rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=3) + number_of_agents=5) env_renderer = RenderTool(env, gl="QT") handle = env.get_agent_handles() diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6750b6a8b0c4854762066dda5fedd8872683e1ac..0e150c62a2c8ef0e5d5623abf1407c9570606e3b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -550,6 +550,8 @@ class RailEnv(Environment): if handle not in action_dict: continue + if self.dones[handle]: + continue action = action_dict[handle] if action < 0 or action > 3: @@ -637,8 +639,7 @@ class RailEnv(Environment): # if agent is not in target position, add step penalty if self.agents_position[i][0] == self.agents_target[i][0] and \ - self.agents_position[i][1] == self.agents_target[i][1] and \ - action_dict[handle] == 0: + self.agents_position[i][1] == self.agents_target[i][1]: self.dones[handle] = True else: self.rewards_dict[handle] += step_penalty diff --git a/tests/test_environments.py b/tests/test_environments.py index 198104fc2c56db2dac0573939bbad2d071d5618b..b46bb38828285401a85dee9000fd5935819c9342 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -140,10 +140,7 @@ def test_dead_end(): _ = rail_env.step({0: 1}) _ = rail_env.step({0: 3}) assert (rail_env.agents_position[0] == prev_pos) - if rail_env.agents_position[0] != rail_env.agents_target[0]: - _, _, dones, _ = rail_env.step({0: 2}) - else: - _, _, dones, _ = rail_env.step({0: 0}) + _, _, dones, _ = rail_env.step({0: 2}) if i < 5: assert (not dones[0] and not dones['__all__'])