diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py index e0616a444dc41121804b9154e1bcdc3b34341214..65a33fdfb070b6ae7f4fe1396c77825c57eed6ab 100644 --- a/flatland/action_plan/action_plan.py +++ b/flatland/action_plan/action_plan.py @@ -6,7 +6,7 @@ import numpy as np from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_env_shortest_paths import get_action_for_move -from flatland.envs.rail_train_run_data_structures import Waypoint, Trainrun, TrainrunWaypoint +from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint from flatland.utils.rendertools import RenderTool, AgentRenderVariant # ---- ActionPlan --------------- @@ -28,14 +28,14 @@ class ControllerFromTrainruns(): def __init__(self, env: RailEnv, - train_run_dict: Dict[int, Trainrun]): + trainrun_dict: Dict[int, Trainrun]): self.env: RailEnv = env - self.train_run_dict: Dict[int, Trainrun] = train_run_dict + self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path) - for agent_id, chosen_path in train_run_dict.items()] + for agent_id, chosen_path in trainrun_dict.items()] - def get_way_point_before_or_at_step(self, agent_id: int, step: int) -> Waypoint: + def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint: """ Get the way point point from which the current position can be extracted. @@ -49,26 +49,26 @@ class ControllerFromTrainruns(): WalkingElement """ - train_run = self.train_run_dict[agent_id] - entry_time_step = train_run[0].scheduled_at + trainrun = self.trainrun_dict[agent_id] + entry_time_step = trainrun[0].scheduled_at # the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid) if step <= entry_time_step: return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction) # the agent has no position as soon as the target is reached - exit_time_step = train_run[-1].scheduled_at + exit_time_step = trainrun[-1].scheduled_at if step >= exit_time_step: # agent loses position as soon as target cell is reached - return Waypoint(position=None, direction=train_run[-1].way_point.direction) + return Waypoint(position=None, direction=trainrun[-1].waypoint.direction) - way_point = None - for train_run_way_point in train_run: - if step < train_run_way_point.scheduled_at: - return way_point - if step >= train_run_way_point.scheduled_at: - way_point = train_run_way_point.way_point - assert way_point is not None - return way_point + waypoint = None + for trainrun_waypoint in trainrun: + if step < trainrun_waypoint.scheduled_at: + return waypoint + if step >= trainrun_waypoint.scheduled_at: + waypoint = trainrun_waypoint.waypoint + assert waypoint is not None + return waypoint def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]: """ @@ -142,26 +142,26 @@ class ControllerFromTrainruns(): assert expected_action_plan == actual_action_plan, \ "expected {}, found {}".format(expected_action_plan, actual_action_plan) - def _create_action_plan_for_agent(self, agent_id, train_run) -> ActionPlan: + def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan: action_plan = [] agent = self.env.agents[agent_id] minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed'])) - for path_loop, train_run_way_point in enumerate(train_run): - train_run_way_point: TrainrunWaypoint = train_run_way_point + for path_loop, trainrun_waypoint in enumerate(trainrun): + trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint - position = train_run_way_point.way_point.position + position = trainrun_waypoint.waypoint.position if Vec2d.is_equal(agent.target, position): break - next_train_run_way_point: TrainrunWaypoint = train_run[path_loop + 1] - next_position = next_train_run_way_point.way_point.position + next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1] + next_position = next_trainrun_waypoint.waypoint.position if path_loop == 0: self._add_action_plan_elements_for_first_path_element_of_agent( action_plan, - train_run_way_point, - next_train_run_way_point, + trainrun_waypoint, + next_trainrun_waypoint, minimum_cell_time ) continue @@ -171,30 +171,30 @@ class ControllerFromTrainruns(): self._add_action_plan_elements_for_current_path_element( action_plan, minimum_cell_time, - train_run_way_point, - next_train_run_way_point) + trainrun_waypoint, + next_trainrun_waypoint) # add a final element if just_before_target: self._add_action_plan_elements_for_target_at_path_element_just_before_target( action_plan, minimum_cell_time, - train_run_way_point, - next_train_run_way_point) + trainrun_waypoint, + next_trainrun_waypoint) return action_plan def _add_action_plan_elements_for_current_path_element(self, action_plan: ActionPlan, minimum_cell_time: int, - train_run_way_point: TrainrunWaypoint, - next_train_run_way_point: TrainrunWaypoint): - scheduled_at = train_run_way_point.scheduled_at - next_entry_value = next_train_run_way_point.scheduled_at - - position = train_run_way_point.way_point.position - direction = train_run_way_point.way_point.direction - next_position = next_train_run_way_point.way_point.position - next_direction = next_train_run_way_point.way_point.direction + trainrun_waypoint: TrainrunWaypoint, + next_trainrun_waypoint: TrainrunWaypoint): + scheduled_at = trainrun_waypoint.scheduled_at + next_entry_value = next_trainrun_waypoint.scheduled_at + + position = trainrun_waypoint.waypoint.position + direction = trainrun_waypoint.waypoint.direction + next_position = next_trainrun_waypoint.waypoint.position + next_direction = next_trainrun_waypoint.waypoint.direction next_action = get_action_for_move(position, direction, next_position, @@ -217,23 +217,23 @@ class ControllerFromTrainruns(): def _add_action_plan_elements_for_target_at_path_element_just_before_target(self, action_plan: ActionPlan, minimum_cell_time: int, - train_run_way_point: TrainrunWaypoint, - next_train_run_way_point: TrainrunWaypoint): - scheduled_at = train_run_way_point.scheduled_at + trainrun_waypoint: TrainrunWaypoint, + next_trainrun_waypoint: TrainrunWaypoint): + scheduled_at = trainrun_waypoint.scheduled_at action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING) action_plan.append(action) def _add_action_plan_elements_for_first_path_element_of_agent(self, action_plan: ActionPlan, - train_run_way_point: TrainrunWaypoint, - next_train_run_way_point: TrainrunWaypoint, + trainrun_waypoint: TrainrunWaypoint, + next_trainrun_waypoint: TrainrunWaypoint, minimum_cell_time: int): - scheduled_at = train_run_way_point.scheduled_at - position = train_run_way_point.way_point.position - direction = train_run_way_point.way_point.direction - next_position = next_train_run_way_point.way_point.position - next_direction = next_train_run_way_point.way_point.direction + scheduled_at = trainrun_waypoint.scheduled_at + position = trainrun_waypoint.waypoint.position + direction = trainrun_waypoint.waypoint.direction + next_position = next_trainrun_waypoint.waypoint.position + next_direction = next_trainrun_waypoint.waypoint.direction # add intial do nothing if we do not enter immediately, actually not necessary if scheduled_at > 0: @@ -251,12 +251,12 @@ class ControllerFromTrainruns(): self.env.rail) # if the agent is blocked in the cell, we have to call stop upon entering! - if next_train_run_way_point.scheduled_at > scheduled_at + 1 + minimum_cell_time: + if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time: action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING) action_plan.append(action) # execute the action exactly minimum_cell_time before the entry into the next cell - action = ActionPlanElement(next_train_run_way_point.scheduled_at - minimum_cell_time, next_action) + action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action) action_plan.append(action) @@ -277,10 +277,10 @@ class ControllerFromTrainrunsReplayer(): i = 0 while not env.dones['__all__'] and i <= env._max_episode_steps: for agent_id, agent in enumerate(env.agents): - way_point: Waypoint = ctl.get_way_point_before_or_at_step(agent_id, i) - assert agent.position == way_point.position, \ + waypoint: Waypoint = ctl.get_waypoint_before_or_at_step(agent_id, i) + assert agent.position == waypoint.position, \ "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position, - way_point.position) + waypoint.position) actions = ctl.act(i) print("actions for {}: {}".format(i, actions)) diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index 772ad7a79f42e779e356381c13b7eeac93dd21be..6bfb4bb558f135388b41ee2b830f74984e62eddc 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -10,7 +10,7 @@ from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.distance_map import DistanceMap from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions, RailEnv -from flatland.envs.rail_train_run_data_structures import Waypoint +from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.utils.ordered_set import OrderedSet diff --git a/flatland/envs/rail_train_run_data_structures.py b/flatland/envs/rail_trainrun_data_structures.py similarity index 96% rename from flatland/envs/rail_train_run_data_structures.py rename to flatland/envs/rail_trainrun_data_structures.py index 3dddb5b0ea45f405f7e22527b44479581304bb35..5de955b3e78ad4be7145d1f514ca34eca9557894 100644 --- a/flatland/envs/rail_train_run_data_structures.py +++ b/flatland/envs/rail_trainrun_data_structures.py @@ -12,7 +12,7 @@ Waypoint = NamedTuple('Waypoint', [('position', Tuple[int, int]), ('direction', # The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md TrainrunWaypoint = NamedTuple('TrainrunWaypoint', [ ('scheduled_at', int), - ('way_point', Waypoint) + ('waypoint', Waypoint) ]) # A train run is the list of an agent's way points and their scheduled time Trainrun = List[TrainrunWaypoint] diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 914c3d8d8e8ebb1361d0baca31dbd36c72322593..f7e087d54111f27af8d0668122893ab7dad20644 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -4,7 +4,7 @@ from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.rail_train_run_data_structures import Waypoint +from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.simple_rail import make_simple_rail @@ -32,25 +32,25 @@ def test_action_plan(rendering: bool = False): for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) - chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, way_point=Waypoint(position=(3, 0), direction=3)), - TrainrunWaypoint(scheduled_at=2, way_point=Waypoint(position=(3, 1), direction=1)), - TrainrunWaypoint(scheduled_at=3, way_point=Waypoint(position=(3, 2), direction=1)), - TrainrunWaypoint(scheduled_at=14, way_point=Waypoint(position=(3, 3), direction=1)), - TrainrunWaypoint(scheduled_at=15, way_point=Waypoint(position=(3, 4), direction=1)), - TrainrunWaypoint(scheduled_at=16, way_point=Waypoint(position=(3, 5), direction=1)), - TrainrunWaypoint(scheduled_at=17, way_point=Waypoint(position=(3, 6), direction=1)), - TrainrunWaypoint(scheduled_at=18, way_point=Waypoint(position=(3, 7), direction=1)), - TrainrunWaypoint(scheduled_at=19, way_point=Waypoint(position=(3, 8), direction=1)), - TrainrunWaypoint(scheduled_at=20, way_point=Waypoint(position=(3, 8), direction=5))], - 1: [TrainrunWaypoint(scheduled_at=0, way_point=Waypoint(position=(3, 8), direction=3)), - TrainrunWaypoint(scheduled_at=3, way_point=Waypoint(position=(3, 7), direction=3)), - TrainrunWaypoint(scheduled_at=5, way_point=Waypoint(position=(3, 6), direction=3)), - TrainrunWaypoint(scheduled_at=7, way_point=Waypoint(position=(3, 5), direction=3)), - TrainrunWaypoint(scheduled_at=9, way_point=Waypoint(position=(3, 4), direction=3)), - TrainrunWaypoint(scheduled_at=11, way_point=Waypoint(position=(3, 3), direction=3)), - TrainrunWaypoint(scheduled_at=13, way_point=Waypoint(position=(2, 3), direction=0)), - TrainrunWaypoint(scheduled_at=15, way_point=Waypoint(position=(1, 3), direction=0)), - TrainrunWaypoint(scheduled_at=17, way_point=Waypoint(position=(0, 3), direction=0))]} + chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), + TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)), + TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)), + TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)), + TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)), + TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)), + TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)), + TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)), + TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)), + TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))], + 1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)), + TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)), + TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)), + TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)), + TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)), + TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)), + TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)), + TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)), + TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]} expected_action_plan = [[ # take action to enter the grid ActionPlanElement(0, RailEnvActions.MOVE_FORWARD), diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index 30ad57f72235ad100fa4d1983ca59b1a5a3c51ea..3aa9d8a54e9e8db628f4fbc1c9a0e8db4f1b0305 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -11,7 +11,7 @@ from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPred from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.rail_train_run_data_structures import Waypoint +from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make_invalid_simple_rail diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 1143c20b5bc06102d2f978e14897a9b06728ecfd..effa1f866cda1c68d231e74cff7829e212cddba1 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -8,7 +8,7 @@ from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths from flatland.envs.rail_env_utils import load_flatland_environment_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.rail_train_run_data_structures import Waypoint +from flatland.envs.rail_trainrun_data_structures import Waypoint from flatland.envs.schedule_generators import random_schedule_generator from flatland.utils.rendertools import RenderTool from flatland.utils.simple_rail import make_disconnected_simple_rail, make_simple_rail_with_alternatives