From d4e6af1c8f7f2b47e873e1e6c6c3169d382374d1 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Mon, 11 Nov 2019 19:20:25 -0500
Subject: [PATCH] SIM-119 refactoring ActionPlan

---
 flatland/action_plan/__init__.py    |   0
 flatland/action_plan/action_plan.py | 304 ++++++++++++++++++++++++++++
 tests/test_action_plan.py           |  93 +++++++++
 3 files changed, 397 insertions(+)
 create mode 100644 flatland/action_plan/__init__.py
 create mode 100644 flatland/action_plan/action_plan.py
 create mode 100644 tests/test_action_plan.py

diff --git a/flatland/action_plan/__init__.py b/flatland/action_plan/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/flatland/action_plan/action_plan.py b/flatland/action_plan/action_plan.py
new file mode 100644
index 00000000..017b9551
--- /dev/null
+++ b/flatland/action_plan/action_plan.py
@@ -0,0 +1,304 @@
+import pprint
+from typing import Dict, List, Optional, NamedTuple
+
+import numpy as np
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_env_shortest_paths import WalkingElement, get_action_for_move
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+
+
+#---- Input Data Structures (graph representation)  ---------------------------------------------
+#  A cell pin represents the one of the four pins in which the cell at row,column may be entered.
+CellPin = NamedTuple('CellPin', [('r', int), ('c', int), ('d', int)])
+
+# A path schedule element represents the entry time of agent at a cell pin.
+PathScheduleElement = NamedTuple('PathScheduleElement', [
+    ('scheduled_at', int),
+    ('cell_pin', CellPin)
+])
+# A path schedule is the list of an agent's cell pin entries
+PathSchedule = List[PathScheduleElement]
+
+
+#---- Output Data Structures (FLATland representation) ---------------------------------------------
+# An action plan element represents the actions to be taken by an agent at deterministic time steps
+#  plus the position before the action
+ActionPlanElement = NamedTuple('ActionPlanElement', [
+    ('scheduled_at', int),
+    ('walking_element', WalkingElement)
+])
+# An action plan deterministically represents all the actions to be taken by an agent
+#  plus its position before the actions are taken
+ActionPlan = Dict[int, List[ActionPlanElement]]
+
+
+
+class ActionPlanReplayer():
+    """Allows to deduce an `ActionPlan` from the agents' `PathSchedule` and
+    to be replayed/verified in a FLATland env without malfunction."""
+
+    pp = pprint.PrettyPrinter(indent=4)
+
+    def __init__(self,
+                 env: RailEnv,
+                 chosen_path_dict: Dict[int, PathSchedule]):
+
+        self.env = env
+        self.action_plan = [[] for _ in range(self.env.get_num_agents())]
+
+        for agent_id, chosen_path in chosen_path_dict.items():
+            self._add_aggent_to_action_plan(self.action_plan, agent_id, chosen_path)
+
+    def get_walking_element_before_or_at_step(self, agent_id: int, step: int) -> WalkingElement:
+        """
+        Get the walking element from which the current position can be extracted.
+
+        Parameters
+        ----------
+        agent_id
+        step
+
+        Returns
+        -------
+        WalkingElement
+
+        """
+        walking_element = None
+        for action in self.action_plan[agent_id]:
+            if step < action.scheduled_at:
+                return walking_element
+            if step >= action.scheduled_at:
+                walking_element = action.walking_element
+        assert walking_element is not None
+        return walking_element
+
+    def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
+        """
+        Get the current action if any is defined in the `ActionPlan`.
+
+        Parameters
+        ----------
+        agent_id
+        current_step
+
+        Returns
+        -------
+        WalkingElement, optional
+
+        """
+        for action_plan_step in self.action_plan[agent_id]:
+            action_plan_step: ActionPlanElement = action_plan_step
+            scheduled_at = action_plan_step.scheduled_at
+            walking_element: WalkingElement = action_plan_step.walking_element
+            if scheduled_at > current_step:
+                return None
+            elif np.isclose(current_step, scheduled_at):
+                return walking_element.next_action
+        return None
+
+    def get_action_dict_for_step_replay(self, current_step: int) -> Dict[int, RailEnvActions]:
+        """
+        Get the action dictionary to be replayed at the current step.
+
+        Parameters
+        ----------
+        current_step: int
+
+        Returns
+        -------
+        Dict[int, RailEnvActions]
+
+        """
+        action_dict = {}
+        for agent_id, agent in enumerate(self.env.agents):
+            action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step)
+            if action is not None:
+                action_dict[agent_id] = action
+        return action_dict
+
+    def replay_verify(self, MAX_EPISODE_STEPS: int, env: RailEnv, rendering: bool):
+        """Replays this deterministic `ActionPlan` and verifies whether it is feasible."""
+        if rendering:
+            renderer = RenderTool(env, gl="PILSVG",
+                                  agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
+                                  show_debug=True,
+                                  clear_debug_text=True,
+                                  screen_height=1000,
+                                  screen_width=1000)
+            renderer.render_env(show=True, show_observations=False, show_predictions=False)
+        i = 0
+        while not env.dones['__all__'] and i <= MAX_EPISODE_STEPS:
+            for agent_id, agent in enumerate(env.agents):
+                walking_element: WalkingElement = self.get_walking_element_before_or_at_step(agent_id, i)
+                assert agent.position == walking_element.position, \
+                    "before {}, agent {} at {}, expected {}".format(i, agent_id, agent.position,
+                                                                    walking_element.position)
+            actions = self.get_action_dict_for_step_replay(i)
+            print("actions for {}: {}".format(i, actions))
+
+            obs, all_rewards, done, _ = env.step(actions)
+
+            if rendering:
+                renderer.render_env(show=True, show_observations=False, show_predictions=False)
+
+            i += 1
+
+    def print_action_plan(self):
+        for agent_id, plan in enumerate(self.action_plan):
+            print("{}: ".format(agent_id))
+            for step in plan:
+                print("  {}".format(step))
+
+    @staticmethod
+    def compare_action_plans(expected_action_plan: ActionPlan, actual_action_plan: ActionPlan):
+        assert len(expected_action_plan) == len(actual_action_plan)
+        for k in range(len(expected_action_plan)):
+            assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \
+                "len for agent {} should be the same.\n\n  expected ({}) = {}\n\n  actual ({}) = {}".format(
+                    k,
+                    len(expected_action_plan[k]),
+                    ActionPlanReplayer.pp.pformat(expected_action_plan[k]),
+                    len(actual_action_plan[k]),
+                    ActionPlanReplayer.pp.pformat(actual_action_plan[k]))
+            for i in range(len(expected_action_plan[k])):
+                assert expected_action_plan[k][i] == actual_action_plan[k][i], \
+                    "not the same at agent {} at step {}\n\n  expected = {}\n\n  actual = {}".format(
+                        k, i,
+                        ActionPlanReplayer.pp.pformat(expected_action_plan[k][i]),
+                        ActionPlanReplayer.pp.pformat(actual_action_plan[k][i]))
+
+    def _add_aggent_to_action_plan(self, action_plan, agent_id, agent_path_new):
+        agent = self.env.agents[agent_id]
+        minimum_cell_time = int(np.ceil(1.0 / agent.speed_data['speed']))
+        for path_loop, path_schedule_element in enumerate(agent_path_new):
+            path_schedule_element: PathScheduleElement = path_schedule_element
+
+            position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c)
+
+            if Vec2d.is_equal(agent.target, position):
+                break
+
+            next_path_schedule_element: PathScheduleElement = agent_path_new[path_loop + 1]
+            next_position = (next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c)
+
+            if path_loop == 0:
+                self._create_action_plan_for_first_path_element_of_agent(
+                    action_plan,
+                    agent_id,
+                    path_schedule_element,
+                    next_path_schedule_element)
+                continue
+
+            just_before_target = Vec2d.is_equal(agent.target, next_position)
+
+            self._create_action_plan_for_current_path_element(
+                action_plan,
+                agent_id,
+                minimum_cell_time,
+                path_schedule_element,
+                next_path_schedule_element)
+
+            # add a final element
+            if just_before_target:
+                self._create_action_plan_for_target_at_path_element_just_before_target(
+                    action_plan,
+                    agent_id,
+                    minimum_cell_time,
+                    path_schedule_element,
+                    next_path_schedule_element)
+
+    def _create_action_plan_for_current_path_element(self,
+                                                     action_plan: ActionPlan,
+                                                     agent_id: int,
+                                                     minimum_cell_time: int,
+                                                     path_schedule_element: PathScheduleElement,
+                                                     next_path_schedule_element: PathScheduleElement):
+        scheduled_at = path_schedule_element.scheduled_at
+        next_entry_value = next_path_schedule_element.scheduled_at
+
+        position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c)
+        direction = path_schedule_element.cell_pin.d
+        next_position = next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c
+        next_direction = next_path_schedule_element.cell_pin.d
+        next_action = get_action_for_move(position,
+                                          direction,
+                                          next_position,
+                                          next_direction,
+                                          self.env.rail)
+
+        walking_element = WalkingElement(position, direction, next_action)
+
+        # if the next entry is later than minimum_cell_time, then stop here and
+        # move minimum_cell_time before the exit
+        # we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
+        if next_entry_value > scheduled_at + minimum_cell_time:
+            action = ActionPlanElement(scheduled_at,
+                                       WalkingElement(
+                                           position=position,
+                                           direction=direction,
+                                           next_action=RailEnvActions.STOP_MOVING))
+            action_plan[agent_id].append(action)
+
+            action = ActionPlanElement(next_entry_value - minimum_cell_time, walking_element)
+            action_plan[agent_id].append(action)
+        else:
+            action = ActionPlanElement(scheduled_at, walking_element)
+            action_plan[agent_id].append(action)
+
+    def _create_action_plan_for_target_at_path_element_just_before_target(self,
+                                                                          action_plan: ActionPlan,
+                                                                          agent_id: int,
+                                                                          minimum_cell_time: int,
+                                                                          path_schedule_element: PathScheduleElement,
+                                                                          next_path_schedule_element: PathScheduleElement):
+        scheduled_at = path_schedule_element.scheduled_at
+        next_path_schedule_element.cell_pin
+
+        action = ActionPlanElement(scheduled_at + minimum_cell_time,
+                                   WalkingElement(
+                                       position=None,
+                                       direction=next_path_schedule_element.cell_pin.d,
+                                       next_action=RailEnvActions.STOP_MOVING))
+        action_plan[agent_id].append(action)
+
+    def _create_action_plan_for_first_path_element_of_agent(self,
+                                                            action_plan: ActionPlan,
+                                                            agent_id: int,
+                                                            path_schedule_element: PathScheduleElement,
+                                                            next_path_schedule_element: PathScheduleElement):
+        scheduled_at = path_schedule_element.scheduled_at
+        position = (path_schedule_element.cell_pin.r, path_schedule_element.cell_pin.c)
+        direction = path_schedule_element.cell_pin.d
+        next_position = next_path_schedule_element.cell_pin.r, next_path_schedule_element.cell_pin.c
+        next_direction = next_path_schedule_element.cell_pin.d
+
+        # add intial do nothing if we do not enter immediately
+        if scheduled_at > 0:
+            action = ActionPlanElement(0,
+                                       WalkingElement(
+                                           position=None,
+                                           direction=direction,
+                                           next_action=RailEnvActions.DO_NOTHING))
+            action_plan[agent_id].append(action)
+        # add action to enter the grid
+        action = ActionPlanElement(scheduled_at,
+                                   WalkingElement(
+                                       position=None,
+                                       direction=direction,
+                                       next_action=RailEnvActions.MOVE_FORWARD))
+        action_plan[agent_id].append(action)
+
+        next_action = get_action_for_move(position,
+                                          direction,
+                                          next_position,
+                                          next_direction,
+                                          self.env.rail)
+
+        # now, we have a position need to perform the action
+        action = ActionPlanElement(scheduled_at + 1,
+                                   WalkingElement(
+                                       position=position,
+                                       direction=direction,
+                                       next_action=next_action))
+        action_plan[agent_id].append(action)
diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py
new file mode 100644
index 00000000..876a7a87
--- /dev/null
+++ b/tests/test_action_plan.py
@@ -0,0 +1,93 @@
+from flatland.action_plan.action_plan import PathScheduleElement, CellPin, ActionPlanReplayer
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_env_shortest_paths import WalkingElement
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.simple_rail import make_simple_rail
+
+
+
+def test_action_plan(rendering: bool = False):
+    """Tests ActionPlanReplayer: does action plan generation and replay work as expected."""
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=77),
+                  number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  remove_agents_at_target=True
+                  )
+    env.reset()
+    env.agents[0].initial_position = (3, 0)
+    env.agents[0].target = (3, 8)
+    env.agents[0].initial_direction = Grid4TransitionsEnum.WEST
+    env.agents[1].initial_position = (3, 8)
+    env.agents[1].initial_direction = Grid4TransitionsEnum.WEST
+    env.agents[1].target = (0, 3)
+    env.agents[1].speed_data['speed'] = 0.5  # two
+    env.reset(False, False, False)
+    for handle, agent in enumerate(env.agents):
+        print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target))
+
+    chosen_path_dict = {0: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=0, d=3)),
+                            PathScheduleElement(scheduled_at=2, cell_pin=CellPin(r=3, c=1, d=1)),
+                            PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=2, d=1)),
+                            PathScheduleElement(scheduled_at=14, cell_pin=CellPin(r=3, c=3, d=1)),
+                            PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=3, c=4, d=1)),
+                            PathScheduleElement(scheduled_at=16, cell_pin=CellPin(r=3, c=5, d=1)),
+                            PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=3, c=6, d=1)),
+                            PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=3, c=7, d=1)),
+                            PathScheduleElement(scheduled_at=19, cell_pin=CellPin(r=3, c=8, d=1)),
+                            PathScheduleElement(scheduled_at=20, cell_pin=CellPin(r=3, c=8, d=5))],
+                        1: [PathScheduleElement(scheduled_at=0, cell_pin=CellPin(r=3, c=8, d=3)),
+                            PathScheduleElement(scheduled_at=3, cell_pin=CellPin(r=3, c=7, d=3)),
+                            PathScheduleElement(scheduled_at=5, cell_pin=CellPin(r=3, c=6, d=3)),
+                            PathScheduleElement(scheduled_at=7, cell_pin=CellPin(r=3, c=5, d=3)),
+                            PathScheduleElement(scheduled_at=9, cell_pin=CellPin(r=3, c=4, d=3)),
+                            PathScheduleElement(scheduled_at=11, cell_pin=CellPin(r=3, c=3, d=3)),
+                            PathScheduleElement(scheduled_at=13, cell_pin=CellPin(r=2, c=3, d=0)),
+                            PathScheduleElement(scheduled_at=15, cell_pin=CellPin(r=1, c=3, d=0)),
+                            PathScheduleElement(scheduled_at=17, cell_pin=CellPin(r=0, c=3, d=0)),
+                            PathScheduleElement(scheduled_at=18, cell_pin=CellPin(r=0, c=3, d=5))]}
+    expected_action_plan = [[
+        # take action to enter the grid
+        (0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        # take action to enter the cell properly
+        (1, WalkingElement(position=(3, 0), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        (2, WalkingElement(position=(3, 1), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
+        (3, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.STOP_MOVING)),
+        (13, WalkingElement(position=(3, 2), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
+        (14, WalkingElement(position=(3, 3), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
+        (15, WalkingElement(position=(3, 4), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
+        (16, WalkingElement(position=(3, 5), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
+        (17, WalkingElement(position=(3, 6), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
+        (18, WalkingElement(position=(3, 7), direction=1, next_action=RailEnvActions.MOVE_FORWARD)),
+        (19, WalkingElement(position=None, direction=1, next_action=RailEnvActions.STOP_MOVING))
+
+    ], [
+        (0, WalkingElement(position=None, direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        (1, WalkingElement(position=(3, 8), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        (3, WalkingElement(position=(3, 7), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        (5, WalkingElement(position=(3, 6), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        (7, WalkingElement(position=(3, 5), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        (9, WalkingElement(position=(3, 4), direction=3, next_action=RailEnvActions.MOVE_FORWARD)),
+        (11, WalkingElement(position=(3, 3), direction=3, next_action=RailEnvActions.MOVE_RIGHT)),
+        (13, WalkingElement(position=(2, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)),
+        (15, WalkingElement(position=(1, 3), direction=0, next_action=RailEnvActions.MOVE_FORWARD)),
+        (17, WalkingElement(position=None, direction=0, next_action=RailEnvActions.STOP_MOVING)),
+
+    ]]
+
+    MAX_EPISODE_STEPS = 50
+
+    actual_action_plan = ActionPlanReplayer(env, chosen_path_dict)
+    actual_action_plan.print_action_plan()
+    ActionPlanReplayer.compare_action_plans(expected_action_plan, actual_action_plan.action_plan)
+    assert actual_action_plan.action_plan == expected_action_plan, \
+        "expected {}, found {}".format(expected_action_plan, actual_action_plan.action_plan)
+
+    actual_action_plan.replay_verify(MAX_EPISODE_STEPS, env, rendering)
-- 
GitLab