From 4b960cbbd8bc4ee891064968e741d9aaea04aafa Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Wed, 15 Sep 2021 14:03:20 +0530
Subject: [PATCH] check bounds on position on map

---
 flatland/envs/rail_env.py                | 21 ++++++++++++++-------
 tests/test_flatland_envs_observations.py | 11 +++++++----
 2 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cec0542c..0c1e3f84 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -28,6 +28,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
 
 from flatland.envs.timetable_generators import timetable_generator
 from flatland.envs.step_utils.states import TrainState, StateTransitionSignals
+from flatland.envs.step_utils.transition_utils import check_valid_action
 from flatland.envs.step_utils import action_preprocessing
 from flatland.envs.step_utils import env_utils
 
@@ -437,6 +438,11 @@ class RailEnv(Environment):
             current_position, current_direction = agent.initial_position, agent.initial_direction
         
         action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction)
+
+        # Check transitions, bounts for executing the action in the given position and directon
+        if not check_valid_action(action, self.rail, current_position, current_direction):
+            action = RailEnvActions.STOP_MOVING
+
         return action
     
     def clear_rewards_dict(self):
@@ -579,14 +585,15 @@ class RailEnv(Environment):
             movement_allowed = movement_allowed and agent.state != TrainState.DONE
 
             # Agent is being added to map
-            if agent.state.is_on_map_state() and agent.state_machine.previous_state.is_off_map_state():
-                agent.position = agent.initial_position
-                agent.direction = agent.initial_direction
+            if agent.state.is_on_map_state():
+                if agent.state_machine.previous_state.is_off_map_state():
+                    agent.position = agent.initial_position
+                    agent.direction = agent.initial_direction
             # Speed counter completes
-            elif movement_allowed and (agent.speed_counter.is_cell_exit):
-                agent.position = agent_transition_data.position
-                agent.direction = agent_transition_data.direction
-                agent.state_machine.update_if_reached(agent.position, agent.target)
+                elif movement_allowed and (agent.speed_counter.is_cell_exit):
+                    agent.position = agent_transition_data.position
+                    agent.direction = agent_transition_data.direction
+                    agent.state_machine.update_if_reached(agent.position, agent.target)
 
             # Off map or on map state and position should match
             env_utils.state_position_sync_check(agent.state, agent.position, agent.handle)
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index 628298f0..a23bf4c6 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -86,7 +86,7 @@ def _step_along_shortest_path(env, obs_builder, rail):
                             expected_next_position[agent.handle] = neighbour
                             print("   improved (action) -> {}".format(actions[agent.handle]))
     _, rewards, dones, _ = env.step(actions)
-    return rewards
+    return rewards, dones
 
 
 def test_reward_function_conflict(rendering=False):
@@ -162,8 +162,9 @@ def test_reward_function_conflict(rendering=False):
         },
     }
     while iteration < 5:
-        rewards = _step_along_shortest_path(env, obs_builder, rail)
-
+        rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
+        if dones["__all__"]:
+            break
         for agent in env.agents:
             # assert rewards[agent.handle] == 0
             expected_position = expected_positions[iteration + 1][agent.handle]
@@ -289,7 +290,9 @@ def test_reward_function_waiting(rendering=False):
     }
     while iteration < 7:
 
-        rewards = _step_along_shortest_path(env, obs_builder, rail)
+        rewards, dones = _step_along_shortest_path(env, obs_builder, rail)
+        if dones["__all__"]:
+            break
 
         if rendering:
             renderer.render_env(show=True, show_observations=True)
-- 
GitLab