diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 60860d65dddf84f42ee6225f67ff585213a7171f..4181482bee3fac8c078500f46ed98df83c328fc4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -443,6 +443,25 @@ class RailEnv(Environment): return reward + def preprocess_action(self, action, agent): + """ + Preprocess the provided action + * Change to DO_NOTHING if illegal action + * Block all actions when in waiting state + * Check MOVE_LEFT/MOVE_RIGHT actions on current position else try MOVE_FORWARD + """ + action = action_preprocessing.preprocess_raw_action(action, agent.state) + action = action_preprocessing.preprocess_action_when_waiting(action, agent.state) + + # Try moving actions on current position + current_position, current_direction = agent.position, agent.direction + if current_position is None: # Agent not added on map yet + current_position, current_direction = agent.initial_position, agent.initial_direction + + action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) + + return action + def step(self, action_dict_: Dict[int, RailEnvActions]): """ Updates rewards for the agents at a step. @@ -475,59 +494,51 @@ class RailEnv(Environment): "speed": {}, "status": {}, } - have_all_agents_ended = True # boolean flag to check if all agents are done self.motionCheck = ac.MotionCheck() # reset the motion check temp_transition_data = {} - for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed + for agent in self.agents: + i_agent = agent.handle # Generate malfunction agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random) # Get action for the agent action = action_dict_.get(i_agent, RailEnvActions.DO_NOTHING) - # TODO: Add the bottom stuff to separate function(s) - - # Preprocess action - action = action_preprocessing.preprocess_raw_action(action, agent.state) - action = action_preprocessing.preprocess_action_when_waiting(action, agent.state) - # Try moving actions on current position - current_position, current_direction = agent.position, agent.direction - agent_not_on_map = current_position is None - if agent_not_on_map: # Agent not added on map yet - current_position, current_direction = agent.initial_position, agent.initial_direction - action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) + preprocessed_action = self.preprocess_action(action, agent) # Save moving actions in not already saved - agent.action_saver.save_action_if_allowed(action, agent.state) + agent.action_saver.save_action_if_allowed(preprocessed_action, agent.state) # Calculate new position # Add agent to the map if not on it yet - if agent_not_on_map and agent.action_saver.is_action_saved: + if agent.position is None and agent.action_saver.is_action_saved: new_position = agent.initial_position new_direction = agent.initial_direction - preprocessed_action = action # When cell exit occurs apply saved action independent of other agents elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved: saved_action = agent.action_saver.saved_action # Apply action independent of other agents and get temporary new position and direction - pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction) - new_position, new_direction = pd + new_position, new_direction = self.apply_action_independent(saved_action, + self.rail, + agent.position, + agent.direction) preprocessed_action = saved_action else: new_position, new_direction = agent.position, agent.direction - preprocessed_action = action temp_transition_data[i_agent] = AgentTransitionData(position=new_position, direction=new_direction, preprocessed_action=preprocessed_action) + + # This is for checking conflicts of agents trying to occupy same cell self.motionCheck.addAgent(i_agent, agent.position, new_position) # Find conflicts - # TODO : Important - Modify conflicted positions and select one of them randomly to go to new position + self.motionCheck.find_conflicts() for agent in self.agents: diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py index 4da43c1695607785d3e779f1fe119064545fd575..a397054c6c55c854b0846d0d70c7c7209fcfa7af 100644 --- a/flatland/envs/step_utils/action_preprocessing.py +++ b/flatland/envs/step_utils/action_preprocessing.py @@ -55,6 +55,6 @@ def preprocess_moving_action(action, rail, position, direction): if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]: action = process_left_right(action, rail, position, direction) - if not check_valid_action(action, rail, position, direction): # TODO: Dipam - Not sure if this is needed + if not check_valid_action(action, rail, position, direction): action = RailEnvActions.STOP_MOVING return action \ No newline at end of file diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py index 6d0b9f406e437c5215e00d983ae9484458ee4455..47b553a8b07e61fbcc30531b62e6c788b8cfc5b5 100644 --- a/flatland/envs/step_utils/state_machine.py +++ b/flatland/envs/step_utils/state_machine.py @@ -1,4 +1,3 @@ -from attr import s from flatland.envs.step_utils.states import TrainState, StateTransitionSignals class TrainStateMachine: diff --git a/flatland/envs/step_utils/transition_utils.py b/flatland/envs/step_utils/transition_utils.py index 157db5aca96b25560dc8c1bd4fe7fc2df2e30c37..c84d6c59cd59b6f8366d28f3d0ad51bbcfc7602a 100644 --- a/flatland/envs/step_utils/transition_utils.py +++ b/flatland/envs/step_utils/transition_utils.py @@ -71,9 +71,7 @@ def check_action_on_agent(action, rail, position, direction): # If transition validity hasn't been checked yet. if transition_valid is None: - transition_valid = rail.get_transition( # TODO: Dipam - Read this one - (*position, direction), - new_direction) + transition_valid = rail.get_transition( (*position, direction), new_direction) return new_cell_valid, new_direction, new_position, transition_valid