diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index fffe7ff786a32a6796af9667f1dfb9a3eb92ce9c..632caeea7e416895d36ce845e19917c2cc94d76d 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -2,21 +2,19 @@ from flatland.envs.rail_trainrun_data_structures import Waypoint import numpy as np from enum import IntEnum +from flatland.envs.step_utils.states import TrainState from itertools import starmap from typing import Tuple, Optional, NamedTuple, List from attr import attr, attrs, attrib, Factory from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.timetable_utils import Line - -class RailAgentStatus(IntEnum): - WAITING = 0 - READY_TO_DEPART = 1 # not in grid yet (position is None) -> prediction as if it were at initial position - ACTIVE = 2 # in grid (position is not None), not done -> prediction is remaining path - DONE = 3 # in grid (position is not None), but done -> prediction is stay at target forever - DONE_REMOVED = 4 # removed from grid (position is None) -> prediction is None +from flatland.envs.schedule_utils import Schedule +from flatland.envs.step_utils.action_saver import ActionSaver +from flatland.envs.step_utils.speed_counter import SpeedCounter +from flatland.envs.step_utils.state_machine import TrainStateMachine +from flatland.envs.step_utils.malfunction_handler import MalfunctionHandler Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), ('initial_direction', Grid4TransitionsEnum), @@ -28,11 +26,16 @@ Agent = NamedTuple('Agent', [('initial_position', Tuple[int, int]), ('speed_data', dict), ('malfunction_data', dict), ('handle', int), - ('status', RailAgentStatus), ('position', Tuple[int, int]), ('arrival_time', int), ('old_direction', Grid4TransitionsEnum), - ('old_position', Tuple[int, int])]) + ('old_position', Tuple[int, int]), + ('speed_counter', SpeedCounter), + ('action_saver', ActionSaver), + ('state', TrainState), + ('state_machine', TrainStateMachine), + ('malfunction_handler', MalfunctionHandler), + ]) @attrs @@ -65,7 +68,15 @@ class EnvAgent: handle = attrib(default=None) # INIT TILL HERE IN _from_line() - status = attrib(default=RailAgentStatus.WAITING, type=RailAgentStatus) + # Env step facelift + speed_counter = attrib(default = None, type=SpeedCounter) + action_saver = attrib(default = Factory(lambda: ActionSaver()), type=ActionSaver) + state_machine = attrib(default= Factory(lambda: TrainStateMachine(initial_state=TrainState.WAITING)) , + type=TrainStateMachine) + malfunction_handler = attrib(default = Factory(lambda: MalfunctionHandler()), type=MalfunctionHandler) + + state = attrib(default=TrainState.WAITING, type=TrainState) + position = attrib(default=None, type=Optional[Tuple[int, int]]) # NEW : EnvAgent Reward Handling @@ -75,6 +86,7 @@ class EnvAgent: old_direction = attrib(default=None) old_position = attrib(default=None) + def reset(self): """ Resets the agents to their initial values of the episode. Called after ScheduleTime generation. @@ -82,14 +94,6 @@ class EnvAgent: self.position = None # TODO: set direction to None: https://gitlab.aicrowd.com/flatland/flatland/issues/280 self.direction = self.initial_direction - - if (self.earliest_departure == 0): - self.status = RailAgentStatus.READY_TO_DEPART - else: - self.status = RailAgentStatus.WAITING - - self.arrival_time = None - self.old_position = None self.old_direction = None self.moving = False @@ -103,48 +107,42 @@ class EnvAgent: self.malfunction_data['nr_malfunctions'] = 0 self.malfunction_data['moving_before_malfunction'] = False - # NEW : Callables - def get_shortest_path(self, distance_map) -> List[Waypoint]: - from flatland.envs.rail_env_shortest_paths import get_shortest_paths # Circular dep fix - return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle] - - def get_travel_time_on_shortest_path(self, distance_map) -> int: - shortest_path = self.get_shortest_path(distance_map) - if shortest_path is not None: - distance = len(shortest_path) - else: - distance = 0 - speed = self.speed_data['speed'] - return int(np.ceil(distance / speed)) - - def get_time_remaining_until_latest_arrival(self, elapsed_steps: int) -> int: - return self.latest_arrival - elapsed_steps - - def get_current_delay(self, elapsed_steps: int, distance_map) -> int: - ''' - +ve if arrival time is projected before latest arrival - -ve if arrival time is projected after latest arrival - ''' - return self.get_time_remaining_until_latest_arrival(elapsed_steps) - \ - self.get_travel_time_on_shortest_path(distance_map) + self.action_saver.clear_saved_action() + self.speed_counter.reset_counter() + self.state_machine.reset() def to_agent(self) -> Agent: - return Agent(initial_position=self.initial_position, initial_direction=self.initial_direction, - direction=self.direction, target=self.target, moving=self.moving, earliest_departure=self.earliest_departure, - latest_arrival=self.latest_arrival, speed_data=self.speed_data, malfunction_data=self.malfunction_data, - handle=self.handle, status=self.status, position=self.position, arrival_time=self.arrival_time, - old_direction=self.old_direction, old_position=self.old_position) + return Agent(initial_position=self.initial_position, + initial_direction=self.initial_direction, + direction=self.direction, + target=self.target, + moving=self.moving, + earliest_departure=self.earliest_departure, + latest_arrival=self.latest_arrival, + speed_data=self.speed_data, + malfunction_data=self.malfunction_data, + handle=self.handle, + state=self.state, + position=self.position, + old_direction=self.old_direction, + old_position=self.old_position, + speed_counter=self.speed_counter, + action_saver=self.action_saver, + state_machine=self.state_machine, + malfunction_handler=self.malfunction_handler) @classmethod def from_line(cls, line: Line): """ Create a list of EnvAgent from lists of positions, directions and targets """ speed_datas = [] - - for i in range(len(line.agent_positions)): + speed_counters = [] + for i in range(len(schedule.agent_positions)): + speed = schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0 speed_datas.append({'position_fraction': 0.0, - 'speed': line.agent_speeds[i] if line.agent_speeds is not None else 1.0, + 'speed': speed, 'transition_action_on_cellexit': 0}) + speed_counters.append( SpeedCounter(speed=speed) ) malfunction_datas = [] for i in range(len(line.agent_positions)): @@ -153,17 +151,19 @@ class EnvAgent: i] if line.agent_malfunction_rates is not None else 0., 'next_malfunction': 0, 'nr_malfunctions': 0}) - - return list(starmap(EnvAgent, zip(line.agent_positions, - line.agent_directions, - line.agent_directions, - line.agent_targets, - [False] * len(line.agent_positions), - [None] * len(line.agent_positions), # earliest_departure - [None] * len(line.agent_positions), # latest_arrival + + return list(starmap(EnvAgent, zip(schedule.agent_positions, # TODO : Dipam - Really want to change this way of loading agents + schedule.agent_directions, + schedule.agent_directions, + schedule.agent_targets, + [False] * len(schedule.agent_positions), + [None] * len(schedule.agent_positions), # earliest_departure + [None] * len(schedule.agent_positions), # latest_arrival speed_datas, malfunction_datas, - range(len(line.agent_positions))))) + range(len(schedule.agent_positions)), + speed_counters, + ))) @classmethod def load_legacy_static_agent(cls, static_agents_data: Tuple): diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 0d27913d6f27fb5df301960655d90baa42ef1ac0..2fecddf1a0abb954637992c371c1fc1053417a78 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -18,7 +18,7 @@ MalfunctionProcessData = NamedTuple('MalfunctionProcessData', Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) # Why is the return value Optional? We always return a Malfunction. -MalfunctionGenerator = Callable[[EnvAgent, RandomState, bool], Optional[Malfunction]] +MalfunctionGenerator = Callable[[RandomState, bool], Malfunction] def _malfunction_prob(rate: float) -> float: """ @@ -42,21 +42,14 @@ class ParamMalfunctionGen(object): #self.max_number_of_steps_broken = parameters.max_duration self.MFP = parameters - def generate(self, - agent: EnvAgent = None, - np_random: RandomState = None, - reset=False) -> Optional[Malfunction]: - - # Dummy reset function as we don't implement specific seeding here - if reset: - return Malfunction(0) + def generate(self, np_random: RandomState) -> Malfunction: - if agent.malfunction_data['malfunction'] < 1: - if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate): - num_broken_steps = np_random.randint(self.MFP.min_duration, - self.MFP.max_duration + 1) + 1 - return Malfunction(num_broken_steps) - return Malfunction(0) + if np_random.rand() < _malfunction_prob(self.MFP.malfunction_rate): + num_broken_steps = np_random.randint(self.MFP.min_duration, + self.MFP.max_duration + 1) + 1 + else: + num_broken_steps = 0 + return Malfunction(num_broken_steps) def get_process_data(self): return MalfunctionProcessData(*self.MFP) @@ -103,7 +96,7 @@ def no_malfunction_generator() -> Tuple[MalfunctionGenerator, MalfunctionProcess min_number_of_steps_broken = 0 max_number_of_steps_broken = 0 - def generator(agent: EnvAgent = None, np_random: RandomState = None, reset=False) -> Optional[Malfunction]: + def generator(np_random: RandomState = None) -> Malfunction: return Malfunction(0) return generator, MalfunctionProcessData(mean_malfunction_rate, min_number_of_steps_broken, @@ -162,7 +155,7 @@ def single_malfunction_generator(earlierst_malfunction: int, malfunction_duratio malfunction_calls[agent.handle] = 1 # Break an agent that is active at the time of the malfunction - if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: + if agent.status == RailAgentStatus.ACTIVE and malfunction_calls[agent.handle] >= earlierst_malfunction: #TODO : Dipam : Is this needed? global_nr_malfunctions += 1 return Malfunction(malfunction_duration) else: diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 3d4da7b93da1705678e310b6c81c9076542db089..1dc332d9480298020ff2d63c9677f5dd0631bf6b 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -7,13 +7,15 @@ from enum import IntEnum from typing import List, NamedTuple, Optional, Dict, Tuple import numpy as np +from numpy.lib.shape_base import vsplit +from numpy.testing._private.utils import import_nose from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4_utils import get_new_position -from flatland.core.grid.grid_utils import IntVector2D +from flatland.core.grid.grid_utils import IntVector2D, position_to_coordinate from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import Agent, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap @@ -37,37 +39,23 @@ from gym.utils import seeding # from flatland.envs.line_generators import random_line_generator, LineGenerator +# NEW : Imports +from flatland.envs.schedule_time_generators import schedule_time_generator +from flatland.envs.step_utils.states import TrainState +from flatland.envs.step_utils.transition_utils import check_action + +# Env Step Facelift imports +from flatland.envs.step_utils.action_preprocessing import preprocess_raw_action, preprocess_moving_action, preprocess_action_when_waiting # Adrian Egli performance fix (the fast methods brings more than 50%) def fast_isclose(a, b, rtol): return (a < (b + rtol)) or (a < (b - rtol)) - -def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool: - return ( - max(min_value[0], min(position[0], max_value[0])), - max(min_value[1], min(position[1], max_value[1])) - ) - - -def fast_argmax(possible_transitions: (int, int, int, int)) -> bool: - if possible_transitions[0] == 1: - return 0 - if possible_transitions[1] == 1: - return 1 - if possible_transitions[2] == 1: - return 2 - return 3 - - def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: - return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] - - -def fast_count_nonzero(possible_transitions: (int, int, int, int)): - return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3] - - + if pos_1 is None: # TODO: Dipam - Consider making default of agent.position as (-1, -1) instead of None + return False + else: + return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] class RailEnv(Environment): """ @@ -255,6 +243,8 @@ class RailEnv(Environment): self.close_following = close_following # use close following logic self.motionCheck = ac.MotionCheck() + self.agent_helpers = {} + def _seed(self, seed=None): self.np_random, seed = seeding.np_random(seed) random.seed(seed) @@ -379,15 +369,18 @@ class RailEnv(Environment): # Reset agents to initial states self.reset_agents() - for agent in self.agents: - # Induce malfunctions - self._break_agent(agent) + # for agent in self.agents: + # # Induce malfunctions + # if activate_agents: + # self.set_agent_active(agent) - if agent.malfunction_data["malfunction"] > 0: - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING + # self._break_agent(agent) - # Fix agents that finished their malfunction - self._fix_agent_after_malfunction(agent) + # if agent.malfunction_data["malfunction"] > 0: + # agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING + + # # Fix agents that finished their malfunction + # self._fix_agent_after_malfunction(agent) self.num_resets += 1 self._elapsed_steps = 0 @@ -398,12 +391,6 @@ class RailEnv(Environment): # Reset the state of the observation builder with the new environment self.obs_builder.reset() - # Reset the malfunction generator - if "generate" in dir(self.malfunction_generator): - self.malfunction_generator.generate(reset=True) - else: - self.malfunction_generator(reset=True) - # Empty the episode store of agent positions self.cur_episode = [] @@ -418,52 +405,25 @@ class RailEnv(Environment): # Return the new observation vectors for each agent observation_dict: Dict = self._get_observations() return observation_dict, info_dict - - def _fix_agent_after_malfunction(self, agent: EnvAgent): - """ - Updates agent malfunction variables and fixes broken agents - - Parameters - ---------- - agent - """ - - # Ignore agents that are OK - if self._is_agent_ok(agent): - return - - # Reduce number of malfunction steps left - if agent.malfunction_data['malfunction'] > 1: - agent.malfunction_data['malfunction'] -= 1 - return - - # Restart agents at the end of their malfunction - agent.malfunction_data['malfunction'] -= 1 - if 'moving_before_malfunction' in agent.malfunction_data: - agent.moving = agent.malfunction_data['moving_before_malfunction'] - return - - def _break_agent(self, agent: EnvAgent): - """ - Malfunction generator that breaks agents at a given rate. - - Parameters - ---------- - agent - - """ - - if "generate" in dir(self.malfunction_generator): - malfunction: mal_gen.Malfunction = self.malfunction_generator.generate(agent, self.np_random) + + def apply_action_independent(self, action, rail, position, direction): + if RailEnvActions.is_moving_action(action): + new_direction, _ = check_action(action, position, direction, rail) + new_position = get_new_position(position, new_direction) else: - malfunction: mal_gen.Malfunction = self.malfunction_generator(agent, self.np_random) - - if malfunction.num_broken_steps > 0: - agent.malfunction_data['malfunction'] = malfunction.num_broken_steps - agent.malfunction_data['moving_before_malfunction'] = agent.moving - agent.malfunction_data['nr_malfunctions'] += 1 - - return + new_position, new_direction = position, direction + return new_position, direction + + def generate_state_transition_signals(self, agent, preprocessed_action, movement_allowed): + st_signals = {} + + st_signals['malfunction_onset'] = agent.malfunction_handler.in_malfunction + st_signals['malfunction_counter_complete'] = agent.malfunction_handler.malfunction_counter_complete + st_signals['earliest_departure_reached'] = self._elapsed_steps >= agent.earliest_departure + st_signals['stop_action_given'] = (preprocessed_action == RailEnvActions.STOP_MOVING) + st_signals['valid_movement_action_given'] = RailEnvActions.is_moving_action(preprocessed_action) + st_signals['target_reached'] = fast_position_equal(agent.position, agent.target) + st_signals['movement_conflict'] = (not movement_allowed) and agent.speed_counter.is_cell_exit # TODO: Modify motion check to provide proper conflict information def _handle_end_reward(self, agent: EnvAgent) -> int: ''' @@ -497,16 +457,26 @@ class RailEnv(Environment): """ Updates rewards for the agents at a step. - Parameters - ---------- - action_dict_ : Dict[int,RailEnvActions] - - """ + def step(self, action_dict): self._elapsed_steps += 1 # If we're done, set reward and info_dict and step() is done. - if self.dones["__all__"]: - raise Exception("Episode is done, cannot call step()") + if self.dones["__all__"]: # TODO: Move boilerplate to different function + self.rewards_dict = {} + info_dict = { + "action_required": {}, + "malfunction": {}, + "speed": {}, + "status": {}, + } + for i_agent, agent in enumerate(self.agents): + self.rewards_dict[i_agent] = self.global_reward + info_dict["action_required"][i_agent] = False + info_dict["malfunction"][i_agent] = 0 + info_dict["speed"][i_agent] = 0 + info_dict["status"][i_agent] = agent.status + + return self._get_observations(), self.rewards_dict, self.dones, info_dict # Reset the step rewards self.rewards_dict = dict() @@ -520,407 +490,96 @@ class RailEnv(Environment): self.motionCheck = ac.MotionCheck() # reset the motion check - if not self.close_following: - for i_agent, agent in enumerate(self.agents): - # Reset the step rewards - self.rewards_dict[i_agent] = 0 - - # Induce malfunction before we do a step, thus a broken agent can't move in this step - self._break_agent(agent) - - # Perform step on the agent - self._step_agent(i_agent, action_dict_.get(i_agent)) - - # manage the boolean flag to check if all agents are indeed done (or done_removed) - have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) - - # Build info dict - info_dict["action_required"][i_agent] = self.action_required(agent) - info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] - info_dict["speed"][i_agent] = agent.speed_data['speed'] - info_dict["status"][i_agent] = agent.status - - # Fix agents that finished their malfunction such that they can perform an action in the next step - self._fix_agent_after_malfunction(agent) - - - else: - for i_agent, agent in enumerate(self.agents): - # Reset the step rewards - self.rewards_dict[i_agent] = 0 - - # Induce malfunction before we do a step, thus a broken agent can't move in this step - self._break_agent(agent) - - # Perform step on the agent - self._step_agent_cf(i_agent, action_dict_.get(i_agent)) - - # second loop: check for collisions / conflicts - self.motionCheck.find_conflicts() - - # third loop: update positions - for i_agent, agent in enumerate(self.agents): - self._step_agent2_cf(i_agent) - - # manage the boolean flag to check if all agents are indeed done (or done_removed) - have_all_agents_ended &= (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) - - # Build info dict - info_dict["action_required"][i_agent] = self.action_required(agent) - info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction'] - info_dict["speed"][i_agent] = agent.speed_data['speed'] - info_dict["status"][i_agent] = agent.status - - # Fix agents that finished their malfunction such that they can perform an action in the next step - self._fix_agent_after_malfunction(agent) - + temp_saved_data = {} # TODO : Change name - # NEW : REW: (END) - if ((self._max_episode_steps is not None) and (self._elapsed_steps >= self._max_episode_steps)) \ - or have_all_agents_ended : - - for i_agent, agent in enumerate(self.agents): - - reward = self._handle_end_reward(agent) - self.rewards_dict[i_agent] += reward + for i_agent, agent in enumerate(self.agents): # TODO: Important - Do not use i_agent like this, use agent.handle if needed + # Generate malfunction + agent.malfunction_handler.generate_malfunction(self.malfunction_generator, self.np_random) + + # Get action for the agent + action = action_dict.get(i_agent, RailEnvActions.DO_NOTHING) + # TODO: Add the bottom stuff to separate function(s) + + # Preprocess action + action = preprocess_raw_action(action, agent.state) + action = preprocess_action_when_waiting(action, agent.state) + + # Try moving actions on current position + current_position, current_direction = agent.position, agent.direction + agent_not_on_map = current_position is None + if agent_not_on_map: # Agent not added on map yet + current_position, current_direction = agent.initial_position, agent.initial_direction + action = preprocess_moving_action(action, agent.state, self.rail, current_position, current_direction) + + # Save moving actions in not already saved + agent.action_saver.save_action_if_allowed(action, agent.state) + + # Calculate new position + # Add agent to the map if not on it yet + if agent_not_on_map and agent.action_saver.is_action_saved: + temp_new_position = agent.initial_position + temp_new_direction = agent.initial_direction - self.dones[i_agent] = True + # When cell exit occurs apply saved action independent of other agents + elif agent.speed_counter.is_cell_exit and agent.action_saver.is_action_saved: + saved_action = agent.action_saver.saved_action + # Apply action independent of other agents and get temporary new position and direction + temp_pd = self.apply_action_independent(saved_action, self.rail, agent.position, agent.direction) + temp_new_position, temp_new_direction = temp_pd + else: + temp_new_position, temp_new_direction = agent.position, agent.direction + + # TODO: Saving temporary positon shouldn't be needed if recheck of position is not needed later (see TAG#1) + temp_saved_data[i_agent] = temp_new_position, temp_new_direction, action + self.motionCheck.addAgent(i_agent, agent.position, temp_new_position) - self.dones["__all__"] = True + # Find conflicts + # TODO : Important - Modify conflicted positions and select one of them randomly to go to new position + self.motionCheck.find_conflicts() + for agent in self.agents: + i_agent = agent.handle - if self.record_steps: - self.record_timestep(action_dict_) + ## Update positions + movement_allowed, _ = self.motionCheck.check_motion(i_agent, agent.position) # TODO: Remove final_new_postion from motioncheck + # TODO : Important : Original code rechecks the next position here again - not sure why? TAG#1 + preprocessed_action = temp_saved_data[i_agent][2] # TODO : Important : Make this namedtuple or class - return self._get_observations(), self.rewards_dict, self.dones, info_dict + # TODO : Looks like a hacky conditionm, improve the handling + if agent.malfunction_handler.in_malfunction: + movement_allowed = False - def _step_agent(self, i_agent, action: Optional[RailEnvActions] = None): - """ - Performs a step and step, start and stop penalty on a single agent in the following sub steps: - - malfunction - - action handling if at the beginning of cell - - movement - - Parameters - ---------- - i_agent : int - action_dict_ : Dict[int,RailEnvActions] - - """ - agent = self.agents[i_agent] - if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed... - return - - # agent gets active by a MOVE_* action and if c - if agent.status == RailAgentStatus.READY_TO_DEPART: - initial_cell_free = self.cell_free(agent.initial_position) - is_action_starting = action in [ - RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD] - - if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, - RailEnvActions.MOVE_FORWARD] and self.cell_free(agent.initial_position): - agent.status = RailAgentStatus.ACTIVE - self._set_agent_to_initial_position(agent, agent.initial_position) - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - return + if movement_allowed: + final_new_position, final_new_direction = temp_saved_data[i_agent][:2] # TODO : Important : Make this namedtuple or class else: - # TODO: Here we need to check for the departure time in future releases with full schedules - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - return - - agent.old_direction = agent.direction - agent.old_position = agent.position - - # if agent is broken, actions are ignored and agent does not move. - # full step penalty in this case - if agent.malfunction_data['malfunction'] > 0: - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - return - - # Is the agent at the beginning of the cell? Then, it can take an action. - # As long as the agent is malfunctioning or stopped at the beginning of the cell, - # different actions may be taken! - if fast_isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): - # No action has been supplied for this agent -> set DO_NOTHING as default - if action is None: - action = RailEnvActions.DO_NOTHING - - if action < 0 or action > len(RailEnvActions): - print('ERROR: illegal action=', action, - 'for agent with index=', i_agent, - '"DO NOTHING" will be executed instead') - action = RailEnvActions.DO_NOTHING - - if action == RailEnvActions.DO_NOTHING and agent.moving: - # Keep moving - action = RailEnvActions.MOVE_FORWARD - - if action == RailEnvActions.STOP_MOVING and agent.moving: - # Only allow halting an agent on entering new cells. - agent.moving = False - self.rewards_dict[i_agent] += self.stop_penalty - - if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or - action == RailEnvActions.STOP_MOVING): - # Allow agent to start with any forward or direction action - agent.moving = True - self.rewards_dict[i_agent] += self.start_penalty - - # Store the action if action is moving - # If not moving, the action will be stored when the agent starts moving again. - if agent.moving: - _action_stored = False - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(action, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = action - _action_stored = True - else: - # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, - # try to keep moving forward! - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - _action_stored = True - - if not _action_stored: - # If the agent cannot move due to an invalid transition, we set its state to not moving - self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.stop_penalty - agent.moving = False - - # Now perform a movement. - # If agent.moving, increment the position_fraction by the speed of the agent - # If the new position fraction is >= 1, reset to 0, and perform the stored - # transition_action_on_cellexit if the cell is free. - if agent.moving: - agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] > 1.0 or fast_isclose(agent.speed_data['position_fraction'], 1.0, - rtol=1e-03): - # Perform stored action to transition to the next cell as soon as cell is free - # Notice that we've already checked new_cell_valid and transition valid when we stored the action, - # so we only have to check cell_free now! - - # Traditional check that next cell is free - # cell and transition validity was checked when we stored transition_action_on_cellexit! - cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( - agent.speed_data['transition_action_on_cellexit'], agent) - - # N.B. validity of new_cell and transition should have been verified before the action was stored! - assert new_cell_valid - assert transition_valid - if cell_free: - self._move_agent_to_new_position(agent, new_position) - agent.direction = new_direction - agent.speed_data['position_fraction'] = 0.0 - - # has the agent reached its target? - if np.equal(agent.position, agent.target).all(): - agent.status = RailAgentStatus.DONE - self.dones[i_agent] = True - self.active_agents.remove(i_agent) - agent.moving = False - self._remove_agent_from_scene(agent) - else: - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - else: - # step penalty if not moving (stopped now or before) - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - - def _step_agent_cf(self, i_agent, action: Optional[RailEnvActions] = None): - """ "close following" version of step_agent. - """ - agent = self.agents[i_agent] - if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]: # this agent has already completed... - return - - # NEW : STEP: WAITING > WAITING or WAITING > READY_TO_DEPART - if (agent.status == RailAgentStatus.WAITING): - if ( self._elapsed_steps >= agent.earliest_departure ): - agent.status = RailAgentStatus.READY_TO_DEPART - self.motionCheck.addAgent(i_agent, None, None) - return - - # agent gets active by a MOVE_* action and if c - if agent.status == RailAgentStatus.READY_TO_DEPART: - is_action_starting = action in [ - RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT, RailEnvActions.MOVE_FORWARD] - - if is_action_starting: # agent is trying to start - self.motionCheck.addAgent(i_agent, None, agent.initial_position) - else: # agent wants to remain unstarted - self.motionCheck.addAgent(i_agent, None, None) - return - - agent.old_direction = agent.direction - agent.old_position = agent.position - - # if agent is broken, actions are ignored and agent does not move. - # full step penalty in this case - # TODO: this means that deadlocked agents which suffer a malfunction are marked as - # stopped rather than deadlocked. - if agent.malfunction_data['malfunction'] > 0: - self.motionCheck.addAgent(i_agent, agent.position, agent.position) - # agent will get penalty in step_agent2_cf - # self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - return - - # Is the agent at the beginning of the cell? Then, it can take an action. - # As long as the agent is malfunctioning or stopped at the beginning of the cell, - # different actions may be taken! - if np.isclose(agent.speed_data['position_fraction'], 0.0, rtol=1e-03): - # No action has been supplied for this agent -> set DO_NOTHING as default - if action is None: - action = RailEnvActions.DO_NOTHING - - if action < 0 or action > len(RailEnvActions): - print('ERROR: illegal action=', action, - 'for agent with index=', i_agent, - '"DO NOTHING" will be executed instead') - action = RailEnvActions.DO_NOTHING - - if action == RailEnvActions.DO_NOTHING and agent.moving: - # Keep moving - action = RailEnvActions.MOVE_FORWARD - - if action == RailEnvActions.STOP_MOVING and agent.moving: - # Only allow halting an agent on entering new cells. - agent.moving = False - self.rewards_dict[i_agent] += self.stop_penalty - - if not agent.moving and not ( - action == RailEnvActions.DO_NOTHING or - action == RailEnvActions.STOP_MOVING): - # Allow agent to start with any forward or direction action - agent.moving = True - self.rewards_dict[i_agent] += self.start_penalty - - # Store the action if action is moving - # If not moving, the action will be stored when the agent starts moving again. - new_position = None - if agent.moving: - _action_stored = False - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(action, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = action - _action_stored = True - else: - # But, if the chosen invalid action was LEFT/RIGHT, and the agent is moving, - # try to keep moving forward! - if (action == RailEnvActions.MOVE_LEFT or action == RailEnvActions.MOVE_RIGHT): - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(RailEnvActions.MOVE_FORWARD, agent) - - if all([new_cell_valid, transition_valid]): - agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.MOVE_FORWARD - _action_stored = True - - if not _action_stored: - # If the agent cannot move due to an invalid transition, we set its state to not moving - self.rewards_dict[i_agent] += self.invalid_action_penalty - self.rewards_dict[i_agent] += self.stop_penalty - agent.moving = False - self.motionCheck.addAgent(i_agent, agent.position, agent.position) - return - - if new_position is None: - self.motionCheck.addAgent(i_agent, agent.position, agent.position) - if agent.moving: - print("Agent", i_agent, "new_pos none, but moving") - - # Check the pos_frac position fraction - if agent.moving: - agent.speed_data['position_fraction'] += agent.speed_data['speed'] - if agent.speed_data['position_fraction'] > 0.999: - stored_action = agent.speed_data["transition_action_on_cellexit"] - - # find the next cell using the stored action - _, new_cell_valid, new_direction, new_position, transition_valid = \ - self._check_action_on_agent(stored_action, agent) - - # if it's valid, record it as the new position - if all([new_cell_valid, transition_valid]): - self.motionCheck.addAgent(i_agent, agent.position, new_position) - else: # if the action wasn't valid then record the agent as stationary - self.motionCheck.addAgent(i_agent, agent.position, agent.position) - else: # This agent hasn't yet crossed the cell - self.motionCheck.addAgent(i_agent, agent.position, agent.position) - - def _step_agent2_cf(self, i_agent): - agent = self.agents[i_agent] - - # NEW : REW: (WAITING) no reward during WAITING... - if agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED, RailAgentStatus.WAITING]: - return - - (move, rc_next) = self.motionCheck.check_motion(i_agent, agent.position) - - if agent.position is not None: - sbTrans = format(self.rail.grid[agent.position], "016b") - trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4] - if (trans_block == "0000"): - print (i_agent, agent.position, agent.direction, sbTrans, trans_block) - - # if agent cannot enter env, then we should have move=False - - if move: - if agent.position is None: # agent is entering the env - # print(i_agent, "writing new pos ", rc_next, " into agent position (None)") - agent.position = rc_next - agent.status = RailAgentStatus.ACTIVE - agent.speed_data['position_fraction'] = 0.0 - - else: # normal agent move - cell_free, new_cell_valid, new_direction, new_position, transition_valid = self._check_action_on_agent( - agent.speed_data['transition_action_on_cellexit'], agent) - - if not all([transition_valid, new_cell_valid]): - print(f"ERRROR: step_agent2 invalid transition ag {i_agent} dir {new_direction} pos {agent.position} next {rc_next}") - - if new_position != rc_next: - print(f"ERROR: agent {i_agent} new_pos {new_position} != rc_next {rc_next} " + - f"pos {agent.position} dir {agent.direction} new_dir {new_direction}" + - f"stored action: {agent.speed_data['transition_action_on_cellexit']}") - - sbTrans = format(self.rail.grid[agent.position], "016b") - trans_block = sbTrans[agent.direction * 4: agent.direction * 4 + 4] - if (trans_block == "0000"): - print ("ERROR: ", i_agent, agent.position, agent.direction, sbTrans, trans_block) - - agent.position = rc_next - agent.direction = new_direction - agent.speed_data['position_fraction'] = 0.0 - - # NEW : STEP: Check DONE before / after LA & Check if RUNNING before / after LA - # has the agent reached its target? - if np.equal(agent.position, agent.target).all(): - # arrived before or after Latest Arrival - agent.status = RailAgentStatus.DONE - self.dones[i_agent] = True - self.active_agents.remove(i_agent) - agent.moving = False - agent.arrival_time = self._elapsed_steps - self._remove_agent_from_scene(agent) - - else: # not reached its target and moving - # running before Latest Arrival - if (self._elapsed_steps <= agent.latest_arrival): - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - else: # running after Latest Arrival - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] # + # NEGATIVE REWARD? per step? - else: - # stopped (!move) before Latest Arrival - if (self._elapsed_steps <= agent.latest_arrival): - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] - else: # stopped (!move) after Latest Arrival - self.rewards_dict[i_agent] += self.step_penalty * agent.speed_data['speed'] # + # NEGATIVE REWARD? per step? + final_new_position = agent.position + final_new_direction = agent.direction + agent.position = final_new_position + agent.direction = final_new_direction + + ## Update states + state_transition_signals = self.generate_state_transition_signals(agent, preprocessed_action, movement_allowed) + agent.state_machine.set_transition_signals(state_transition_signals) + agent.state_machine.step() + agent.state = agent.state_machine.state # TODO : Make this a property instead? + + # TODO : Important : Looks like a hacky condiition, improve the handling + if agent.state == TrainState.DONE: + agent.position = None + + ## Update rewards + # self.update_rewards(i_agent, agent, rail) + + ## Update counters (malfunction and speed) + agent.speed_counter.update_counter(agent.state) + agent.malfunction_handler.update_counter() + + # Clear old action when starting in new cell + if agent.speed_counter.is_cell_entry: + agent.action_saver.clear_saved_action() + + self.rewards_dict = {i_agent: 0 for i_agent in range(len(self.agents))} # TODO : Remove this + return self._get_observations(), self.rewards_dict, self.dones, info_dict # TODO : Will need changes? def _set_agent_to_initial_position(self, agent: EnvAgent, new_position: IntVector2D): """ @@ -965,52 +624,6 @@ class RailEnv(Environment): agent.old_position = None agent.status = RailAgentStatus.DONE_REMOVED - def _check_action_on_agent(self, action: RailEnvActions, agent: EnvAgent): - """ - - Parameters - ---------- - action : RailEnvActions - agent : EnvAgent - - Returns - ------- - bool - Is it a legal move? - 1) transition allows the new_direction in the cell, - 2) the new cell is not empty (case 0), - 3) the cell is free, i.e., no agent is currently in that cell - - - """ - # compute number of possible transitions in the current - # cell used to check for invalid actions - new_direction, transition_valid = self.check_action(agent, action) - new_position = get_new_position(agent.position, new_direction) - - new_cell_valid = ( - fast_position_equal( # Check the new position is still in the grid - new_position, - fast_clip(new_position, [0, 0], [self.height - 1, self.width - 1])) - and # check the new position has some transitions (ie is not an empty cell) - self.rail.get_full_transitions(*new_position) > 0) - - # If transition validity hasn't been checked yet. - if transition_valid is None: - transition_valid = self.rail.get_transition( - (*agent.position, agent.direction), - new_direction) - - # only call cell_free() if new cell is inside the scene - if new_cell_valid: - # Check the new position is not the same as any of the existing agent positions - # (including itself, for simplicity, since it is moving) - cell_free = self.cell_free(new_position) - else: - # if new cell is outside of scene -> cell_free is False - cell_free = False - return cell_free, new_cell_valid, new_direction, new_position, transition_valid - def record_timestep(self, dActions): ''' Record the positions and orientations of all agents in memory, in the cur_episode ''' @@ -1034,62 +647,6 @@ class RailEnv(Environment): self.cur_episode.append(list_agents_state) self.list_actions.append(dActions) - def cell_free(self, position: IntVector2D) -> bool: - """ - Utility to check if a cell is free - - Parameters: - -------- - position : Tuple[int, int] - - Returns - ------- - bool - is the cell free or not? - - """ - return self.agent_positions[position] == -1 - - def check_action(self, agent: EnvAgent, action: RailEnvActions): - """ - - Parameters - ---------- - agent : EnvAgent - action : RailEnvActions - - Returns - ------- - Tuple[Grid4TransitionsEnum,Tuple[int,int]] - - - - """ - transition_valid = None - possible_transitions = self.rail.get_transitions(*agent.position, agent.direction) - num_transitions = fast_count_nonzero(possible_transitions) - - new_direction = agent.direction - if action == RailEnvActions.MOVE_LEFT: - new_direction = agent.direction - 1 - if num_transitions <= 1: - transition_valid = False - - elif action == RailEnvActions.MOVE_RIGHT: - new_direction = agent.direction + 1 - if num_transitions <= 1: - transition_valid = False - - new_direction %= 4 - - if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1: - # - dead-end, straight line or curved line; - # new_direction will be the only valid transition - # - take only available transition - new_direction = fast_argmax(possible_transitions) - transition_valid = True - return new_direction, transition_valid - def _get_observations(self): """ Utility which returns the observations for an agent with respect to environment @@ -1140,7 +697,7 @@ class RailEnv(Environment): True if agent is ok, False otherwise """ - return agent.malfunction_data['malfunction'] < 1 + return agent.malfunction_handler.in_malfunction def save(self, filename): print("deprecated call to env.save() - pls call RailEnvPersister.save()") diff --git a/flatland/envs/rail_env_action.py b/flatland/envs/rail_env_action.py index 6fcc175e7f7f63653153f8841ec3ba398876d4a1..a25cc8f0f37233f76b921ffc62c83818e8e7bb9b 100644 --- a/flatland/envs/rail_env_action.py +++ b/flatland/envs/rail_env_action.py @@ -19,6 +19,10 @@ class RailEnvActions(IntEnum): 4: 'S', }[a] + @staticmethod + def is_moving_action(action): + return action in [1,2,3] + RailEnvGridPos = NamedTuple('RailEnvGridPos', [('r', int), ('c', int)]) RailEnvNextAction = NamedTuple('RailEnvNextAction', [('action', RailEnvActions), ('next_position', RailEnvGridPos), diff --git a/flatland/envs/step_utils/action_preprocessing.py b/flatland/envs/step_utils/action_preprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..e8ad1d797d47dc9495089c576fc16fc507548adf --- /dev/null +++ b/flatland/envs/step_utils/action_preprocessing.py @@ -0,0 +1,61 @@ +from flatland.core.grid.grid_utils import position_to_coordinate +from flatland.envs.agent_utils import TrainState +from flatland.envs.rail_env_action import RailEnvActions +from flatland.envs.step_utils.transition_utils import check_valid_action + + +def process_illegal_action(action: RailEnvActions): + # TODO - Dipam : This check is kind of weird, change this + if action is None or action not in RailEnvActions._value2member_map_: + return RailEnvActions.DO_NOTHING + else: + return action + + +def process_do_nothing(state: TrainState): + if state == TrainState.MOVING: + action = RailEnvActions.MOVE_FORWARD + else: + action = RailEnvActions.STOP_MOVING + return action + + +def process_left_right(action, state, rail, position, direction): + if not check_valid_action(action, state, rail, position, direction): + action = RailEnvActions.MOVE_FORWARD + return action + + +def preprocess_action_when_waiting(action, state): + """ + Set action to DO_NOTHING if in waiting state + """ + if state == TrainState.WAITING: + action = RailEnvActions.DO_NOTHING + return action + + +def preprocess_raw_action(action, state): + """ + Preprocesses actions to handle different situations of usage of action based on context + - DO_NOTHING is converted to FORWARD if train is moving + - DO_NOTHING is converted to STOP_MOVING if train is moving + """ + action = process_illegal_action(action) + + if action == RailEnvActions.DO_NOTHING: + action = process_do_nothing(state) + + return action + +def preprocess_moving_action(action, state, rail, position, direction): + """ + LEFT/RIGHT is converted to FORWARD if left/right is not available and train is moving + FORWARD is converted to STOP_MOVING if leading to dead end? + """ + if action in [RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]: + action = process_left_right(action, rail, position, direction) + + if not check_valid_action(action, rail, position, direction): # TODO: Dipam - Not sure if this is needed + action = RailEnvActions.STOP_MOVING + return action \ No newline at end of file diff --git a/flatland/envs/step_utils/action_saver.py b/flatland/envs/step_utils/action_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..56f7465af77de4a88ce6d010593bca92c8280759 --- /dev/null +++ b/flatland/envs/step_utils/action_saver.py @@ -0,0 +1,25 @@ +from flatland.envs.rail_env_action import RailEnvActions +from flatland.envs.step_utils.states import TrainState + +class ActionSaver: + def __init__(self): + self.saved_action = None + + @property + def is_action_saved(self): + return self.saved_action is not None + + def __repr__(self): + return f"is_action_saved: {self.is_action_saved}, saved_action: {self.saved_action}" + + + def save_action_if_allowed(self, action, state): + if not self.is_action_saved and \ + RailEnvActions.is_moving_action(action) and \ + not TrainState.is_malfunction_state(state): + self.saved_action = action + + def clear_saved_action(self): + self.saved_action = None + + diff --git a/flatland/envs/step_utils/malfunction_handler.py b/flatland/envs/step_utils/malfunction_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2d4169e0b0f46b172b358f84a26e5832749969 --- /dev/null +++ b/flatland/envs/step_utils/malfunction_handler.py @@ -0,0 +1,47 @@ + +def get_number_of_steps_to_break(malfunction_generator, np_random): + if hasattr(malfunction_generator, "generate"): + malfunction = malfunction_generator.generate(np_random) + else: + malfunction = malfunction_generator(np_random) + + return malfunction.num_broken_steps + +class MalfunctionHandler: + def __init__(self): + self._malfunction_down_counter = 0 + + @property + def in_malfunction(self): + return self._malfunction_down_counter > 0 + + @property + def malfunction_counter_complete(self): + return self._malfunction_down_counter == 0 + + @property + def malfunction_down_counter(self): + return self._malfunction_down_counter + + @malfunction_down_counter.setter + def malfunction_down_counter(self, val): + self._set_malfunction_down_counter(val) + + def _set_malfunction_down_counter(self, val): + if val < 0: + raise ValueError("Cannot set a negative value to malfunction down counter") + self._malfunction_down_counter = val + + def generate_malfunction(self, malfunction_generator, np_random): + num_broken_steps = get_number_of_steps_to_break(malfunction_generator, np_random) + self._set_malfunction_down_counter(num_broken_steps) + + def update_counter(self): + if self._malfunction_down_counter > 0: + self._malfunction_down_counter -= 1 + + + + + + diff --git a/flatland/envs/step_utils/speed_counter.py b/flatland/envs/step_utils/speed_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..5bde9c20f98b1b7ed26ad4a8ba3d5791786bd84f --- /dev/null +++ b/flatland/envs/step_utils/speed_counter.py @@ -0,0 +1,31 @@ +import numpy as np +from flatland.envs.step_utils.states import TrainState + +class SpeedCounter: + def __init__(self, speed): + self.speed = speed + self.max_count = int(1/speed) + + def update_counter(self, state): + if state == TrainState.MOVING: + self.counter += 1 + self.counter = self.counter % self.max_count + + def __repr__(self): + return f"speed: {self.speed} \ + max_count: {self.max_count} \ + is_cell_entry: {self.is_cell_entry} \ + is_cell_exit: {self.is_cell_exit} \ + counter: {self.counter}" + + def reset_counter(self): + self.counter = 0 + + @property + def is_cell_entry(self): + return self.counter == 0 + + @property + def is_cell_exit(self): + return self.counter == self.max_count - 1 + diff --git a/flatland/envs/step_utils/state_machine.py b/flatland/envs/step_utils/state_machine.py new file mode 100644 index 0000000000000000000000000000000000000000..e42a829d2018c3c540ddd0f0e8c249530333abef --- /dev/null +++ b/flatland/envs/step_utils/state_machine.py @@ -0,0 +1,140 @@ +from attr import s +from flatland.envs.step_utils.states import TrainState + +class TrainStateMachine: + def __init__(self, initial_state=TrainState.WAITING): + self._initial_state = initial_state + self._state = initial_state + self.st_signals = {} # State Transition Signals # TODO: Make this namedtuple + self.next_state = None + + def _handle_waiting(self): + """" Waiting state goes to ready to depart when earliest departure is reached""" + # TODO: Important - The malfunction handling is not like proper state machine + # Both transition signals can happen at the same time + # Atleast mention it in the diagram + if self.st_signals['malfunction_onset']: + self.next_state = TrainState.MALFUNCTION_OFF_MAP + elif self.st_signals['earliest_departure_reached']: + self.next_state = TrainState.READY_TO_DEPART + else: + self.next_state = TrainState.WAITING + + def _handle_ready_to_depart(self): + """ Can only go to MOVING if a valid action is provided """ + if self.st_signals['malfunction_onset']: + self.next_state = TrainState.MALFUNCTION_OFF_MAP + elif self.st_signals['valid_movement_action_given']: + self.next_state = TrainState.MOVING + else: + self.next_state = TrainState.READY_TO_DEPART + + def _handle_malfunction_off_map(self): + if self.st_signals['malfunction_counter_complete']: + if self.st_signals['earliest_departure_reached']: + self.next_state = TrainState.READY_TO_DEPART + else: + self.next_state = TrainState.STOPPED + else: + self.next_state = TrainState.WAITING + + def _handle_moving(self): + if self.st_signals['malfunction_onset']: + self.next_state = TrainState.MALFUNCTION + elif self.st_signals['target_reached']: + self.next_state = TrainState.DONE + elif self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']: + self.next_state = TrainState.STOPPED + else: + self.next_state = TrainState.MOVING + + def _handle_stopped(self): + if self.st_signals['malfunction_onset']: + self.next_state = TrainState.MALFUNCTION + elif self.st_signals['valid_movement_action_given']: + self.next_state = TrainState.MOVING + else: + self.next_state = TrainState.STOPPED + + def _handle_malfunction(self): + if self.st_signals['malfunction_counter_complete'] and \ + self.st_signals['valid_movement_action_given']: + self.next_state = TrainState.MOVING + elif self.st_signals['malfunction_counter_complete'] and \ + (self.st_signals['stop_action_given'] or self.st_signals['movement_conflict']): + self.next_state = TrainState.STOPPED + else: + self.next_state = TrainState.MALFUNCTION + + def _handle_done(self): + """" Done state is terminal """ + self.next_state = TrainState.DONE + + def calculate_next_state(self, current_state): + + # _Handle the current state + if current_state == TrainState.WAITING: + self._handle_waiting() + + elif current_state == TrainState.READY_TO_DEPART: + self._handle_ready_to_depart() + + elif current_state == TrainState.MALFUNCTION_OFF_MAP: + self._handle_malfunction_off_map() + + elif current_state == TrainState.MOVING: + self._handle_moving() + + elif current_state == TrainState.STOPPED: + self._handle_stopped() + + elif current_state == TrainState.MALFUNCTION: + self._handle_malfunction() + + elif current_state == TrainState.DONE: + self._handle_done() + + else: + raise ValueError(f"Got unexpected state {current_state}") + + def step(self): + """ Steps the state machine to the next state """ + + current_state = self._state + + # Clear next state + self.clear_next_state() + + # Handle current state to get next_state + self.calculate_next_state(current_state) + + # Set next state + self.set_state(self.next_state) + + + def clear_next_state(self): + self.next_state = None + + def set_state(self, state): + if not TrainState.check_valid_state(state): + raise ValueError(f"Cannot set invalid state {state}") + self._state = state + + def reset(self): + self._state = self._initial_state + self.st_signals = {} + self.clear_next_state() + + @property + def state(self): + return self._state + + @property + def state_transition_signals(self): + return self.st_signals + + def set_transition_signals(self, state_transition_signals): + self.st_signals = state_transition_signals # TODO: Important: Check all keys are present and if not raise error + + + diff --git a/flatland/envs/step_utils/states.py b/flatland/envs/step_utils/states.py new file mode 100644 index 0000000000000000000000000000000000000000..4c991b667fdfcc90086492fe80d43ff4d45ddce1 --- /dev/null +++ b/flatland/envs/step_utils/states.py @@ -0,0 +1,21 @@ +from enum import IntEnum + +class TrainState(IntEnum): + WAITING = 0 + READY_TO_DEPART = 1 + MALFUNCTION_OFF_MAP = 2 + MOVING = 3 + STOPPED = 4 + MALFUNCTION = 5 + DONE = 6 + + @classmethod + def check_valid_state(cls, state): + return state in cls._value2member_map_ + + @staticmethod + def is_malfunction_state(state): + return state in [2, 5] # TODO: Can this be done with names instead? + + + diff --git a/flatland/envs/step_utils/transition_utils.py b/flatland/envs/step_utils/transition_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d58d21e2b023087bd432d0df40703f33e69e797 --- /dev/null +++ b/flatland/envs/step_utils/transition_utils.py @@ -0,0 +1,101 @@ +from typing import Tuple +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.rail_env_action import RailEnvActions + + +def check_action(action, position, direction, rail): + """ + + Parameters + ---------- + agent : EnvAgent + action : RailEnvActions + + Returns + ------- + Tuple[Grid4TransitionsEnum,Tuple[int,int]] + + + + """ + transition_valid = None + possible_transitions = rail.get_transitions(*position, direction) + num_transitions = fast_count_nonzero(possible_transitions) + + new_direction = direction + if action == RailEnvActions.MOVE_LEFT: + new_direction = direction - 1 + if num_transitions <= 1: + transition_valid = False + + elif action == RailEnvActions.MOVE_RIGHT: + new_direction = direction + 1 + if num_transitions <= 1: + transition_valid = False + + new_direction %= 4 # Dipam : Why? + + if action == RailEnvActions.MOVE_FORWARD and num_transitions == 1: + # - dead-end, straight line or curved line; + # new_direction will be the only valid transition + # - take only available transition + new_direction = fast_argmax(possible_transitions) + transition_valid = True + return new_direction, transition_valid + + +def check_action_on_agent(action, rail, position, direction): + """ + Parameters + ---------- + action : RailEnvActions + agent : EnvAgent + + Returns + ------- + bool + Is it a legal move? + 1) transition allows the new_direction in the cell, + 2) the new cell is not empty (case 0), + 3) the cell is free, i.e., no agent is currently in that cell + + + """ + # compute number of possible transitions in the current + # cell used to check for invalid actions + new_direction, transition_valid = check_action(action, position, direction, rail) + new_position = get_new_position(position, new_direction) + + cell_inside_grid = check_bounds(new_position, rail.height, rail.width) + cell_not_empty = rail.get_full_transitions(*new_position) > 0 + new_cell_valid = cell_inside_grid and cell_not_empty + + # If transition validity hasn't been checked yet. + if transition_valid is None: + transition_valid = rail.get_transition( # TODO: Dipam - Read this one + (*position, direction), + new_direction) + + return new_cell_valid, new_direction, new_position, transition_valid + + +def check_valid_action(action, rail, position, direction): + new_cell_valid, _, _, transition_valid = check_action_on_agent(action, rail, position, direction) + action_is_valid = new_cell_valid and transition_valid + return action_is_valid + +def fast_argmax(possible_transitions: Tuple[int, int, int, int]) -> bool: + if possible_transitions[0] == 1: + return 0 + if possible_transitions[1] == 1: + return 1 + if possible_transitions[2] == 1: + return 2 + return 3 + +def fast_count_nonzero(possible_transitions: Tuple[int, int, int, int]): + return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3] + +def check_bounds(position, height, width): + return position[0] >= 0 and position[1] >= 0 and position[0] < height and position[1] < width + \ No newline at end of file diff --git a/tests/test_env_step_utils.py b/tests/test_env_step_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..739d3d06e7271d2ce54ded07de54957d96c08022 --- /dev/null +++ b/tests/test_env_step_utils.py @@ -0,0 +1,61 @@ +import numpy as np +import numpy as np +import os + +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen + +from flatland.envs.observations import GlobalObsForRailEnv +# First of all we import the Flatland rail environment +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env import RailEnvActions +from flatland.envs.rail_generators import sparse_rail_generator +#from flatland.envs.sparse_rail_gen import SparseRailGen +from flatland.envs.schedule_generators import sparse_schedule_generator + + +def get_small_two_agent_env(): + """Generates a simple 2 city 2 train env returns it after reset""" + width = 30 # With of map + height = 15 # Height of map + nr_trains = 2 # Number of trains that have an assigned task in the env + cities_in_map = 2 # Number of cities where agents can start or end + seed = 42 # Random seed + grid_distribution_of_cities = False # Type of city distribution, if False cities are randomly placed + max_rails_between_cities = 2 # Max number of tracks allowed between cities. This is number of entry point to a city + max_rail_in_cities = 6 # Max number of parallel tracks within a city, representing a realistic trainstation + + rail_generator = sparse_rail_generator(max_num_cities=cities_in_map, + seed=seed, + grid_mode=grid_distribution_of_cities, + max_rails_between_cities=max_rails_between_cities, + max_rail_pairs_in_city=max_rail_in_cities//2, + ) + speed_ration_map = {1.: 0.25, # Fast passenger train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train + 1. / 4.: 0.25} # Slow freight train + + schedule_generator = sparse_schedule_generator(speed_ration_map) + + + stochastic_data = MalfunctionParameters(malfunction_rate=1/10000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) + + observation_builder = GlobalObsForRailEnv() + + env = RailEnv(width=width, + height=height, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + number_of_agents=nr_trains, + obs_builder_object=observation_builder, + #malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + malfunction_generator=ParamMalfunctionGen(stochastic_data), + remove_agents_at_target=True, + random_seed=seed) + + env.reset() + + return env \ No newline at end of file diff --git a/tests/test_state_machine.py b/tests/test_state_machine.py new file mode 100644 index 0000000000000000000000000000000000000000..266a8f86589b6033ea67523cab0b31b72ac9d32d --- /dev/null +++ b/tests/test_state_machine.py @@ -0,0 +1,115 @@ +from test_env_step_utils import get_small_two_agent_env +from flatland.envs.rail_env_action import RailEnvActions +from flatland.envs.step_utils.states import TrainState +from flatland.envs.malfunction_generators import Malfunction + +class NoMalfunctionGenerator: + def generate(self, np_random): + return Malfunction(0) + +class AlwaysThreeStepMalfunction: + def generate(self, np_random): + return Malfunction(3) + +def test_waiting_no_transition(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed-1): + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.WAITING + + +def test_waiting_to_ready_to_depart(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.READY_TO_DEPART + + +def test_ready_to_depart_to_moving(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.MOVING + +def test_moving_to_stopped(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + env.step({i_agent: RailEnvActions.STOP_MOVING}) + assert env.agents[i_agent].state == TrainState.STOPPED + +def test_stopped_to_moving(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 0 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + env.step({i_agent: RailEnvActions.STOP_MOVING}) + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.MOVING + +def test_moving_to_done(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 1 + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) + + for _ in range(50): + env.step({i_agent: RailEnvActions.MOVE_FORWARD}) + assert env.agents[i_agent].state == TrainState.DONE + +def test_waiting_to_malfunction(): + env = get_small_two_agent_env() + env.malfunction_generator = AlwaysThreeStepMalfunction() + i_agent = 1 + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP + + +def test_ready_to_depart_to_malfunction_off_map(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 1 + env.step({i_agent: RailEnvActions.DO_NOTHING}) + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart + + env.malfunction_generator = AlwaysThreeStepMalfunction() + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP + + +def test_malfunction_off_map_to_waiting(): + env = get_small_two_agent_env() + env.malfunction_generator = NoMalfunctionGenerator() + i_agent = 1 + env.step({i_agent: RailEnvActions.DO_NOTHING}) + ed = env.agents[i_agent].earliest_departure + for _ in range(ed): + env.step({i_agent: RailEnvActions.DO_NOTHING}) # This should get into ready to depart + + env.malfunction_generator = AlwaysThreeStepMalfunction() + env.step({i_agent: RailEnvActions.DO_NOTHING}) + assert env.agents[i_agent].state == TrainState.MALFUNCTION_OFF_MAP \ No newline at end of file