From d466718799a688a0907e080abc7e21c87ab63f14 Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Mon, 13 Sep 2021 23:40:58 +0530 Subject: [PATCH] Fix for stopped to moving in fractional speeds --- flatland/envs/agent_utils.py | 2 +- flatland/envs/rail_env.py | 12 +++++--- flatland/envs/step_utils/action_saver.py | 9 +++++- flatland/envs/step_utils/speed_counter.py | 4 ++- tests/test_multi_speed.py | 34 +++++++++++------------ 5 files changed, 37 insertions(+), 24 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index ac1ef626..20dc0325 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -224,7 +224,7 @@ class EnvAgent: old_position: {self.old_position} old_direction {self.old_direction} \n \ earliest_departure: {self.earliest_departure} latest_arrival: {self.latest_arrival} \n \ state: {str(self.state)} \n \ - malfunction_data: {self.malfunction_data} \n \ + malfunction_handler: {self.malfunction_handler} \n \ action_saver: {self.action_saver} \n \ speed_counter: {self.speed_counter}" diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6859497f..a6a15531 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -373,7 +373,7 @@ class RailEnv(Environment): st_signals = StateTransitionSignals() # Malfunction starts when in_malfunction is set to true - st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction + st_signals.in_malfunction = agent.malfunction_handler.in_malfunction # Malfunction counter complete - Malfunction ends next timestep st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete @@ -519,8 +519,8 @@ class RailEnv(Environment): new_position = agent.initial_position new_direction = agent.initial_direction - # When cell exit occurs apply saved action independent of other agents - elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved: + # If movement is allowed apply saved action independent of other agents + elif agent.action_saver.is_action_saved: saved_action = agent.action_saver.saved_action # Apply action independent of other agents and get temporary new position and direction new_position, new_direction = self.apply_action_independent(saved_action, @@ -551,7 +551,10 @@ class RailEnv(Environment): else: movement_allowed = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck - if movement_allowed: + # Position can be changed only if other cell is empty + # And either the speed counter completes or agent is being added to map + if movement_allowed and \ + (agent.speed_counter.is_cell_exit or agent.position is None): agent.position = agent_transition_data.position agent.direction = agent_transition_data.direction @@ -576,6 +579,7 @@ class RailEnv(Environment): ## Update counters (malfunction and speed) agent.speed_counter.update_counter(agent.state, agent.old_position) + # agent.state_machine.previous_state) agent.malfunction_handler.update_counter() # Clear old action when starting in new cell diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py index 5e6c8a8c..bf61076e 100644 --- a/flatland/envs/step_utils/action_saver.py +++ b/flatland/envs/step_utils/action_saver.py @@ -10,10 +10,17 @@ class ActionSaver: return self.saved_action is not None def __repr__(self): - return f"is_action_saved: {self.is_action_saved}, saved_action: {self.saved_action}" + return f"is_action_saved: {self.is_action_saved}, saved_action: {str(self.saved_action)}" def save_action_if_allowed(self, action, state): + """ + Save the action if all conditions are met + 1. It is a movement based action -> Forward, Left, Right + 2. Action is not already saved + 3. Not in a malfunction state + 4. Agent is not already done + """ if action.is_moving_action() and \ not self.is_action_saved and \ not state.is_malfunction_state() and \ diff --git a/flatland/envs/step_utils/speed_counter.py b/flatland/envs/step_utils/speed_counter.py index 1c2c7279..f4a37ebe 100644 --- a/flatland/envs/step_utils/speed_counter.py +++ b/flatland/envs/step_utils/speed_counter.py @@ -8,11 +8,13 @@ class SpeedCounter: self.reset_counter() def update_counter(self, state, old_position): - # When coming onto the map, do no update speed counter + # Can't start counting when adding train to the map if state == TrainState.MOVING and old_position is not None: self.counter += 1 self.counter = self.counter % (self.max_count + 1) + + def __repr__(self): return f"speed: {self.speed} \ max_count: {self.max_count} \ diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 6455e573..56a3a33f 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -406,26 +406,26 @@ def test_multispeed_actions_malfunction_no_blocking(): set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ - Replay( + Replay( # 0 position=(3, 9), # east dead-end direction=Grid4TransitionsEnum.EAST, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), - Replay( + Replay( # 1 position=(3, 9), direction=Grid4TransitionsEnum.EAST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), - Replay( + Replay( # 2 position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # add additional step in the cell - Replay( + Replay( # 3 position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, @@ -434,26 +434,26 @@ def test_multispeed_actions_malfunction_no_blocking(): reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step - Replay( + Replay( # 4 position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, malfunction=1, reward=env.step_penalty * 0.5 # recovered: running at speed 0.5 ), - Replay( + Replay( # 5 position=(3, 8), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), - Replay( + Replay( # 6 position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.step_penalty * 0.5 # running at speed 0.5 ), - Replay( + Replay( # 7 position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, @@ -462,57 +462,57 @@ def test_multispeed_actions_malfunction_no_blocking(): reward=env.step_penalty * 0.5 # step penalty for speed 0.5 when malfunctioning ), # agent recovers in this step; since we're at the beginning, we provide a different action although we're broken! - Replay( + Replay( # 8 position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, malfunction=1, reward=env.step_penalty * 0.5 # running at speed 0.5 ), - Replay( + Replay( # 9 position=(3, 7), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), - Replay( + Replay( # 10 position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.STOP_MOVING, reward=env.stop_penalty + env.step_penalty * 0.5 # stopping and step penalty for speed 0.5 ), - Replay( + Replay( # 11 position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.STOP_MOVING, reward=env.step_penalty * 0.5 # step penalty for speed 0.5 while stopped ), - Replay( + Replay( # 12 position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, reward=env.start_penalty + env.step_penalty * 0.5 # starting and running at speed 0.5 ), - Replay( + Replay( # 13 position=(3, 6), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), # DO_NOTHING keeps moving! - Replay( + Replay( # 14 position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.DO_NOTHING, reward=env.step_penalty * 0.5 # running at speed 0.5 ), - Replay( + Replay( # 15 position=(3, 5), direction=Grid4TransitionsEnum.WEST, action=None, reward=env.step_penalty * 0.5 # running at speed 0.5 ), - Replay( + Replay( # 16 position=(3, 4), direction=Grid4TransitionsEnum.WEST, action=RailEnvActions.MOVE_FORWARD, -- GitLab