From fe5f6f463c67af3d8b57ef7fd1512b5611d624ef Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 12 Nov 2019 19:26:46 -0500 Subject: [PATCH] SIM-119 refactoring ActionPlan; TODO extract Agent from ActionPlanReplayer --- flatland/action_plan/action_plan.py | 142 +++----- flatland/envs/rail_env_shortest_paths.py | 44 +-- .../envs/rail_train_run_data_structures.py | 13 + tests/test_action_plan.py | 86 +++-- tests/test_flatland_envs_predictions.py | 15 +- ...t_flatland_envs_rail_env_shortest_paths.py | 342 ++++++++---------- 6 files changed, 287 insertions(+), 355 deletions(-) create mode 100644 flatland/envs/rail_train_run_data_structures.py diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py index 017b9551..441a2ca0 100644 --- a/flatland/action_plan/action_plan.py +++ b/flatland/action_plan/action_plan.py @@ -2,38 +2,25 @@ import pprint from typing import Dict, List, Optional, NamedTuple import numpy as np + from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_env_shortest_paths import WalkingElement, get_action_for_move +from flatland.envs.rail_env_shortest_paths import get_action_for_move +from flatland.envs.rail_train_run_data_structures import WayPoint, TrainRun, TrainRunWayPoint from flatland.utils.rendertools import RenderTool, AgentRenderVariant - -#---- Input Data Structures (graph representation) --------------------------------------------- -# A cell pin represents the one of the four pins in which the cell at row,column may be entered. -CellPin = NamedTuple('CellPin', [('r', int), ('c', int), ('d', int)]) - -# A path schedule element represents the entry time of agent at a cell pin. -PathScheduleElement = NamedTuple('PathScheduleElement', [ - ('scheduled_at', int), - ('cell_pin', CellPin) -]) -# A path schedule is the list of an agent's cell pin entries -PathSchedule = List[PathScheduleElement] - - -#---- Output Data Structures (FLATland representation) --------------------------------------------- +# ---- Output Data Structures (FLATland representation) --------------------------------------------- # An action plan element represents the actions to be taken by an agent at deterministic time steps # plus the position before the action ActionPlanElement = NamedTuple('ActionPlanElement', [ ('scheduled_at', int), - ('walking_element', WalkingElement) + ('action', RailEnvActions) ]) # An action plan deterministically represents all the actions to be taken by an agent # plus its position before the actions are taken ActionPlan = Dict[int, List[ActionPlanElement]] - class ActionPlanReplayer(): """Allows to deduce an `ActionPlan` from the agents' `PathSchedule` and to be replayed/verified in a FLATland env without malfunction.""" @@ -42,15 +29,17 @@ class ActionPlanReplayer(): def __init__(self, env: RailEnv, - chosen_path_dict: Dict[int, PathSchedule]): + train_run_dict: Dict[int, TrainRun]): self.env = env + self.train_run_dict: Dict[int, TrainRun] = train_run_dict + print(train_run_dict) self.action_plan = [[] for _ in range(self.env.get_num_agents())] - for agent_id, chosen_path in chosen_path_dict.items(): + for agent_id, chosen_path in train_run_dict.items(): self._add_aggent_to_action_plan(self.action_plan, agent_id, chosen_path) - def get_walking_element_before_or_at_step(self, agent_id: int, step: int) -> WalkingElement: + def get_way_point_before_or_at_step(self, agent_id: int, step: int) -> WayPoint: """ Get the walking element from which the current position can be extracted. @@ -64,14 +53,26 @@ class ActionPlanReplayer(): WalkingElement """ - walking_element = None - for action in self.action_plan[agent_id]: - if step < action.scheduled_at: - return walking_element - if step >= action.scheduled_at: - walking_element = action.walking_element - assert walking_element is not None - return walking_element + train_run = self.train_run_dict[agent_id] + entry_time_step = train_run[0].scheduled_at + # the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid) + if step <= entry_time_step: + return WayPoint(position=None, direction=self.env.agents[agent_id].initial_direction) + + # the agent has no position as soon as the target is reached + exit_time_step = train_run[-1].scheduled_at + if step >= exit_time_step: + # agent loses position as soon as target cell is reached + return WayPoint(position=None, direction=train_run[-1].way_point.direction) + + way_point = None + for train_run_way_point in train_run: + if step < train_run_way_point.scheduled_at: + return way_point + if step >= train_run_way_point.scheduled_at: + way_point = train_run_way_point.way_point + assert way_point is not None + return way_point def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]: """ @@ -90,11 +91,10 @@ class ActionPlanReplayer(): for action_plan_step in self.action_plan[agent_id]: action_plan_step: ActionPlanElement = action_plan_step scheduled_at = action_plan_step.scheduled_at - walking_element: WalkingElement = action_plan_step.walking_element if scheduled_at > current_step: return None elif np.isclose(current_step, scheduled_at): - return walking_element.next_action + return action_plan_step.action return None def get_action_dict_for_step_replay(self, current_step: int) -> Dict[int, RailEnvActions]: @@ -130,10 +130,10 @@ class ActionPlanReplayer(): i = 0 while not env.dones['__all__'] and i <= MAX_EPISODE_STEPS: for agent_id, agent in enumerate(env.agents): - walking_element: WalkingElement = self.get_walking_element_before_or_at_step(agent_id, i) - assert agent.position == walking_element.position, \ + way_point: WayPoint = self.get_way_point_before_or_at_step(agent_id, i) + assert agent.position == way_point.position, \ "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position, - walking_element.position) + way_point.position) actions = self.get_action_dict_for_step_replay(i) print("actions for {}: {}".format(i, actions)) @@ -172,15 +172,15 @@ class ActionPlanReplayer(): agent = self.env.agents[agent_id] minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed'])) for path_loop, path_schedule_element in enumerate(agent_path_new): - path_schedule_element: PathScheduleElement = path_schedule_element + path_schedule_element: TrainRunWayPoint = path_schedule_element - position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c) + position = path_schedule_element.way_point.position if Vec2d.is_equal(agent.target, position): break - next_path_schedule_element: PathScheduleElement = agent_path_new[path_loop + 1] - next_position = (next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c) + next_path_schedule_element: TrainRunWayPoint = agent_path_new[path_loop + 1] + next_position = next_path_schedule_element.way_point.position if path_loop == 0: self._create_action_plan_for_first_path_element_of_agent( @@ -212,81 +212,63 @@ class ActionPlanReplayer(): action_plan: ActionPlan, agent_id: int, minimum_cell_time: int, - path_schedule_element: PathScheduleElement, - next_path_schedule_element: PathScheduleElement): + path_schedule_element: TrainRunWayPoint, + next_path_schedule_element: TrainRunWayPoint): scheduled_at = path_schedule_element.scheduled_at next_entry_value = next_path_schedule_element.scheduled_at - position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c) - direction = path_schedule_element.cell_pin.d - next_position = next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c - next_direction = next_path_schedule_element.cell_pin.d + position = path_schedule_element.way_point.position + direction = path_schedule_element.way_point.direction + next_position = next_path_schedule_element.way_point.position + next_direction = next_path_schedule_element.way_point.direction next_action = get_action_for_move(position, direction, next_position, next_direction, self.env.rail) - walking_element = WalkingElement(position, direction, next_action) - # if the next entry is later than minimum_cell_time, then stop here and # move minimum_cell_time before the exit # we have to do this since agents in the RailEnv are processed in the step() in the order of their handle if next_entry_value > scheduled_at + minimum_cell_time: - action = ActionPlanElement(scheduled_at, - WalkingElement( - position=position, - direction=direction, - next_action=RailEnvActions.STOP_MOVING)) + action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING) action_plan[agent_id].append(action) - action = ActionPlanElement(next_entry_value - minimum_cell_time, walking_element) + action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action) action_plan[agent_id].append(action) else: - action = ActionPlanElement(scheduled_at, walking_element) + action = ActionPlanElement(scheduled_at, next_action) action_plan[agent_id].append(action) def _create_action_plan_for_target_at_path_element_just_before_target(self, action_plan: ActionPlan, agent_id: int, minimum_cell_time: int, - path_schedule_element: PathScheduleElement, - next_path_schedule_element: PathScheduleElement): + path_schedule_element: TrainRunWayPoint, + next_path_schedule_element: TrainRunWayPoint): scheduled_at = path_schedule_element.scheduled_at - next_path_schedule_element.cell_pin + next_path_schedule_element.way_point - action = ActionPlanElement(scheduled_at + minimum_cell_time, - WalkingElement( - position=None, - direction=next_path_schedule_element.cell_pin.d, - next_action=RailEnvActions.STOP_MOVING)) + action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING) action_plan[agent_id].append(action) def _create_action_plan_for_first_path_element_of_agent(self, action_plan: ActionPlan, agent_id: int, - path_schedule_element: PathScheduleElement, - next_path_schedule_element: PathScheduleElement): + path_schedule_element: TrainRunWayPoint, + next_path_schedule_element: TrainRunWayPoint): scheduled_at = path_schedule_element.scheduled_at - position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c) - direction = path_schedule_element.cell_pin.d - next_position = next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c - next_direction = next_path_schedule_element.cell_pin.d + position = path_schedule_element.way_point.position + direction = path_schedule_element.way_point.direction + next_position = next_path_schedule_element.way_point.position + next_direction = next_path_schedule_element.way_point.direction # add intial do nothing if we do not enter immediately if scheduled_at > 0: - action = ActionPlanElement(0, - WalkingElement( - position=None, - direction=direction, - next_action=RailEnvActions.DO_NOTHING)) + action = ActionPlanElement(0, RailEnvActions.DO_NOTHING) action_plan[agent_id].append(action) # add action to enter the grid - action = ActionPlanElement(scheduled_at, - WalkingElement( - position=None, - direction=direction, - next_action=RailEnvActions.MOVE_FORWARD)) + action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD) action_plan[agent_id].append(action) next_action = get_action_for_move(position, @@ -296,9 +278,5 @@ class ActionPlanReplayer(): self.env.rail) # now, we have a position need to perform the action - action = ActionPlanElement(scheduled_at + 1, - WalkingElement( - position=position, - direction=direction, - next_action=next_action)) + action = ActionPlanElement(scheduled_at + 1, next_action) action_plan[agent_id].append(action) diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index bc7afee9..ffbd9e81 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -1,5 +1,5 @@ import math -from typing import Dict, List, Optional, NamedTuple, Tuple, Set +from typing import Dict, List, Optional, Tuple, Set import matplotlib.pyplot as plt import numpy as np @@ -10,12 +10,9 @@ from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv +from flatland.envs.rail_train_run_data_structures import WayPoint from flatland.utils.ordered_set import OrderedSet -WalkingElement = \ - NamedTuple('WalkingElement', - [('position', Tuple[int, int]), ('direction', int), ('next_action', Optional[RailEnvActions])]) - def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum, agent_position: Tuple[int, int], @@ -195,7 +192,7 @@ def get_action_for_move( # N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!) def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) \ - -> Dict[int, Optional[List[WalkingElement]]]: + -> Dict[int, Optional[List[WayPoint]]]: """ Computes the shortest path for each agent to its target and the action to be taken to do so. The paths are derived from a `DistanceMap`. @@ -245,8 +242,7 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non best_next_action = next_action distance = next_action_distance - shortest_paths[agent.handle].append( - WalkingElement(position, direction, best_next_action.action if best_next_action is not None else None)) + shortest_paths[agent.handle].append(WayPoint(position, direction)) depth += 1 # if there is no way to continue, the rail must be disconnected! @@ -258,7 +254,7 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non position = best_next_action.next_position direction = best_next_action.next_direction if max_depth is None or depth < max_depth: - shortest_paths[agent.handle].append(WalkingElement(position, direction, RailEnvActions.STOP_MOVING)) + shortest_paths[agent.handle].append(WayPoint(position, direction)) if agent_handle is not None: _shortest_path_for_agent(distance_map.agents[agent_handle]) @@ -273,7 +269,7 @@ def get_k_shortest_paths(env: RailEnv, source_position: Tuple[int, int], source_direction: int, target_position=Tuple[int, int], - k: int = 1, debug=False) -> List[Tuple[WalkingElement]]: + k: int = 1, debug=False) -> List[Tuple[WayPoint]]: """ Computes the k shortest paths using modified Dijkstra following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing @@ -300,17 +296,17 @@ def get_k_shortest_paths(env: RailEnv, # P: set of shortest paths from s to t # P =empty, - shortest_paths: List[Tuple[WalkingElement]] = [] + shortest_paths: List[Tuple[WayPoint]] = [] # countu: number of shortest paths found to node u # countu = 0, for all u in V count = {(r, c, d): 0 for r in range(env.height) for c in range(env.width) for d in range(4)} # B is a heap data structure containing paths - heap: Set[Tuple[WalkingElement]] = set() + heap: Set[Tuple[WayPoint]] = set() # insert path Ps = {s} into B with cost 0 - heap.add((WalkingElement(source_position, source_direction, None),)) + heap.add((WayPoint(source_position, source_direction),)) # while B is not empty and countt < K: while len(heap) > 0 and len(shortest_paths) < k: @@ -323,7 +319,7 @@ def get_k_shortest_paths(env: RailEnv, if len(path) < c: pu = path c = len(path) - u: WalkingElement = pu[-1] + u: WayPoint = pu[-1] if debug: print(" looking at pu={}".format(pu)) @@ -355,7 +351,7 @@ def get_k_shortest_paths(env: RailEnv, if debug: print(" looking at neighbor v={}".format((*new_position, new_direction))) - v = WalkingElement(position=new_position, direction=new_direction, next_action=None) + v = WayPoint(position=new_position, direction=new_direction) # CAVEAT: do not allow for loopy paths if v in pu: continue @@ -364,25 +360,9 @@ def get_k_shortest_paths(env: RailEnv, pv = pu + (v,) # – insert Pv into B heap.add(pv) - # add actions to shortest paths - shortest_paths_with_action = [] - for p in shortest_paths: - p_with_action = tuple( - WalkingElement(position=el.position, - direction=el.direction, - next_action=int(get_action_for_move(el.position, - el.direction, - p[i + 1].position, - p[i + 1].direction, - env.rail))) for i, el in - enumerate(p[:-1])) - target_walking_element = WalkingElement(position=p[-1].position, - direction=p[-1].direction, - next_action=int(RailEnvActions.DO_NOTHING)) - shortest_paths_with_action.append(p_with_action + (target_walking_element,)) # return P - return shortest_paths_with_action + return shortest_paths def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): diff --git a/flatland/envs/rail_train_run_data_structures.py b/flatland/envs/rail_train_run_data_structures.py new file mode 100644 index 00000000..775b183b --- /dev/null +++ b/flatland/envs/rail_train_run_data_structures.py @@ -0,0 +1,13 @@ +from typing import NamedTuple, Tuple, List + +WayPoint = NamedTuple('WayPoint', [('position', Tuple[int, int]), ('direction', int)]) + + +# A train run is represented by the waypoints traversed and the times of traversal +# The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md +TrainRunWayPoint = NamedTuple('TrainRunWayPoint', [ + ('scheduled_at', int), + ('way_point', WayPoint) +]) +# A path schedule is the list of an agent's cell pin entries +TrainRun = List[TrainRunWayPoint] diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 876a7a87..4eeb5a49 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -1,15 +1,14 @@ -from flatland.action_plan.action_plan import PathScheduleElement, CellPin, ActionPlanReplayer +from flatland.action_plan.action_plan import TrainRunWayPoint, ActionPlanReplayer, ActionPlanElement from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_env_shortest_paths import WalkingElement from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.rail_train_run_data_structures import WayPoint from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.simple_rail import make_simple_rail - def test_action_plan(rendering: bool = False): """Tests ActionPlanReplayer: does action plan generation and replay work as expected.""" rail, rail_map = make_simple_rail() @@ -33,52 +32,51 @@ def test_action_plan(rendering: bool = False): for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) - chosen_path_dict = {0: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=0, d=3)), - PathScheduleElement(scheduled_at=2, cell_pin=CellPin(r=3, c=1, d=1)), - PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=2, d=1)), - PathScheduleElement(scheduled_at=14, cell_pin=CellPin(r=3, c=3, d=1)), - PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=3, c=4, d=1)), - PathScheduleElement(scheduled_at=16, cell_pin=CellPin(r=3, c=5, d=1)), - PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=3, c=6, d=1)), - PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=3, c=7, d=1)), - PathScheduleElement(scheduled_at=19, cell_pin=CellPin(r=3, c=8, d=1)), - PathScheduleElement(scheduled_at=20, cell_pin=CellPin(r=3, c=8, d=5))], - 1: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=8, d=3)), - PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=7, d=3)), - PathScheduleElement(scheduled_at=5, cell_pin=CellPin(r=3, c=6, d=3)), - PathScheduleElement(scheduled_at=7, cell_pin=CellPin(r=3, c=5, d=3)), - PathScheduleElement(scheduled_at=9, cell_pin=CellPin(r=3, c=4, d=3)), - PathScheduleElement(scheduled_at=11, cell_pin=CellPin(r=3, c=3, d=3)), - PathScheduleElement(scheduled_at=13, cell_pin=CellPin(r=2, c=3, d=0)), - PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=1, c=3, d=0)), - PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=0, c=3, d=0)), - PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=0, c=3, d=5))]} + chosen_path_dict = {0: [TrainRunWayPoint(scheduled_at=0, way_point=WayPoint(position=(3, 0), direction=3)), + TrainRunWayPoint(scheduled_at=2, way_point=WayPoint(position=(3, 1), direction=1)), + TrainRunWayPoint(scheduled_at=3, way_point=WayPoint(position=(3, 2), direction=1)), + TrainRunWayPoint(scheduled_at=14, way_point=WayPoint(position=(3, 3), direction=1)), + TrainRunWayPoint(scheduled_at=15, way_point=WayPoint(position=(3, 4), direction=1)), + TrainRunWayPoint(scheduled_at=16, way_point=WayPoint(position=(3, 5), direction=1)), + TrainRunWayPoint(scheduled_at=17, way_point=WayPoint(position=(3, 6), direction=1)), + TrainRunWayPoint(scheduled_at=18, way_point=WayPoint(position=(3, 7), direction=1)), + TrainRunWayPoint(scheduled_at=19, way_point=WayPoint(position=(3, 8), direction=1)), + TrainRunWayPoint(scheduled_at=20, way_point=WayPoint(position=(3, 8), direction=5))], + 1: [TrainRunWayPoint(scheduled_at=0, way_point=WayPoint(position=(3, 8), direction=3)), + TrainRunWayPoint(scheduled_at=3, way_point=WayPoint(position=(3, 7), direction=3)), + TrainRunWayPoint(scheduled_at=5, way_point=WayPoint(position=(3, 6), direction=3)), + TrainRunWayPoint(scheduled_at=7, way_point=WayPoint(position=(3, 5), direction=3)), + TrainRunWayPoint(scheduled_at=9, way_point=WayPoint(position=(3, 4), direction=3)), + TrainRunWayPoint(scheduled_at=11, way_point=WayPoint(position=(3, 3), direction=3)), + TrainRunWayPoint(scheduled_at=13, way_point=WayPoint(position=(2, 3), direction=0)), + TrainRunWayPoint(scheduled_at=15, way_point=WayPoint(position=(1, 3), direction=0)), + TrainRunWayPoint(scheduled_at=17, way_point=WayPoint(position=(0, 3), direction=0))]} expected_action_plan = [[ # take action to enter the grid - (0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)), + ActionPlanElement(0, RailEnvActions.MOVE_FORWARD), # take action to enter the cell properly - (1, WalkingElement(position=(3, 0), direction=3, next_action=RailEnvActions.MOVE_FORWARD)), - (2, WalkingElement(position=(3, 1), direction=1, next_action=RailEnvActions.MOVE_FORWARD)), - (3, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.STOP_MOVING)), - (13, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.MOVE_FORWARD)), - (14, WalkingElement(position=(3, 3), direction=1, next_action=RailEnvActions.MOVE_FORWARD)), - (15, WalkingElement(position=(3, 4), direction=1, next_action=RailEnvActions.MOVE_FORWARD)), - (16, WalkingElement(position=(3, 5), direction=1, next_action=RailEnvActions.MOVE_FORWARD)), - (17, WalkingElement(position=(3, 6), direction=1, next_action=RailEnvActions.MOVE_FORWARD)), - (18, WalkingElement(position=(3, 7), direction=1, next_action=RailEnvActions.MOVE_FORWARD)), - (19, WalkingElement(position=None, direction=1, next_action=RailEnvActions.STOP_MOVING)) + ActionPlanElement(1, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(2, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(3, RailEnvActions.STOP_MOVING), + ActionPlanElement(13, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(14, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(15, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(16, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(17, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(18, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(19, RailEnvActions.STOP_MOVING) ], [ - (0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)), - (1, WalkingElement(position=(3, 8), direction=3, next_action=RailEnvActions.MOVE_FORWARD)), - (3, WalkingElement(position=(3, 7), direction=3, next_action=RailEnvActions.MOVE_FORWARD)), - (5, WalkingElement(position=(3, 6), direction=3, next_action=RailEnvActions.MOVE_FORWARD)), - (7, WalkingElement(position=(3, 5), direction=3, next_action=RailEnvActions.MOVE_FORWARD)), - (9, WalkingElement(position=(3, 4), direction=3, next_action=RailEnvActions.MOVE_FORWARD)), - (11, WalkingElement(position=(3, 3), direction=3, next_action=RailEnvActions.MOVE_RIGHT)), - (13, WalkingElement(position=(2, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)), - (15, WalkingElement(position=(1, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)), - (17, WalkingElement(position=None, direction=0, next_action=RailEnvActions.STOP_MOVING)), + ActionPlanElement(0, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(1, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(3, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(5, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(7, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(9, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(11, RailEnvActions.MOVE_RIGHT), + ActionPlanElement(13, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(15, RailEnvActions.MOVE_FORWARD), + ActionPlanElement(17, RailEnvActions.STOP_MOVING), ]] diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index f2084ac3..5b8c385e 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -9,8 +9,9 @@ from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement +from flatland.envs.rail_env_shortest_paths import get_shortest_paths from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.rail_train_run_data_structures import WayPoint from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail @@ -146,12 +147,12 @@ def test_shortest_path_predictor(rendering=False): paths = get_shortest_paths(env.distance_map)[0] assert paths == [ - WalkingElement((5, 6), 0, RailEnvActions.MOVE_FORWARD), - WalkingElement((4, 6), 0, RailEnvActions.MOVE_FORWARD), - WalkingElement((3, 6), 0, RailEnvActions.MOVE_FORWARD), - WalkingElement((3, 7), 1, RailEnvActions.MOVE_FORWARD), - WalkingElement((3, 8), 1, RailEnvActions.MOVE_FORWARD), - WalkingElement((3, 9), 1, RailEnvActions.STOP_MOVING) + WayPoint((5, 6), 0), + WayPoint((4, 6), 0), + WayPoint((3, 6), 0), + WayPoint((3, 7), 1), + WayPoint((3, 8), 1), + WayPoint((3, 9), 1) ] # extract the data diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 0d0927f5..5f17bb3d 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -5,10 +5,11 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv -from flatland.envs.rail_env import RailEnvActions, RailEnv -from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement, get_k_shortest_paths +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths from flatland.envs.rail_env_utils import load_flatland_environment_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.rail_train_run_data_structures import WayPoint from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives @@ -47,45 +48,45 @@ def test_get_shortest_paths(): expected = { 0: [ - WalkingElement(position=(1, 1), direction=1, next_action=RailEnvActions.MOVE_FORWARD, ), - WalkingElement(position=(1, 2), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(1, 3), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 3), direction=2, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 4), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 5), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 6), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 7), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 8), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 9), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 10), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 11), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 12), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 13), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 14), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 15), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 16), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 17), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 18), direction=1, next_action=RailEnvActions.STOP_MOVING)], + WayPoint(position=(1, 1), direction=1), + WayPoint(position=(1, 2), direction=1), + WayPoint(position=(1, 3), direction=1), + WayPoint(position=(2, 3), direction=2), + WayPoint(position=(2, 4), direction=1), + WayPoint(position=(2, 5), direction=1), + WayPoint(position=(2, 6), direction=1), + WayPoint(position=(2, 7), direction=1), + WayPoint(position=(2, 8), direction=1), + WayPoint(position=(2, 9), direction=1), + WayPoint(position=(2, 10), direction=1), + WayPoint(position=(2, 11), direction=1), + WayPoint(position=(2, 12), direction=1), + WayPoint(position=(2, 13), direction=1), + WayPoint(position=(2, 14), direction=1), + WayPoint(position=(2, 15), direction=1), + WayPoint(position=(2, 16), direction=1), + WayPoint(position=(2, 17), direction=1), + WayPoint(position=(2, 18), direction=1)], 1: [ - WalkingElement(position=(3, 18), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(3, 17), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(3, 16), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 16), direction=0, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 15), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 14), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 13), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 12), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 11), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 10), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 9), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 8), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 7), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 6), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 5), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 4), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 3), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 2), direction=3, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 1), direction=3, next_action=RailEnvActions.STOP_MOVING)] + WayPoint(position=(3, 18), direction=3), + WayPoint(position=(3, 17), direction=3), + WayPoint(position=(3, 16), direction=3), + WayPoint(position=(2, 16), direction=0), + WayPoint(position=(2, 15), direction=3), + WayPoint(position=(2, 14), direction=3), + WayPoint(position=(2, 13), direction=3), + WayPoint(position=(2, 12), direction=3), + WayPoint(position=(2, 11), direction=3), + WayPoint(position=(2, 10), direction=3), + WayPoint(position=(2, 9), direction=3), + WayPoint(position=(2, 8), direction=3), + WayPoint(position=(2, 7), direction=3), + WayPoint(position=(2, 6), direction=3), + WayPoint(position=(2, 5), direction=3), + WayPoint(position=(2, 4), direction=3), + WayPoint(position=(2, 3), direction=3), + WayPoint(position=(2, 2), direction=3), + WayPoint(position=(2, 1), direction=3)] } for agent_handle in expected: @@ -102,16 +103,12 @@ def test_get_shortest_paths_max_depth(): expected = { 0: [ - WalkingElement(position=(1, 1), direction=1, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(1, 2), direction=1, - next_action=RailEnvActions.MOVE_FORWARD) + WayPoint(position=(1, 1), direction=1), + WayPoint(position=(1, 2), direction=1) ], 1: [ - WalkingElement(position=(3, 18), direction=3, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(3, 17), direction=3, - next_action=RailEnvActions.MOVE_FORWARD), + WayPoint(position=(3, 18), direction=3), + WayPoint(position=(3, 17), direction=3), ] } @@ -130,114 +127,79 @@ def test_get_shortest_paths_agent_handle(): print(actual, file=sys.stderr) expected = {6: - [WalkingElement(position=(5, 5), - direction=0, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(4, 5), - direction=0, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(3, 5), - direction=0, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 5), - direction=0, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(1, 5), - direction=0, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(0, 5), - direction=0, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(0, 6), - direction=1, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(0, 7), direction=1, next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(0, 8), - direction=1, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(0, 9), - direction=1, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(0, 10), - direction=1, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(1, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(2, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(3, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(4, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(5, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(6, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(7, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(8, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(9, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(10, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(11, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(12, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(13, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(14, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(15, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(16, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(17, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(18, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(19, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(20, 10), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(20, 9), - direction=3, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(20, 8), - direction=3, next_action=RailEnvActions.MOVE_LEFT), - WalkingElement(position=(21, 8), - direction=2, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(21, 7), - direction=3, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(21, 6), - direction=3, - next_action=RailEnvActions.MOVE_FORWARD), - WalkingElement(position=(21, 5), - direction=3, - next_action=RailEnvActions.STOP_MOVING) + [WayPoint(position=(5, 5), + direction=0), + WayPoint(position=(4, 5), + direction=0), + WayPoint(position=(3, 5), + direction=0), + WayPoint(position=(2, 5), + direction=0), + WayPoint(position=(1, 5), + direction=0), + WayPoint(position=(0, 5), + direction=0), + WayPoint(position=(0, 6), + direction=1), + WayPoint(position=(0, 7), direction=1), + WayPoint(position=(0, 8), + direction=1), + WayPoint(position=(0, 9), + direction=1), + WayPoint(position=(0, 10), + direction=1), + WayPoint(position=(1, 10), + direction=2), + WayPoint(position=(2, 10), + direction=2), + WayPoint(position=(3, 10), + direction=2), + WayPoint(position=(4, 10), + direction=2), + WayPoint(position=(5, 10), + direction=2), + WayPoint(position=(6, 10), + direction=2), + WayPoint(position=(7, 10), + direction=2), + WayPoint(position=(8, 10), + direction=2), + WayPoint(position=(9, 10), + direction=2), + WayPoint(position=(10, 10), + direction=2), + WayPoint(position=(11, 10), + direction=2), + WayPoint(position=(12, 10), + direction=2), + WayPoint(position=(13, 10), + direction=2), + WayPoint(position=(14, 10), + direction=2), + WayPoint(position=(15, 10), + direction=2), + WayPoint(position=(16, 10), + direction=2), + WayPoint(position=(17, 10), + direction=2), + WayPoint(position=(18, 10), + direction=2), + WayPoint(position=(19, 10), + direction=2), + WayPoint(position=(20, 10), + direction=2), + WayPoint(position=(20, 9), + direction=3), + WayPoint(position=(20, 8), + direction=3), + WayPoint(position=(21, 8), + direction=2), + WayPoint(position=(21, 7), + direction=3), + WayPoint(position=(21, 6), + direction=3), + WayPoint(position=(21, 5), + direction=3) ]} for agent_handle in expected: @@ -285,41 +247,41 @@ def test_get_k_shortest_paths(rendering=False): expected = set([ ( - WalkingElement(position=(3, 1), direction=3, next_action=2), - WalkingElement(position=(3, 0), direction=3, next_action=2), - WalkingElement(position=(3, 1), direction=1, next_action=2), - WalkingElement(position=(3, 2), direction=1, next_action=2), - WalkingElement(position=(3, 3), direction=1, next_action=1), - WalkingElement(position=(2, 3), direction=0, next_action=2), - WalkingElement(position=(1, 3), direction=0, next_action=2), - WalkingElement(position=(0, 3), direction=0, next_action=2), - WalkingElement(position=(0, 4), direction=1, next_action=2), - WalkingElement(position=(0, 5), direction=1, next_action=2), - WalkingElement(position=(0, 6), direction=1, next_action=2), - WalkingElement(position=(0, 7), direction=1, next_action=2), - WalkingElement(position=(0, 8), direction=1, next_action=2), - WalkingElement(position=(0, 9), direction=1, next_action=2), - WalkingElement(position=(1, 9), direction=2, next_action=2), - WalkingElement(position=(2, 9), direction=2, next_action=2), - WalkingElement(position=(3, 9), direction=2, next_action=0)), + WayPoint(position=(3, 1), direction=3), + WayPoint(position=(3, 0), direction=3), + WayPoint(position=(3, 1), direction=1), + WayPoint(position=(3, 2), direction=1), + WayPoint(position=(3, 3), direction=1), + WayPoint(position=(2, 3), direction=0), + WayPoint(position=(1, 3), direction=0), + WayPoint(position=(0, 3), direction=0), + WayPoint(position=(0, 4), direction=1), + WayPoint(position=(0, 5), direction=1), + WayPoint(position=(0, 6), direction=1), + WayPoint(position=(0, 7), direction=1), + WayPoint(position=(0, 8), direction=1), + WayPoint(position=(0, 9), direction=1), + WayPoint(position=(1, 9), direction=2), + WayPoint(position=(2, 9), direction=2), + WayPoint(position=(3, 9), direction=2)), ( - WalkingElement(position=(3, 1), direction=3, next_action=2), - WalkingElement(position=(3, 0), direction=3, next_action=2), - WalkingElement(position=(3, 1), direction=1, next_action=2), - WalkingElement(position=(3, 2), direction=1, next_action=2), - WalkingElement(position=(3, 3), direction=1, next_action=2), - WalkingElement(position=(3, 4), direction=1, next_action=2), - WalkingElement(position=(3, 5), direction=1, next_action=2), - WalkingElement(position=(3, 6), direction=1, next_action=2), - WalkingElement(position=(4, 6), direction=2, next_action=2), - WalkingElement(position=(5, 6), direction=2, next_action=2), - WalkingElement(position=(6, 6), direction=2, next_action=2), - WalkingElement(position=(5, 6), direction=0, next_action=2), - WalkingElement(position=(4, 6), direction=0, next_action=3), - WalkingElement(position=(4, 7), direction=1, next_action=2), - WalkingElement(position=(4, 8), direction=1, next_action=2), - WalkingElement(position=(4, 9), direction=1, next_action=2), - WalkingElement(position=(3, 9), direction=0, next_action=0)) + WayPoint(position=(3, 1), direction=3), + WayPoint(position=(3, 0), direction=3), + WayPoint(position=(3, 1), direction=1), + WayPoint(position=(3, 2), direction=1), + WayPoint(position=(3, 3), direction=1), + WayPoint(position=(3, 4), direction=1), + WayPoint(position=(3, 5), direction=1), + WayPoint(position=(3, 6), direction=1), + WayPoint(position=(4, 6), direction=2), + WayPoint(position=(5, 6), direction=2), + WayPoint(position=(6, 6), direction=2), + WayPoint(position=(5, 6), direction=0), + WayPoint(position=(4, 6), direction=0), + WayPoint(position=(4, 7), direction=1), + WayPoint(position=(4, 8), direction=1), + WayPoint(position=(4, 9), direction=1), + WayPoint(position=(3, 9), direction=0)) ]) assert actual == expected, "actual={},expected={}".format(actual, expected) -- GitLab