From 94feb03955515de836cbb75e90b690ac4c2ca7bd Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Tue, 14 Sep 2021 17:31:26 +0530 Subject: [PATCH] position update allowed on cell exit and stopped state --- flatland/envs/rail_env.py | 25 +++++++++++++------ .../envs/step_utils/action_preprocessing.py | 10 ++++---- .../envs/step_utils/malfunction_handler.py | 12 ++++++--- flatland/envs/step_utils/state_machine.py | 17 ++++++++++--- 4 files changed, 44 insertions(+), 20 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index e7f736f5..ce5a09b3 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -431,7 +431,7 @@ class RailEnv(Environment): * Block all actions when in waiting state * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD """ - action = action_preprocessing.preprocess_raw_action(action, agent.state) + action = action_preprocessing.preprocess_raw_action(action, agent.state, agent.action_saver.saved_action) action = action_preprocessing.preprocess_action_when_waiting(action, agent.state) # Try moving actions on current position @@ -440,7 +440,6 @@ class RailEnv(Environment): current_position, current_direction = agent.initial_position, agent.initial_direction action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) - return action def clear_rewards_dict(self): @@ -513,6 +512,9 @@ class RailEnv(Environment): # Save moving actions in not already saved agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state) + # Train's next position can change if current stopped in a fractional speed or train is at cell's exit + position_update_allowed = (agent.speed_counter.is_cell_exit or agent.state == TrainState.STOPPED) + # Calculate new position # Add agent to the map if not on it yet if agent.position is None and agent.action_saver.is_action_saved: @@ -520,7 +522,7 @@ class RailEnv(Environment): new_direction = agent.initial_direction # If movement is allowed apply saved action independent of other agents - elif agent.action_saver.is_action_saved: + elif agent.action_saver.is_action_saved and position_update_allowed: saved_action = agent.action_saver.saved_action # Apply action independent of other agents and get temporary new position and direction new_position, new_direction = self.apply_action_independent(saved_action, @@ -557,7 +559,7 @@ class RailEnv(Environment): (agent.speed_counter.is_cell_exit or agent.position is None): agent.position = agent_transition_data.position agent.direction = agent_transition_data.direction - + preprocessed_action = agent_transition_data.preprocessed_action ## Update states @@ -565,9 +567,8 @@ class RailEnv(Environment): agent.state_machine.set_transition_signals(state_transition_signals) agent.state_machine.step() - if agent.state.is_on_map_state() and agent.position is None: - raise ValueError("Agent ID {} Agent State {} not matching with Agent Position {} ".format( - agent.handle, str(agent.state), str(agent.position) )) + # Off map or on map state and position should match + state_position_sync_check(agent.state, agent.position, agent.handle) # Handle done state actions, optionally remove agents self.handle_done_state(agent) @@ -583,7 +584,7 @@ class RailEnv(Environment): agent.malfunction_handler.update_counter() # Clear old action when starting in new cell - if agent.speed_counter.is_cell_entry: + if agent.speed_counter.is_cell_entry and agent.position is not None: agent.action_saver.clear_saved_action() # Check if episode has ended and update rewards and dones @@ -687,3 +688,11 @@ def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: return False else: return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] + +def state_position_sync_check(state, position, i_agent): + if state.is_on_map_state() and position is None: + raise ValueError("Agent ID {} Agent State {} is on map Agent Position {} if off map ".format( + i_agent, str(state), str(position) )) + elif state.is_off_map_state() and position is not None: + raise ValueError("Agent ID {} Agent State {} is off map Agent Position {} if on map ".format( + i_agent, str(state), str(position) )) diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py index a397054c..47f06e2c 100644 --- a/flatland/envs/step_utils/action_preprocessing.py +++ b/flatland/envs/step_utils/action_preprocessing.py @@ -11,9 +11,11 @@ def process_illegal_action(action: RailEnvActions): return RailEnvActions(action) -def process_do_nothing(state: TrainState): +def process_do_nothing(state: TrainState, saved_action: RailEnvActions): if state == TrainState.MOVING: action = RailEnvActions.MOVE_FORWARD + elif saved_action: + action = saved_action else: action = RailEnvActions.STOP_MOVING return action @@ -34,7 +36,7 @@ def preprocess_action_when_waiting(action, state): return action -def preprocess_raw_action(action, state): +def preprocess_raw_action(action, state, saved_action): """ Preprocesses actions to handle different situations of usage of action based on context - DO_NOTHING is converted to FORWARD if train is moving @@ -43,7 +45,7 @@ def preprocess_raw_action(action, state): action = process_illegal_action(action) if action == RailEnvActions.DO_NOTHING: - action = process_do_nothing(state) + action = process_do_nothing(state, saved_action) return action @@ -55,6 +57,4 @@ def preprocess_moving_action(action, rail, position, direction): if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]: action = process_left_right(action, rail, position, direction) - if not check_valid_action(action, rail, position, direction): - action = RailEnvActions.STOP_MOVING return action \ No newline at end of file diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py index 2ba72643..bf1f188f 100644 --- a/flatland/envs/step_utils/malfunction_handler.py +++ b/flatland/envs/step_utils/malfunction_handler.py @@ -10,6 +10,7 @@ def get_number_of_steps_to_break(malfunction_generator, np_random): class MalfunctionHandler: def __init__(self): self._malfunction_down_counter = 0 + self.num_malfunctions = 0 @property def in_malfunction(self): @@ -33,6 +34,7 @@ class MalfunctionHandler: # Only set new malfunction value if old malfunction is completed if self._malfunction_down_counter == 0: self._malfunction_down_counter = val + self.num_malfunctions += 1 def generate_malfunction(self, malfunction_generator, np_random): num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random) @@ -44,16 +46,20 @@ class MalfunctionHandler: def __repr__(self): return f"malfunction_down_counter: {self._malfunction_down_counter} \ - in_malfunction: {self.in_malfunction}" + in_malfunction: {self.in_malfunction} \ + num_malfunctions: {self.num_malfunctions}" def to_dict(self): - return {"malfunction_down_counter": self._malfunction_down_counter} + return {"malfunction_down_counter": self._malfunction_down_counter, + "num_malfunctions": self.num_malfunctions} def from_dict(self, load_dict): self._malfunction_down_counter = load_dict['malfunction_down_counter'] + self.num_malfunctions = load_dict['num_malfunctions'] def __eq__(self, other): - return self._malfunction_down_counter == other._malfunction_down_counter + return self._malfunction_down_counter == other._malfunction_down_counter and \ + self.num_malfunctions == other.num_malfunctions diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py index 78c9883f..58b028b6 100644 --- a/flatland/envs/step_utils/state_machine.py +++ b/flatland/envs/step_utils/state_machine.py @@ -31,12 +31,21 @@ class TrainStateMachine: def _handle_malfunction_off_map(self): if self.st_signals.malfunction_counter_complete: + if self.st_signals.earliest_departure_reached: - self.next_state = TrainState.READY_TO_DEPART + + if self.st_signals.valid_movement_action_given: + self.next_state = TrainState.MOVING + elif self.st_signals.stop_action_given: + self.next_state = TrainState.STOPPED + else: + self.next_state = TrainState.READY_TO_DEPART + else: - self.next_state = TrainState.STOPPED + self.next_state = TrainState.WAITING + else: - self.next_state = TrainState.WAITING + self.next_state = TrainState.MALFUNCTION_OFF_MAP def _handle_moving(self): if self.st_signals.in_malfunction: @@ -61,7 +70,7 @@ class TrainStateMachine: self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING elif self.st_signals.malfunction_counter_complete and \ - (self.st_signals.stop_action_given or self.st_signals.movement_conflict): + (self.st_signals.stop_action_given or self.st_signals.movement_conflict): self.next_state = TrainState.STOPPED else: self.next_state = TrainState.MALFUNCTION -- GitLab