Commit fe5f6f46 authored by u214892's avatar u214892
Browse files

SIM-119 refactoring ActionPlan; TODO extract Agent from ActionPlanReplayer

parent d4e6af1c
Pipeline #2792 passed with stages
in 36 minutes and 44 seconds
......@@ -2,38 +2,25 @@ import pprint
from typing import Dict, List, Optional, NamedTuple
import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import WalkingElement, get_action_for_move
from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_train_run_data_structures import WayPoint, TrainRun, TrainRunWayPoint
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
#---- Input Data Structures (graph representation) ---------------------------------------------
# A cell pin represents the one of the four pins in which the cell at row,column may be entered.
CellPin = NamedTuple('CellPin', [('r', int), ('c', int), ('d', int)])
# A path schedule element represents the entry time of agent at a cell pin.
PathScheduleElement = NamedTuple('PathScheduleElement', [
('scheduled_at', int),
('cell_pin', CellPin)
])
# A path schedule is the list of an agent's cell pin entries
PathSchedule = List[PathScheduleElement]
#---- Output Data Structures (FLATland representation) ---------------------------------------------
# ---- Output Data Structures (FLATland representation) ---------------------------------------------
# An action plan element represents the actions to be taken by an agent at deterministic time steps
# plus the position before the action
ActionPlanElement = NamedTuple('ActionPlanElement', [
('scheduled_at', int),
('walking_element', WalkingElement)
('action', RailEnvActions)
])
# An action plan deterministically represents all the actions to be taken by an agent
# plus its position before the actions are taken
ActionPlan = Dict[int, List[ActionPlanElement]]
class ActionPlanReplayer():
"""Allows to deduce an `ActionPlan` from the agents' `PathSchedule` and
to be replayed/verified in a FLATland env without malfunction."""
......@@ -42,15 +29,17 @@ class ActionPlanReplayer():
def __init__(self,
env: RailEnv,
chosen_path_dict: Dict[int, PathSchedule]):
train_run_dict: Dict[int, TrainRun]):
self.env = env
self.train_run_dict: Dict[int, TrainRun] = train_run_dict
print(train_run_dict)
self.action_plan = [[] for _ in range(self.env.get_num_agents())]
for agent_id, chosen_path in chosen_path_dict.items():
for agent_id, chosen_path in train_run_dict.items():
self._add_aggent_to_action_plan(self.action_plan, agent_id, chosen_path)
def get_walking_element_before_or_at_step(self, agent_id: int, step: int) -> WalkingElement:
def get_way_point_before_or_at_step(self, agent_id: int, step: int) -> WayPoint:
"""
Get the walking element from which the current position can be extracted.
......@@ -64,14 +53,26 @@ class ActionPlanReplayer():
WalkingElement
"""
walking_element = None
for action in self.action_plan[agent_id]:
if step < action.scheduled_at:
return walking_element
if step >= action.scheduled_at:
walking_element = action.walking_element
assert walking_element is not None
return walking_element
train_run = self.train_run_dict[agent_id]
entry_time_step = train_run[0].scheduled_at
# the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid)
if step <= entry_time_step:
return WayPoint(position=None, direction=self.env.agents[agent_id].initial_direction)
# the agent has no position as soon as the target is reached
exit_time_step = train_run[-1].scheduled_at
if step >= exit_time_step:
# agent loses position as soon as target cell is reached
return WayPoint(position=None, direction=train_run[-1].way_point.direction)
way_point = None
for train_run_way_point in train_run:
if step < train_run_way_point.scheduled_at:
return way_point
if step >= train_run_way_point.scheduled_at:
way_point = train_run_way_point.way_point
assert way_point is not None
return way_point
def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
"""
......@@ -90,11 +91,10 @@ class ActionPlanReplayer():
for action_plan_step in self.action_plan[agent_id]:
action_plan_step: ActionPlanElement = action_plan_step
scheduled_at = action_plan_step.scheduled_at
walking_element: WalkingElement = action_plan_step.walking_element
if scheduled_at > current_step:
return None
elif np.isclose(current_step, scheduled_at):
return walking_element.next_action
return action_plan_step.action
return None
def get_action_dict_for_step_replay(self, current_step: int) -> Dict[int, RailEnvActions]:
......@@ -130,10 +130,10 @@ class ActionPlanReplayer():
i = 0
while not env.dones['__all__'] and i <= MAX_EPISODE_STEPS:
for agent_id, agent in enumerate(env.agents):
walking_element: WalkingElement = self.get_walking_element_before_or_at_step(agent_id, i)
assert agent.position == walking_element.position, \
way_point: WayPoint = self.get_way_point_before_or_at_step(agent_id, i)
assert agent.position == way_point.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
walking_element.position)
way_point.position)
actions = self.get_action_dict_for_step_replay(i)
print("actions for {}: {}".format(i, actions))
......@@ -172,15 +172,15 @@ class ActionPlanReplayer():
agent = self.env.agents[agent_id]
minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed']))
for path_loop, path_schedule_element in enumerate(agent_path_new):
path_schedule_element: PathScheduleElement = path_schedule_element
path_schedule_element: TrainRunWayPoint = path_schedule_element
position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c)
position = path_schedule_element.way_point.position
if Vec2d.is_equal(agent.target, position):
break
next_path_schedule_element: PathScheduleElement = agent_path_new[path_loop + 1]
next_position = (next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c)
next_path_schedule_element: TrainRunWayPoint = agent_path_new[path_loop + 1]
next_position = next_path_schedule_element.way_point.position
if path_loop == 0:
self._create_action_plan_for_first_path_element_of_agent(
......@@ -212,81 +212,63 @@ class ActionPlanReplayer():
action_plan: ActionPlan,
agent_id: int,
minimum_cell_time: int,
path_schedule_element: PathScheduleElement,
next_path_schedule_element: PathScheduleElement):
path_schedule_element: TrainRunWayPoint,
next_path_schedule_element: TrainRunWayPoint):
scheduled_at = path_schedule_element.scheduled_at
next_entry_value = next_path_schedule_element.scheduled_at
position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c)
direction = path_schedule_element.cell_pin.d
next_position = next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c
next_direction = next_path_schedule_element.cell_pin.d
position = path_schedule_element.way_point.position
direction = path_schedule_element.way_point.direction
next_position = next_path_schedule_element.way_point.position
next_direction = next_path_schedule_element.way_point.direction
next_action = get_action_for_move(position,
direction,
next_position,
next_direction,
self.env.rail)
walking_element = WalkingElement(position, direction, next_action)
# if the next entry is later than minimum_cell_time, then stop here and
# move minimum_cell_time before the exit
# we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
if next_entry_value > scheduled_at + minimum_cell_time:
action = ActionPlanElement(scheduled_at,
WalkingElement(
position=position,
direction=direction,
next_action=RailEnvActions.STOP_MOVING))
action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING)
action_plan[agent_id].append(action)
action = ActionPlanElement(next_entry_value - minimum_cell_time, walking_element)
action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action)
action_plan[agent_id].append(action)
else:
action = ActionPlanElement(scheduled_at, walking_element)
action = ActionPlanElement(scheduled_at, next_action)
action_plan[agent_id].append(action)
def _create_action_plan_for_target_at_path_element_just_before_target(self,
action_plan: ActionPlan,
agent_id: int,
minimum_cell_time: int,
path_schedule_element: PathScheduleElement,
next_path_schedule_element: PathScheduleElement):
path_schedule_element: TrainRunWayPoint,
next_path_schedule_element: TrainRunWayPoint):
scheduled_at = path_schedule_element.scheduled_at
next_path_schedule_element.cell_pin
next_path_schedule_element.way_point
action = ActionPlanElement(scheduled_at + minimum_cell_time,
WalkingElement(
position=None,
direction=next_path_schedule_element.cell_pin.d,
next_action=RailEnvActions.STOP_MOVING))
action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING)
action_plan[agent_id].append(action)
def _create_action_plan_for_first_path_element_of_agent(self,
action_plan: ActionPlan,
agent_id: int,
path_schedule_element: PathScheduleElement,
next_path_schedule_element: PathScheduleElement):
path_schedule_element: TrainRunWayPoint,
next_path_schedule_element: TrainRunWayPoint):
scheduled_at = path_schedule_element.scheduled_at
position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c)
direction = path_schedule_element.cell_pin.d
next_position = next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c
next_direction = next_path_schedule_element.cell_pin.d
position = path_schedule_element.way_point.position
direction = path_schedule_element.way_point.direction
next_position = next_path_schedule_element.way_point.position
next_direction = next_path_schedule_element.way_point.direction
# add intial do nothing if we do not enter immediately
if scheduled_at > 0:
action = ActionPlanElement(0,
WalkingElement(
position=None,
direction=direction,
next_action=RailEnvActions.DO_NOTHING))
action = ActionPlanElement(0, RailEnvActions.DO_NOTHING)
action_plan[agent_id].append(action)
# add action to enter the grid
action = ActionPlanElement(scheduled_at,
WalkingElement(
position=None,
direction=direction,
next_action=RailEnvActions.MOVE_FORWARD))
action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD)
action_plan[agent_id].append(action)
next_action = get_action_for_move(position,
......@@ -296,9 +278,5 @@ class ActionPlanReplayer():
self.env.rail)
# now, we have a position need to perform the action
action = ActionPlanElement(scheduled_at + 1,
WalkingElement(
position=position,
direction=direction,
next_action=next_action))
action = ActionPlanElement(scheduled_at + 1, next_action)
action_plan[agent_id].append(action)
import math
from typing import Dict, List, Optional, NamedTuple, Tuple, Set
from typing import Dict, List, Optional, Tuple, Set
import matplotlib.pyplot as plt
import numpy as np
......@@ -10,12 +10,9 @@ from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv
from flatland.envs.rail_train_run_data_structures import WayPoint
from flatland.utils.ordered_set import OrderedSet
WalkingElement = \
NamedTuple('WalkingElement',
[('position', Tuple[int, int]), ('direction', int), ('next_action', Optional[RailEnvActions])])
def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
agent_position: Tuple[int, int],
......@@ -195,7 +192,7 @@ def get_action_for_move(
# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) \
-> Dict[int, Optional[List[WalkingElement]]]:
-> Dict[int, Optional[List[WayPoint]]]:
"""
Computes the shortest path for each agent to its target and the action to be taken to do so.
The paths are derived from a `DistanceMap`.
......@@ -245,8 +242,7 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
best_next_action = next_action
distance = next_action_distance
shortest_paths[agent.handle].append(
WalkingElement(position, direction, best_next_action.action if best_next_action is not None else None))
shortest_paths[agent.handle].append(WayPoint(position, direction))
depth += 1
# if there is no way to continue, the rail must be disconnected!
......@@ -258,7 +254,7 @@ def get_shortest_paths(distance_map: DistanceMap, max_depth: Optional[int] = Non
position = best_next_action.next_position
direction = best_next_action.next_direction
if max_depth is None or depth < max_depth:
shortest_paths[agent.handle].append(WalkingElement(position, direction, RailEnvActions.STOP_MOVING))
shortest_paths[agent.handle].append(WayPoint(position, direction))
if agent_handle is not None:
_shortest_path_for_agent(distance_map.agents[agent_handle])
......@@ -273,7 +269,7 @@ def get_k_shortest_paths(env: RailEnv,
source_position: Tuple[int, int],
source_direction: int,
target_position=Tuple[int, int],
k: int = 1, debug=False) -> List[Tuple[WalkingElement]]:
k: int = 1, debug=False) -> List[Tuple[WayPoint]]:
"""
Computes the k shortest paths using modified Dijkstra
following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing
......@@ -300,17 +296,17 @@ def get_k_shortest_paths(env: RailEnv,
# P: set of shortest paths from s to t
# P =empty,
shortest_paths: List[Tuple[WalkingElement]] = []
shortest_paths: List[Tuple[WayPoint]] = []
# countu: number of shortest paths found to node u
# countu = 0, for all u in V
count = {(r, c, d): 0 for r in range(env.height) for c in range(env.width) for d in range(4)}
# B is a heap data structure containing paths
heap: Set[Tuple[WalkingElement]] = set()
heap: Set[Tuple[WayPoint]] = set()
# insert path Ps = {s} into B with cost 0
heap.add((WalkingElement(source_position, source_direction, None),))
heap.add((WayPoint(source_position, source_direction),))
# while B is not empty and countt < K:
while len(heap) > 0 and len(shortest_paths) < k:
......@@ -323,7 +319,7 @@ def get_k_shortest_paths(env: RailEnv,
if len(path) < c:
pu = path
c = len(path)
u: WalkingElement = pu[-1]
u: WayPoint = pu[-1]
if debug:
print(" looking at pu={}".format(pu))
......@@ -355,7 +351,7 @@ def get_k_shortest_paths(env: RailEnv,
if debug:
print(" looking at neighbor v={}".format((*new_position, new_direction)))
v = WalkingElement(position=new_position, direction=new_direction, next_action=None)
v = WayPoint(position=new_position, direction=new_direction)
# CAVEAT: do not allow for loopy paths
if v in pu:
continue
......@@ -364,25 +360,9 @@ def get_k_shortest_paths(env: RailEnv,
pv = pu + (v,)
# – insert Pv into B
heap.add(pv)
# add actions to shortest paths
shortest_paths_with_action = []
for p in shortest_paths:
p_with_action = tuple(
WalkingElement(position=el.position,
direction=el.direction,
next_action=int(get_action_for_move(el.position,
el.direction,
p[i + 1].position,
p[i + 1].direction,
env.rail))) for i, el in
enumerate(p[:-1]))
target_walking_element = WalkingElement(position=p[-1].position,
direction=p[-1].direction,
next_action=int(RailEnvActions.DO_NOTHING))
shortest_paths_with_action.append(p_with_action + (target_walking_element,))
# return P
return shortest_paths_with_action
return shortest_paths
def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
......
from typing import NamedTuple, Tuple, List
WayPoint = NamedTuple('WayPoint', [('position', Tuple[int, int]), ('direction', int)])
# A train run is represented by the waypoints traversed and the times of traversal
# The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md
TrainRunWayPoint = NamedTuple('TrainRunWayPoint', [
('scheduled_at', int),
('way_point', WayPoint)
])
# A path schedule is the list of an agent's cell pin entries
TrainRun = List[TrainRunWayPoint]
from flatland.action_plan.action_plan import PathScheduleElement, CellPin, ActionPlanReplayer
from flatland.action_plan.action_plan import TrainRunWayPoint, ActionPlanReplayer, ActionPlanElement
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import WalkingElement
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_train_run_data_structures import WayPoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail
def test_action_plan(rendering: bool = False):
"""Tests ActionPlanReplayer: does action plan generation and replay work as expected."""
rail, rail_map = make_simple_rail()
......@@ -33,52 +32,51 @@ def test_action_plan(rendering: bool = False):
for handle, agent in enumerate(env.agents):
print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
chosen_path_dict = {0: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=0, d=3)),
PathScheduleElement(scheduled_at=2, cell_pin=CellPin(r=3, c=1, d=1)),
PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=2, d=1)),
PathScheduleElement(scheduled_at=14, cell_pin=CellPin(r=3, c=3, d=1)),
PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=3, c=4, d=1)),
PathScheduleElement(scheduled_at=16, cell_pin=CellPin(r=3, c=5, d=1)),
PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=3, c=6, d=1)),
PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=3, c=7, d=1)),
PathScheduleElement(scheduled_at=19, cell_pin=CellPin(r=3, c=8, d=1)),
PathScheduleElement(scheduled_at=20, cell_pin=CellPin(r=3, c=8, d=5))],
1: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=8, d=3)),
PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=7, d=3)),
PathScheduleElement(scheduled_at=5, cell_pin=CellPin(r=3, c=6, d=3)),
PathScheduleElement(scheduled_at=7, cell_pin=CellPin(r=3, c=5, d=3)),
PathScheduleElement(scheduled_at=9, cell_pin=CellPin(r=3, c=4, d=3)),
PathScheduleElement(scheduled_at=11, cell_pin=CellPin(r=3, c=3, d=3)),
PathScheduleElement(scheduled_at=13, cell_pin=CellPin(r=2, c=3, d=0)),
PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=1, c=3, d=0)),
PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=0, c=3, d=0)),
PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=0, c=3, d=5))]}
chosen_path_dict = {0: [TrainRunWayPoint(scheduled_at=0, way_point=WayPoint(position=(3, 0), direction=3)),
TrainRunWayPoint(scheduled_at=2, way_point=WayPoint(position=(3, 1), direction=1)),
TrainRunWayPoint(scheduled_at=3, way_point=WayPoint(position=(3, 2), direction=1)),
TrainRunWayPoint(scheduled_at=14, way_point=WayPoint(position=(3, 3), direction=1)),
TrainRunWayPoint(scheduled_at=15, way_point=WayPoint(position=(3, 4), direction=1)),
TrainRunWayPoint(scheduled_at=16, way_point=WayPoint(position=(3, 5), direction=1)),
TrainRunWayPoint(scheduled_at=17, way_point=WayPoint(position=(3, 6), direction=1)),
TrainRunWayPoint(scheduled_at=18, way_point=WayPoint(position=(3, 7), direction=1)),
TrainRunWayPoint(scheduled_at=19, way_point=WayPoint(position=(3, 8), direction=1)),
TrainRunWayPoint(scheduled_at=20, way_point=WayPoint(position=(3, 8), direction=5))],
1: [TrainRunWayPoint(scheduled_at=0, way_point=WayPoint(position=(3, 8), direction=3)),
TrainRunWayPoint(scheduled_at=3, way_point=WayPoint(position=(3, 7), direction=3)),
TrainRunWayPoint(scheduled_at=5, way_point=WayPoint(position=(3, 6), direction=3)),
TrainRunWayPoint(scheduled_at=7, way_point=WayPoint(position=(3, 5), direction=3)),
TrainRunWayPoint(scheduled_at=9, way_point=WayPoint(position=(3, 4), direction=3)),
TrainRunWayPoint(scheduled_at=11, way_point=WayPoint(position=(3, 3), direction=3)),
TrainRunWayPoint(scheduled_at=13, way_point=WayPoint(position=(2, 3), direction=0)),
TrainRunWayPoint(scheduled_at=15, way_point=WayPoint(position=(1, 3), direction=0)),
TrainRunWayPoint(scheduled_at=17, way_point=WayPoint(position=(0, 3), direction=0))]}
expected_action_plan = [[
# take action to enter the grid
(0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
# take action to enter the cell properly
(1, WalkingElement(position=(3, 0), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(2, WalkingElement(position=(3, 1), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(3, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.STOP_MOVING)),
(13, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(14, WalkingElement(position=(3, 3), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(15, WalkingElement(position=(3, 4), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(16, WalkingElement(position=(3, 5), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(17, WalkingElement(position=(3, 6), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(18, WalkingElement(position=(3, 7), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
(19, WalkingElement(position=None, direction=1, next_action=RailEnvActions.STOP_MOVING))
ActionPlanElement(1, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(2, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(3, RailEnvActions.STOP_MOVING),
ActionPlanElement(13, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(14, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(15, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(16, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(17, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(18, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(19, RailEnvActions.STOP_MOVING)
], [
(0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(1, WalkingElement(position=(3, 8), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(3, WalkingElement(position=(3, 7), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(5, WalkingElement(position=(3, 6), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(7, WalkingElement(position=(3, 5), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(9, WalkingElement(position=(3, 4), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
(11, WalkingElement(position=(3, 3), direction=3, next_action=RailEnvActions.MOVE_RIGHT)),
(13, WalkingElement(position=(2, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)),
(15, WalkingElement(position=(1, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)),
(17, WalkingElement(position=None, direction=0, next_action=RailEnvActions.STOP_MOVING)),
ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(1, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(3, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(5, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(7, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(9, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(11, RailEnvActions.MOVE_RIGHT),
ActionPlanElement(13, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(15, RailEnvActions.MOVE_FORWARD),
ActionPlanElement(17, RailEnvActions.STOP_MOVING),
]]
......
......@@ -9,8 +9,9 @@ from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_shortest_paths, WalkingElement
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_train_run_data_structures import WayPoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
......@@ -146,12 +147,12 @@ def test_shortest_path_predictor(rendering=False):
paths = get_shortest_paths(env.distance_map)[0]
assert paths == [
WalkingElement((5, 6), 0, RailEnvActions.MOVE_FORWARD),
WalkingElement((4, 6), 0, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 6), 0, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 7), 1, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 8), 1, RailEnvActions.MOVE_FORWARD),
WalkingElement((3, 9), 1, RailEnvActions.STOP_MOVING)
WayPoint((5, 6), 0),
WayPoint((4, 6), 0),
WayPoint((3, 6), 0),
WayPoint((3, 7), 1),
WayPoint((3, 8), 1),
WayPoint((3, 9), 1)
]
# extract the data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment