From 5f24b5fe6891511b663e2eb96cc3955761f624cd Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 12 Nov 2019 20:40:44 -0500 Subject: [PATCH] SIM-119 refactoring review from Erik and Christian B. --- flatland/action_plan/action_plan.py | 156 +++++++++--------- flatland/envs/rail_env_shortest_paths.py | 9 +- .../envs/rail_train_run_data_structures.py | 7 +- tests/test_action_plan.py | 8 +- ...t_flatland_envs_rail_env_shortest_paths.py | 7 +- 5 files changed, 92 insertions(+), 95 deletions(-) diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py index 7c103cd5..59c4f1f3 100644 --- a/flatland/action_plan/action_plan.py +++ b/flatland/action_plan/action_plan.py @@ -10,15 +10,16 @@ from flatland.envs.rail_train_run_data_structures import WayPoint, TrainRun, Tra from flatland.utils.rendertools import RenderTool, AgentRenderVariant # ---- ActionPlan --------------- -# represents the actions to be taken by an agent at deterministic time steps -# plus the position before the action +# an action plan element represents the actions to be taken by an agent at the given time step ActionPlanElement = NamedTuple('ActionPlanElement', [ ('scheduled_at', int), ('action', RailEnvActions) ]) -# An action plan deterministically represents all the actions to be taken by an agent -# plus its position before the actions are taken -ActionPlan = Dict[int, List[ActionPlanElement]] +# an action plan gathers all the the actions to be taken by a single agent at the corresponding time steps +ActionPlan = List[ActionPlanElement] + +# An action plan dict gathers all the actions for every agent identified by the dictionary key +ActionPlanDict = Dict[int, ActionPlan] class DeterministicController(): @@ -30,10 +31,8 @@ class DeterministicController(): self.env = env self.train_run_dict: Dict[int, TrainRun] = train_run_dict - self.action_plan = [[] for _ in range(self.env.get_num_agents())] - - for agent_id, chosen_path in train_run_dict.items(): - self._add_aggent_to_action_plan(self.action_plan, agent_id, chosen_path) + self.action_plan = [self._create_action_plan_for_agent(agent_id, chosen_path) + for agent_id, chosen_path in train_run_dict.items()] def get_way_point_before_or_at_step(self, agent_id: int, step: int) -> WayPoint: """ @@ -84,13 +83,13 @@ class DeterministicController(): WalkingElement, optional """ - for action_plan_step in self.action_plan[agent_id]: - action_plan_step: ActionPlanElement = action_plan_step - scheduled_at = action_plan_step.scheduled_at + for action_plan_element in self.action_plan[agent_id]: + action_plan_element: ActionPlanElement = action_plan_element + scheduled_at = action_plan_element.scheduled_at if scheduled_at > current_step: return None - elif np.isclose(current_step, scheduled_at): - return action_plan_step.action + elif current_step == scheduled_at: + return action_plan_element.action return None def act(self, current_step: int) -> Dict[int, RailEnvActions]: @@ -107,7 +106,7 @@ class DeterministicController(): """ action_dict = {} - for agent_id, agent in enumerate(self.env.agents): + for agent_id in range(len(self.env.agents)): action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step) if action is not None: action_dict[agent_id] = action @@ -120,76 +119,76 @@ class DeterministicController(): print(" {}".format(step)) @staticmethod - def compare_action_plans(expected_action_plan: ActionPlan, actual_action_plan: ActionPlan): + def compare_action_plans(expected_action_plan: ActionPlanDict, actual_action_plan: ActionPlanDict): assert len(expected_action_plan) == len(actual_action_plan) for k in range(len(expected_action_plan)): assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \ "len for agent {} should be the same.\n\n expected ({}) = {}\n\n actual ({}) = {}".format( k, len(expected_action_plan[k]), - DeterministicControllerReplayer.pp.pformat(expected_action_plan[k]), + DeterministicController.pp.pformat(expected_action_plan[k]), len(actual_action_plan[k]), - DeterministicControllerReplayer.pp.pformat(actual_action_plan[k])) + DeterministicController.pp.pformat(actual_action_plan[k])) for i in range(len(expected_action_plan[k])): assert expected_action_plan[k][i] == actual_action_plan[k][i], \ "not the same at agent {} at step {}\n\n expected = {}\n\n actual = {}".format( k, i, - DeterministicControllerReplayer.pp.pformat(expected_action_plan[k][i]), - DeterministicControllerReplayer.pp.pformat(actual_action_plan[k][i])) + DeterministicController.pp.pformat(expected_action_plan[k][i]), + DeterministicController.pp.pformat(actual_action_plan[k][i])) + assert expected_action_plan == actual_action_plan, \ + "expected {}, found {}".format(expected_action_plan, actual_action_plan) - def _add_aggent_to_action_plan(self, action_plan, agent_id, agent_path_new): + def _create_action_plan_for_agent(self, agent_id, agent_path_new) -> ActionPlan: + action_plan = [] agent = self.env.agents[agent_id] minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed'])) - for path_loop, path_schedule_element in enumerate(agent_path_new): - path_schedule_element: TrainRunWayPoint = path_schedule_element + for path_loop, train_run_way_point in enumerate(agent_path_new): + train_run_way_point: TrainRunWayPoint = train_run_way_point - position = path_schedule_element.way_point.position + position = train_run_way_point.way_point.position if Vec2d.is_equal(agent.target, position): break - next_path_schedule_element: TrainRunWayPoint = agent_path_new[path_loop + 1] - next_position = next_path_schedule_element.way_point.position + next_train_run_way_point: TrainRunWayPoint = agent_path_new[path_loop + 1] + next_position = next_train_run_way_point.way_point.position if path_loop == 0: - self._create_action_plan_for_first_path_element_of_agent( + self._add_action_plan_elements_for_first_path_element_of_agent( action_plan, - agent_id, - path_schedule_element, - next_path_schedule_element) + train_run_way_point, + next_train_run_way_point) continue just_before_target = Vec2d.is_equal(agent.target, next_position) - self._create_action_plan_for_current_path_element( + self._add_action_plan_elements_for_current_path_element( action_plan, - agent_id, minimum_cell_time, - path_schedule_element, - next_path_schedule_element) + train_run_way_point, + next_train_run_way_point) # add a final element if just_before_target: - self._create_action_plan_for_target_at_path_element_just_before_target( + self._add_action_plan_elements_for_target_at_path_element_just_before_target( action_plan, - agent_id, minimum_cell_time, - path_schedule_element, - next_path_schedule_element) - - def _create_action_plan_for_current_path_element(self, - action_plan: ActionPlan, - agent_id: int, - minimum_cell_time: int, - path_schedule_element: TrainRunWayPoint, - next_path_schedule_element: TrainRunWayPoint): - scheduled_at = path_schedule_element.scheduled_at - next_entry_value = next_path_schedule_element.scheduled_at - - position = path_schedule_element.way_point.position - direction = path_schedule_element.way_point.direction - next_position = next_path_schedule_element.way_point.position - next_direction = next_path_schedule_element.way_point.direction + train_run_way_point, + next_train_run_way_point) + return action_plan + + def _add_action_plan_elements_for_current_path_element(self, + action_plan: ActionPlan, + minimum_cell_time: int, + train_run_way_point: TrainRunWayPoint, + next_train_run_way_point: TrainRunWayPoint): + scheduled_at = train_run_way_point.scheduled_at + next_entry_value = next_train_run_way_point.scheduled_at + + position = train_run_way_point.way_point.position + direction = train_run_way_point.way_point.direction + next_position = next_train_run_way_point.way_point.position + next_direction = next_train_run_way_point.way_point.direction next_action = get_action_for_move(position, direction, next_position, @@ -201,44 +200,41 @@ class DeterministicController(): # we have to do this since agents in the RailEnv are processed in the step() in the order of their handle if next_entry_value > scheduled_at + minimum_cell_time: action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING) - action_plan[agent_id].append(action) + action_plan.append(action) action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action) - action_plan[agent_id].append(action) + action_plan.append(action) else: action = ActionPlanElement(scheduled_at, next_action) - action_plan[agent_id].append(action) + action_plan.append(action) - def _create_action_plan_for_target_at_path_element_just_before_target(self, - action_plan: ActionPlan, - agent_id: int, - minimum_cell_time: int, - path_schedule_element: TrainRunWayPoint, - next_path_schedule_element: TrainRunWayPoint): - scheduled_at = path_schedule_element.scheduled_at - next_path_schedule_element.way_point + def _add_action_plan_elements_for_target_at_path_element_just_before_target(self, + action_plan: ActionPlan, + minimum_cell_time: int, + train_run_way_point: TrainRunWayPoint, + next_train_run_way_point: TrainRunWayPoint): + scheduled_at = train_run_way_point.scheduled_at action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING) - action_plan[agent_id].append(action) - - def _create_action_plan_for_first_path_element_of_agent(self, - action_plan: ActionPlan, - agent_id: int, - path_schedule_element: TrainRunWayPoint, - next_path_schedule_element: TrainRunWayPoint): - scheduled_at = path_schedule_element.scheduled_at - position = path_schedule_element.way_point.position - direction = path_schedule_element.way_point.direction - next_position = next_path_schedule_element.way_point.position - next_direction = next_path_schedule_element.way_point.direction + action_plan.append(action) + + def _add_action_plan_elements_for_first_path_element_of_agent(self, + action_plan: ActionPlan, + train_run_way_point: TrainRunWayPoint, + next_train_run_way_point: TrainRunWayPoint): + scheduled_at = train_run_way_point.scheduled_at + position = train_run_way_point.way_point.position + direction = train_run_way_point.way_point.direction + next_position = next_train_run_way_point.way_point.position + next_direction = next_train_run_way_point.way_point.direction # add intial do nothing if we do not enter immediately if scheduled_at > 0: action = ActionPlanElement(0, RailEnvActions.DO_NOTHING) - action_plan[agent_id].append(action) + action_plan.append(action) # add action to enter the grid action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD) - action_plan[agent_id].append(action) + action_plan.append(action) next_action = get_action_for_move(position, direction, @@ -248,14 +244,14 @@ class DeterministicController(): # now, we have a position need to perform the action action = ActionPlanElement(scheduled_at + 1, next_action) - action_plan[agent_id].append(action) + action_plan.append(action) class DeterministicControllerReplayer(): """Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction.""" @staticmethod - def replay_verify(MAX_EPISODE_STEPS: int, ctl: DeterministicController, env: RailEnv, rendering: bool): + def replay_verify(max_episode_steps: int, ctl: DeterministicController, env: RailEnv, rendering: bool): """Replays this deterministic `ActionPlan` and verifies whether it is feasible.""" if rendering: renderer = RenderTool(env, gl="PILSVG", @@ -266,7 +262,7 @@ class DeterministicControllerReplayer(): screen_width=1000) renderer.render_env(show=True, show_observations=False, show_predictions=False) i = 0 - while not env.dones['__all__'] and i <= MAX_EPISODE_STEPS: + while not env.dones['__all__'] and i <= max_episode_steps: for agent_id, agent in enumerate(env.agents): way_point: WayPoint = ctl.get_way_point_before_or_at_step(agent_id, i) assert agent.position == way_point.position, \ diff --git a/flatland/envs/rail_env_shortest_paths.py b/flatland/envs/rail_env_shortest_paths.py index ffbd9e81..d5f91cb0 100644 --- a/flatland/envs/rail_env_shortest_paths.py +++ b/flatland/envs/rail_env_shortest_paths.py @@ -137,6 +137,8 @@ def get_action_for_move( """ Get the action (if any) to move from a position and direction to another. + The implementation could probably be more efficient I believe. But given the few calls this has no priority now. + Parameters ---------- agent_position @@ -274,7 +276,6 @@ def get_k_shortest_paths(env: RailEnv, Computes the k shortest paths using modified Dijkstra following pseudo-code https://en.wikipedia.org/wiki/K_shortest_path_routing In contrast to the pseudo-code in wikipedia, we do not a allow for loopy paths. - We add the next_action Parameters ---------- @@ -313,12 +314,12 @@ def get_k_shortest_paths(env: RailEnv, if debug: print("iteration heap={}, shortest_paths={}".format(heap, shortest_paths)) # – let Pu be the shortest cost path in B with cost C - c = np.inf + cost = np.inf pu = None for path in heap: - if len(path) < c: + if len(path) < cost: pu = path - c = len(path) + cost = len(path) u: WayPoint = pu[-1] if debug: print(" looking at pu={}".format(pu)) diff --git a/flatland/envs/rail_train_run_data_structures.py b/flatland/envs/rail_train_run_data_structures.py index 775b183b..aeb34a00 100644 --- a/flatland/envs/rail_train_run_data_structures.py +++ b/flatland/envs/rail_train_run_data_structures.py @@ -1,8 +1,13 @@ from typing import NamedTuple, Tuple, List +# A way point is the entry into a cell defined by +# - the row and column coordinates of the cell entered +# - direction, in which the agent is facing to enter the cell. +# This induces a graph on top of the FLATland cells: +# - four possible way points per cell +# - edges are the possible transitions in the cell. WayPoint = NamedTuple('WayPoint', [('position', Tuple[int, int]), ('direction', int)]) - # A train run is represented by the waypoints traversed and the times of traversal # The terminology follows https://github.com/crowdAI/train-schedule-optimisation-challenge-starter-kit/blob/master/documentation/output_data_model.md TrainRunWayPoint = NamedTuple('TrainRunWayPoint', [ diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 91c64d9b..323d653f 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -1,8 +1,7 @@ from flatland.action_plan.action_plan import TrainRunWayPoint, DeterministicControllerReplayer, ActionPlanElement, \ DeterministicController from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_train_run_data_structures import WayPoint @@ -18,7 +17,7 @@ def test_action_plan(rendering: bool = False): rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=77), number_of_agents=2, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=True ) env.reset() @@ -86,7 +85,4 @@ def test_action_plan(rendering: bool = False): deterministic_controller = DeterministicController(env, chosen_path_dict) deterministic_controller.print_action_plan() DeterministicController.compare_action_plans(expected_action_plan, deterministic_controller.action_plan) - assert deterministic_controller.action_plan == expected_action_plan, \ - "expected {}, found {}".format(expected_action_plan, deterministic_controller.action_plan) - DeterministicControllerReplayer.replay_verify(MAX_EPISODE_STEPS, deterministic_controller, env, rendering) diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 5f17bb3d..ffd26cfb 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -3,8 +3,7 @@ import sys import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import DummyPredictorForRailEnv +from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env_shortest_paths import get_shortest_paths, get_k_shortest_paths from flatland.envs.rail_env_utils import load_flatland_environment_from_file @@ -20,7 +19,7 @@ def test_get_shortest_paths_unreachable(): env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10))) + obs_builder_object=GlobalObsForRailEnv()) env.reset() # set the initial position @@ -215,7 +214,7 @@ def test_get_k_shortest_paths(rendering=False): rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), + obs_builder_object=GlobalObsForRailEnv(), ) env.reset() -- GitLab