diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 2f19bddd8aeda0e8fe8bdc886fcf21db646a22c7..60860d65dddf84f42ee6225f67ff585213a7171f 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -2,25 +2,22 @@ Definition of the RailEnv environment. """ import random -# TODO: _ this is a global method --> utils or remove later -from typing import List, NamedTuple, Optional, Dict, Tuple -import numpy as np -from numpy.lib.shape_base import vsplit -from numpy.testing._private.utils import import_nose +from typing import List, Optional, Dict, Tuple +import numpy as np +from gym.utils import seeding +from dataclasses import dataclass from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions +from flatland.core.grid.grid4 import Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position -from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate from flatland.core.transition_map import GridTransitionMap -from flatland.envs.agent_utils import Agent, EnvAgent +from flatland.envs.agent_utils import EnvAgent from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env_action import RailEnvActions -# Need to use circular imports for persistence. from flatland.envs import malfunction_generators as mal_gen from flatland.envs import rail_generators as rail_gen from flatland.envs import line_generators as line_gen @@ -29,31 +26,11 @@ from flatland.envs import persistence from flatland.envs import agent_chains as ac from flatland.envs.observations import GlobalObsForRailEnv -from gym.utils import seeding - -# Direct import of objects / classes does not work with circular imports. -# from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData -# from flatland.envs.observations import GlobalObsForRailEnv -# from flatland.envs.rail_generators import random_rail_generator, RailGenerator -# from flatland.envs.line_generators import random_line_generator, LineGenerator - from flatland.envs.timetable_generators import timetable_generator -from flatland.envs.step_utils.states import TrainState -from flatland.envs.step_utils.transition_utils import check_action - -# Env Step Facelift imports -from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, preprocess_moving_action, preprocess_action_when_waiting - -# Adrian Egli performance fix (the fast methods brings more than 50%) -def fast_isclose(a, b, rtol): - return (a < (b + rtol)) or (a < (b - rtol)) - -def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: - if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None - return False - else: - return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] +from flatland.envs.step_utils.states import TrainState, StateTransitionSignals +from flatland.envs.step_utils import transition_utils +from flatland.envs.step_utils import action_preprocessing class RailEnv(Environment): """ @@ -406,22 +383,35 @@ class RailEnv(Environment): def apply_action_independent(self, action, rail, position, direction): if action.is_moving_action(): - new_direction, _ = check_action(action, position, direction, rail) + new_direction, _ = transition_utils.check_action(action, position, direction, rail) new_position = get_new_position(position, new_direction) else: new_position, new_direction = position, direction return new_position, direction def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed): - st_signals = {} + st_signals = StateTransitionSignals() - st_signals['malfunction_onset'] = agent.malfunction_handler.in_malfunction - st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete - st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure - st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING) - st_signals['valid_movement_action_given'] = preprocessed_action.is_moving_action() - st_signals['target_reached'] = fast_position_equal(agent.position, agent.target) - st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information + # Malfunction onset - Malfunction starts + st_signals.malfunction_onset = agent.malfunction_handler.in_malfunction + + # Malfunction counter complete - Malfunction ends next timestep + st_signals.malfunction_counter_complete = agent.malfunction_handler.malfunction_counter_complete + + # Earliest departure reached - Train is allowed to move now + st_signals.earliest_departure_reached = self._elapsed_steps >= agent.earliest_departure + + # Stop Action Given + st_signals.stop_action_given = (preprocessed_action == RailEnvActions.STOP_MOVING) + + # Valid Movement action Given + st_signals.valid_movement_action_given = preprocessed_action.is_moving_action() + + # Target Reached + st_signals.target_reached = fast_position_equal(agent.position, agent.target) + + # Movement conflict - Multiple trains trying to move into same cell + st_signals.movement_conflict = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information return st_signals @@ -489,7 +479,7 @@ class RailEnv(Environment): self.motionCheck = ac.MotionCheck() # reset the motion check - temp_saved_data = {} # TODO : Change name + temp_transition_data = {} for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed # Generate malfunction @@ -500,15 +490,15 @@ class RailEnv(Environment): # TODO: Add the bottom stuff to separate function(s) # Preprocess action - action = preprocess_raw_action(action, agent.state) - action = preprocess_action_when_waiting(action, agent.state) + action = action_preprocessing.preprocess_raw_action(action, agent.state) + action = action_preprocessing.preprocess_action_when_waiting(action, agent.state) # Try moving actions on current position current_position, current_direction = agent.position, agent.direction agent_not_on_map = current_position is None if agent_not_on_map: # Agent not added on map yet current_position, current_direction = agent.initial_position, agent.initial_direction - action = preprocess_moving_action(action, self.rail, current_position, current_direction) + action = action_preprocessing.preprocess_moving_action(action, self.rail, current_position, current_direction) # Save moving actions in not already saved agent.action_saver.save_action_if_allowed(action, agent.state) @@ -516,24 +506,25 @@ class RailEnv(Environment): # Calculate new position # Add agent to the map if not on it yet if agent_not_on_map and agent.action_saver.is_action_saved: - temp_new_position = agent.initial_position - temp_new_direction = agent.initial_direction + new_position = agent.initial_position + new_direction = agent.initial_direction preprocessed_action = action # When cell exit occurs apply saved action independent of other agents elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved: saved_action = agent.action_saver.saved_action # Apply action independent of other agents and get temporary new position and direction - temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction) - temp_new_position, temp_new_direction = temp_pd + pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction) + new_position, new_direction = pd preprocessed_action = saved_action else: - temp_new_position, temp_new_direction = agent.position, agent.direction + new_position, new_direction = agent.position, agent.direction preprocessed_action = action - # TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1) - temp_saved_data[i_agent] = temp_new_position, temp_new_direction, preprocessed_action - self.motionCheck.addAgent(i_agent, agent.position, temp_new_position) + temp_transition_data[i_agent] = AgentTransitionData(position=new_position, + direction=new_direction, + preprocessed_action=preprocessed_action) + self.motionCheck.addAgent(i_agent, agent.position, new_position) # Find conflicts # TODO : Important - Modify conflicted positions and select one of them randomly to go to new position @@ -541,23 +532,19 @@ class RailEnv(Environment): for agent in self.agents: i_agent = agent.handle + agent_transition_data = temp_transition_data[i_agent] ## Update positions - movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck - # TODO : Important : Original code rechecks the next position here again - not sure why? TAG#1 - preprocessed_action = temp_saved_data[i_agent][2] # TODO : Important : Make this namedtuple or class - - # TODO : Looks like a hacky conditionm, improve the handling if agent.malfunction_handler.in_malfunction: movement_allowed = False + else: + movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck if movement_allowed: - final_new_position, final_new_direction = temp_saved_data[i_agent][:2] # TODO : Important : Make this namedtuple or class - else: - final_new_position = agent.position - final_new_direction = agent.direction - agent.position = final_new_position - agent.direction = final_new_direction + agent.position = agent_transition_data.position + agent.direction = agent_transition_data.direction + + preprocessed_action = agent_transition_data.preprocessed_action ## Update states state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed) @@ -565,8 +552,8 @@ class RailEnv(Environment): agent.state_machine.step() agent.state = agent.state_machine.state # TODO : Make this a property instead? - # TODO : Important : Looks like a hacky condiition, improve the handling - if agent.state == TrainState.DONE: + # Remove agent is required + if self.remove_agents_at_target and agent.state == TrainState.DONE: agent.position = None ## Update rewards @@ -661,3 +648,21 @@ class RailEnv(Environment): def save(self, filename): print("deprecated call to env.save() - pls call RailEnvPersister.save()") persistence.RailEnvPersister.save(self, filename) + +@dataclass(repr=True) +class AgentTransitionData: + """ Class for keeping track of temporary agent data for position update """ + position : Tuple[int, int] + direction : Grid4Transitions + preprocessed_action : RailEnvActions + + +# Adrian Egli performance fix (the fast methods brings more than 50%) +def fast_isclose(a, b, rtol): + return (a < (b + rtol)) or (a < (b - rtol)) + +def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: + if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None + return False + else: + return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] diff --git a/flatland/envs/rail_env_action.py b/flatland/envs/rail_env_action.py index f583eb72568f375e3a066cbf355771bb9b60e038..8665897f949294a9a1bf50fdc624de7907eca714 100644 --- a/flatland/envs/rail_env_action.py +++ b/flatland/envs/rail_env_action.py @@ -20,7 +20,7 @@ class RailEnvActions(IntEnum): }[a] @classmethod - def check_valid_action(cls, action): + def is_action_valid(cls, action): return action in cls._value2member_map_ def is_moving_action(self): diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py index 98c42b15961b515e8c07eb89b521453a581e82f2..4da43c1695607785d3e779f1fe119064545fd575 100644 --- a/flatland/envs/step_utils/action_preprocessing.py +++ b/flatland/envs/step_utils/action_preprocessing.py @@ -5,7 +5,7 @@ from flatland.envs.step_utils.transition_utils import check_valid_action def process_illegal_action(action: RailEnvActions): - if not RailEnvActions.check_valid_action(action): + if not RailEnvActions.is_action_valid(action): return RailEnvActions.DO_NOTHING else: return RailEnvActions(action) diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py index e42a829d2018c3c540ddd0f0e8c249530333abef..6d0b9f406e437c5215e00d983ae9484458ee4455 100644 --- a/flatland/envs/step_utils/state_machine.py +++ b/flatland/envs/step_utils/state_machine.py @@ -1,11 +1,11 @@ from attr import s -from flatland.envs.step_utils.states import TrainState +from flatland.envs.step_utils.states import TrainState, StateTransitionSignals class TrainStateMachine: def __init__(self, initial_state=TrainState.WAITING): self._initial_state = initial_state self._state = initial_state - self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple + self.st_signals = StateTransitionSignals() self.next_state = None def _handle_waiting(self): @@ -13,25 +13,25 @@ class TrainStateMachine: # TODO: Important - The malfunction handling is not like proper state machine # Both transition signals can happen at the same time # Atleast mention it in the diagram - if self.st_signals['malfunction_onset']: + if self.st_signals.malfunction_onset: self.next_state = TrainState.MALFUNCTION_OFF_MAP - elif self.st_signals['earliest_departure_reached']: + elif self.st_signals.earliest_departure_reached: self.next_state = TrainState.READY_TO_DEPART else: self.next_state = TrainState.WAITING def _handle_ready_to_depart(self): """ Can only go to MOVING if a valid action is provided """ - if self.st_signals['malfunction_onset']: + if self.st_signals.malfunction_onset: self.next_state = TrainState.MALFUNCTION_OFF_MAP - elif self.st_signals['valid_movement_action_given']: + elif self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING else: self.next_state = TrainState.READY_TO_DEPART def _handle_malfunction_off_map(self): - if self.st_signals['malfunction_counter_complete']: - if self.st_signals['earliest_departure_reached']: + if self.st_signals.malfunction_counter_complete: + if self.st_signals.earliest_departure_reached: self.next_state = TrainState.READY_TO_DEPART else: self.next_state = TrainState.STOPPED @@ -39,29 +39,29 @@ class TrainStateMachine: self.next_state = TrainState.WAITING def _handle_moving(self): - if self.st_signals['malfunction_onset']: + if self.st_signals.malfunction_onset: self.next_state = TrainState.MALFUNCTION - elif self.st_signals['target_reached']: + elif self.st_signals.target_reached: self.next_state = TrainState.DONE - elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']: + elif self.st_signals.stop_action_given or self.st_signals.movement_conflict: self.next_state = TrainState.STOPPED else: self.next_state = TrainState.MOVING def _handle_stopped(self): - if self.st_signals['malfunction_onset']: + if self.st_signals.malfunction_onset: self.next_state = TrainState.MALFUNCTION - elif self.st_signals['valid_movement_action_given']: + elif self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING else: self.next_state = TrainState.STOPPED def _handle_malfunction(self): - if self.st_signals['malfunction_counter_complete'] and \ - self.st_signals['valid_movement_action_given']: + if self.st_signals.malfunction_counter_complete and \ + self.st_signals.valid_movement_action_given: self.next_state = TrainState.MOVING - elif self.st_signals['malfunction_counter_complete'] and \ - (self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']): + elif self.st_signals.malfunction_counter_complete and \ + (self.st_signals.stop_action_given or self.st_signals.movement_conflict): self.next_state = TrainState.STOPPED else: self.next_state = TrainState.MALFUNCTION @@ -134,7 +134,7 @@ class TrainStateMachine: return self.st_signals def set_transition_signals(self, state_transition_signals): - self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error + self.st_signals = state_transition_signals diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py index 4e612ae842fd9da2e206f010a72bca2e0b9f5608..d7cfd2b24a41ccfffc1178b0d0625b5e1d617752 100644 --- a/flatland/envs/step_utils/states.py +++ b/flatland/envs/step_utils/states.py @@ -1,5 +1,5 @@ from enum import IntEnum - +from dataclasses import dataclass class TrainState(IntEnum): WAITING = 0 READY_TO_DEPART = 1 @@ -22,6 +22,13 @@ class TrainState(IntEnum): def is_on_map_state(self): return self.value in [self.MOVING, self.STOPPED, self.MALFUNCTION] - - +@dataclass(repr=True) +class StateTransitionSignals: + malfunction_onset : bool = False + malfunction_counter_complete : bool = False + earliest_departure_reached : bool = False + stop_action_given : bool = False + valid_movement_action_given : bool = False + target_reached : bool = False + movement_conflict : bool = False