Commit 9916ed52 authored by u214892's avatar u214892
Browse files

trainrun instead of train_run and waypoint instead of way_point for...

trainrun instead of train_run and waypoint instead of way_point for compatibility with Trainrun and Waypoint for readability
parent 8be8160f
Pipeline #2839 passed with stages
in 38 minutes and 37 seconds
......@@ -6,7 +6,7 @@ import numpy as np
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_env_shortest_paths import get_action_for_move
from flatland.envs.rail_train_run_data_structures import Waypoint, Trainrun, TrainrunWaypoint
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
# ---- ActionPlan ---------------
......@@ -28,14 +28,14 @@ class ControllerFromTrainruns():
def __init__(self,
env: RailEnv,
train_run_dict: Dict[int, Trainrun]):
trainrun_dict: Dict[int, Trainrun]):
self.env: RailEnv = env
self.train_run_dict: Dict[int, Trainrun] = train_run_dict
self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict
self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path)
for agent_id, chosen_path in train_run_dict.items()]
for agent_id, chosen_path in trainrun_dict.items()]
def get_way_point_before_or_at_step(self, agent_id: int, step: int) -> Waypoint:
def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint:
"""
Get the way point point from which the current position can be extracted.
......@@ -49,26 +49,26 @@ class ControllerFromTrainruns():
WalkingElement
"""
train_run = self.train_run_dict[agent_id]
entry_time_step = train_run[0].scheduled_at
trainrun = self.trainrun_dict[agent_id]
entry_time_step = trainrun[0].scheduled_at
# the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid)
if step <= entry_time_step:
return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction)
# the agent has no position as soon as the target is reached
exit_time_step = train_run[-1].scheduled_at
exit_time_step = trainrun[-1].scheduled_at
if step >= exit_time_step:
# agent loses position as soon as target cell is reached
return Waypoint(position=None, direction=train_run[-1].way_point.direction)
return Waypoint(position=None, direction=trainrun[-1].waypoint.direction)
way_point = None
for train_run_way_point in train_run:
if step < train_run_way_point.scheduled_at:
return way_point
if step >= train_run_way_point.scheduled_at:
way_point = train_run_way_point.way_point
assert way_point is not None
return way_point
waypoint = None
for trainrun_waypoint in trainrun:
if step < trainrun_waypoint.scheduled_at:
return waypoint
if step >= trainrun_waypoint.scheduled_at:
waypoint = trainrun_waypoint.waypoint
assert waypoint is not None
return waypoint
def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
"""
......@@ -142,26 +142,26 @@ class ControllerFromTrainruns():
assert expected_action_plan == actual_action_plan, \
"expected {}, found {}".format(expected_action_plan, actual_action_plan)
def _create_action_plan_for_agent(self, agent_id, train_run) -> ActionPlan:
def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
action_plan = []
agent = self.env.agents[agent_id]
minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed']))
for path_loop, train_run_way_point in enumerate(train_run):
train_run_way_point: TrainrunWaypoint = train_run_way_point
for path_loop, trainrun_waypoint in enumerate(trainrun):
trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
position = train_run_way_point.way_point.position
position = trainrun_waypoint.waypoint.position
if Vec2d.is_equal(agent.target, position):
break
next_train_run_way_point: TrainrunWaypoint = train_run[path_loop + 1]
next_position = next_train_run_way_point.way_point.position
next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1]
next_position = next_trainrun_waypoint.waypoint.position
if path_loop == 0:
self._add_action_plan_elements_for_first_path_element_of_agent(
action_plan,
train_run_way_point,
next_train_run_way_point,
trainrun_waypoint,
next_trainrun_waypoint,
minimum_cell_time
)
continue
......@@ -171,30 +171,30 @@ class ControllerFromTrainruns():
self._add_action_plan_elements_for_current_path_element(
action_plan,
minimum_cell_time,
train_run_way_point,
next_train_run_way_point)
trainrun_waypoint,
next_trainrun_waypoint)
# add a final element
if just_before_target:
self._add_action_plan_elements_for_target_at_path_element_just_before_target(
action_plan,
minimum_cell_time,
train_run_way_point,
next_train_run_way_point)
trainrun_waypoint,
next_trainrun_waypoint)
return action_plan
def _add_action_plan_elements_for_current_path_element(self,
action_plan: ActionPlan,
minimum_cell_time: int,
train_run_way_point: TrainrunWaypoint,
next_train_run_way_point: TrainrunWaypoint):
scheduled_at = train_run_way_point.scheduled_at
next_entry_value = next_train_run_way_point.scheduled_at
position = train_run_way_point.way_point.position
direction = train_run_way_point.way_point.direction
next_position = next_train_run_way_point.way_point.position
next_direction = next_train_run_way_point.way_point.direction
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
next_entry_value = next_trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
next_action = get_action_for_move(position,
direction,
next_position,
......@@ -217,23 +217,23 @@ class ControllerFromTrainruns():
def _add_action_plan_elements_for_target_at_path_element_just_before_target(self,
action_plan: ActionPlan,
minimum_cell_time: int,
train_run_way_point: TrainrunWaypoint,
next_train_run_way_point: TrainrunWaypoint):
scheduled_at = train_run_way_point.scheduled_at
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint):
scheduled_at = trainrun_waypoint.scheduled_at
action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING)
action_plan.append(action)
def _add_action_plan_elements_for_first_path_element_of_agent(self,
action_plan: ActionPlan,
train_run_way_point: TrainrunWaypoint,
next_train_run_way_point: TrainrunWaypoint,
trainrun_waypoint: TrainrunWaypoint,
next_trainrun_waypoint: TrainrunWaypoint,
minimum_cell_time: int):
scheduled_at = train_run_way_point.scheduled_at
position = train_run_way_point.way_point.position
direction = train_run_way_point.way_point.direction
next_position = next_train_run_way_point.way_point.position
next_direction = next_train_run_way_point.way_point.direction
scheduled_at = trainrun_waypoint.scheduled_at
position = trainrun_waypoint.waypoint.position
direction = trainrun_waypoint.waypoint.direction
next_position = next_trainrun_waypoint.waypoint.position
next_direction = next_trainrun_waypoint.waypoint.direction
# add intial do nothing if we do not enter immediately, actually not necessary
if scheduled_at > 0:
......@@ -251,12 +251,12 @@ class ControllerFromTrainruns():
self.env.rail)
# if the agent is blocked in the cell, we have to call stop upon entering!
if next_train_run_way_point.scheduled_at > scheduled_at + 1 + minimum_cell_time:
if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time:
action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING)
action_plan.append(action)
# execute the action exactly minimum_cell_time before the entry into the next cell
action = ActionPlanElement(next_train_run_way_point.scheduled_at - minimum_cell_time, next_action)
action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action)
action_plan.append(action)
......@@ -277,10 +277,10 @@ class ControllerFromTrainrunsReplayer():
i = 0
while not env.dones['__all__'] and i <= env._max_episode_steps:
for agent_id, agent in enumerate(env.agents):
way_point: Waypoint = ctl.get_way_point_before_or_at_step(agent_id, i)
assert agent.position == way_point.position, \
waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i)
assert agent.position == waypoint.position, \
"before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
way_point.position)
waypoint.position)
actions = ctl.act(i)
print("actions for {}: {}".format(i, actions))
......
......@@ -10,7 +10,7 @@ from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv
from flatland.envs.rail_train_run_data_structures import Waypoint
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.utils.ordered_set import OrderedSet
......
......@@ -12,7 +12,7 @@ Waypoint = NamedTuple('Waypoint', [('position', Tuple[int, int]), ('direction',
# The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md
TrainrunWaypoint = NamedTuple('TrainrunWaypoint', [
('scheduled_at', int),
('way_point', Waypoint)
('waypoint', Waypoint)
])
# A train run is the list of an agent's way points and their scheduled time
Trainrun = List[TrainrunWaypoint]
......@@ -4,7 +4,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv, RailEnvActions
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_train_run_data_structures import Waypoint
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.simple_rail import make_simple_rail
......@@ -32,25 +32,25 @@ def test_action_plan(rendering: bool = False):
for handle, agent in enumerate(env.agents):
print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, way_point=Waypoint(position=(3, 0), direction=3)),
TrainrunWaypoint(scheduled_at=2, way_point=Waypoint(position=(3, 1), direction=1)),
TrainrunWaypoint(scheduled_at=3, way_point=Waypoint(position=(3, 2), direction=1)),
TrainrunWaypoint(scheduled_at=14, way_point=Waypoint(position=(3, 3), direction=1)),
TrainrunWaypoint(scheduled_at=15, way_point=Waypoint(position=(3, 4), direction=1)),
TrainrunWaypoint(scheduled_at=16, way_point=Waypoint(position=(3, 5), direction=1)),
TrainrunWaypoint(scheduled_at=17, way_point=Waypoint(position=(3, 6), direction=1)),
TrainrunWaypoint(scheduled_at=18, way_point=Waypoint(position=(3, 7), direction=1)),
TrainrunWaypoint(scheduled_at=19, way_point=Waypoint(position=(3, 8), direction=1)),
TrainrunWaypoint(scheduled_at=20, way_point=Waypoint(position=(3, 8), direction=5))],
1: [TrainrunWaypoint(scheduled_at=0, way_point=Waypoint(position=(3, 8), direction=3)),
TrainrunWaypoint(scheduled_at=3, way_point=Waypoint(position=(3, 7), direction=3)),
TrainrunWaypoint(scheduled_at=5, way_point=Waypoint(position=(3, 6), direction=3)),
TrainrunWaypoint(scheduled_at=7, way_point=Waypoint(position=(3, 5), direction=3)),
TrainrunWaypoint(scheduled_at=9, way_point=Waypoint(position=(3, 4), direction=3)),
TrainrunWaypoint(scheduled_at=11, way_point=Waypoint(position=(3, 3), direction=3)),
TrainrunWaypoint(scheduled_at=13, way_point=Waypoint(position=(2, 3), direction=0)),
TrainrunWaypoint(scheduled_at=15, way_point=Waypoint(position=(1, 3), direction=0)),
TrainrunWaypoint(scheduled_at=17, way_point=Waypoint(position=(0, 3), direction=0))]}
chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)),
TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)),
TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)),
TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)),
TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)),
TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)),
TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)),
TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)),
TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)),
TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))],
1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)),
TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)),
TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)),
TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)),
TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)),
TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)),
TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)),
TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)),
TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]}
expected_action_plan = [[
# take action to enter the grid
ActionPlanElement(0, RailEnvActions.MOVE_FORWARD),
......
......@@ -11,7 +11,7 @@ from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPred
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_train_run_data_structures import Waypoint
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail
......
......@@ -8,7 +8,7 @@ from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths
from flatland.envs.rail_env_utils import load_flatland_environment_from_file
from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.rail_train_run_data_structures import Waypoint
from flatland.envs.rail_trainrun_data_structures import Waypoint
from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool
from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment