From f0149144454e558e05843b263a0ca562dbf48989 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 13 Nov 2019 09:36:50 -0500 Subject: [PATCH] SIM-119 refactoring review from Erik --- flatland/action_plan/action_plan.py | 23 ++++++++++++----------- tests/test_action_plan.py | 10 +++++----- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py index 5edad878..32171be1 100644 --- a/flatland/action_plan/action_plan.py +++ b/flatland/action_plan/action_plan.py @@ -22,7 +22,8 @@ ActionPlan = List[ActionPlanElement] ActionPlanDict = Dict[int, ActionPlan] -class DeterministicController(): +class ControllerFromTrainRuns(): + """Takes train runs, derives the actions from it and re-acts them.""" pp = pprint.PrettyPrinter(indent=4) def __init__(self, @@ -36,7 +37,7 @@ class DeterministicController(): def get_way_point_before_or_at_step(self, agent_id: int, step: int) -> WayPoint: """ - Get the walking element from which the current position can be extracted. + Get the way point point from which the current position can be extracted. Parameters ---------- @@ -126,23 +127,23 @@ class DeterministicController(): "len for agent {} should be the same.\n\n expected ({}) = {}\n\n actual ({}) = {}".format( k, len(expected_action_plan[k]), - DeterministicController.pp.pformat(expected_action_plan[k]), + ControllerFromTrainRuns.pp.pformat(expected_action_plan[k]), len(actual_action_plan[k]), - DeterministicController.pp.pformat(actual_action_plan[k])) + ControllerFromTrainRuns.pp.pformat(actual_action_plan[k])) for i in range(len(expected_action_plan[k])): assert expected_action_plan[k][i] == actual_action_plan[k][i], \ "not the same at agent {} at step {}\n\n expected = {}\n\n actual = {}".format( k, i, - DeterministicController.pp.pformat(expected_action_plan[k][i]), - DeterministicController.pp.pformat(actual_action_plan[k][i])) + ControllerFromTrainRuns.pp.pformat(expected_action_plan[k][i]), + ControllerFromTrainRuns.pp.pformat(actual_action_plan[k][i])) assert expected_action_plan == actual_action_plan, \ "expected {}, found {}".format(expected_action_plan, actual_action_plan) - def _create_action_plan_for_agent(self, agent_id, agent_path_new) -> ActionPlan: + def _create_action_plan_for_agent(self, agent_id, train_run) -> ActionPlan: action_plan = [] agent = self.env.agents[agent_id] minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed'])) - for path_loop, train_run_way_point in enumerate(agent_path_new): + for path_loop, train_run_way_point in enumerate(train_run): train_run_way_point: TrainRunWayPoint = train_run_way_point position = train_run_way_point.way_point.position @@ -150,7 +151,7 @@ class DeterministicController(): if Vec2d.is_equal(agent.target, position): break - next_train_run_way_point: TrainRunWayPoint = agent_path_new[path_loop + 1] + next_train_run_way_point: TrainRunWayPoint = train_run[path_loop + 1] next_position = next_train_run_way_point.way_point.position if path_loop == 0: @@ -247,11 +248,11 @@ class DeterministicController(): action_plan.append(action) -class DeterministicControllerReplayer(): +class ControllerFromTrainRunsReplayer(): """Allows to verify a `DeterministicController` by replaying it against a FLATland env without malfunction.""" @staticmethod - def replay_verify(max_episode_steps: int, ctl: DeterministicController, env: RailEnv, rendering: bool): + def replay_verify(max_episode_steps: int, ctl: ControllerFromTrainRuns, env: RailEnv, rendering: bool): """Replays this deterministic `ActionPlan` and verifies whether it is feasible.""" if rendering: renderer = RenderTool(env, gl="PILSVG", diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index c6176689..2bd175fb 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -1,5 +1,5 @@ -from flatland.action_plan.action_plan import TrainRunWayPoint, DeterministicControllerReplayer, ActionPlanElement, \ - DeterministicController +from flatland.action_plan.action_plan import TrainRunWayPoint, ControllerFromTrainRunsReplayer, ActionPlanElement, \ + ControllerFromTrainRuns from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions @@ -82,7 +82,7 @@ def test_action_plan(rendering: bool = False): MAX_EPISODE_STEPS = 50 - deterministic_controller = DeterministicController(env, chosen_path_dict) + deterministic_controller = ControllerFromTrainRuns(env, chosen_path_dict) deterministic_controller.print_action_plan() - DeterministicController.assert_actions_plans_equal(expected_action_plan, deterministic_controller.action_plan) - DeterministicControllerReplayer.replay_verify(MAX_EPISODE_STEPS, deterministic_controller, env, rendering) + ControllerFromTrainRuns.assert_actions_plans_equal(expected_action_plan, deterministic_controller.action_plan) + ControllerFromTrainRunsReplayer.replay_verify(MAX_EPISODE_STEPS, deterministic_controller, env, rendering) -- GitLab