action_plan.py 11.8 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
import pprint
from typing import Dict, List, Optional, NamedTuple

import numpy as np
5

u214892's avatar
u214892 committed
6
7
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.rail_env import RailEnv, RailEnvActions
8
from flatland.envs.rail_env_shortest_paths import get_action_for_move
9
from flatland.envs.rail_trainrun_data_structures import Waypoint, Trainrun, TrainrunWaypoint
u214892's avatar
u214892 committed
10

11
# ---- ActionPlan ---------------
12
# an action plan element represents the actions to be taken by an agent at the given time step
u214892's avatar
u214892 committed
13
14
ActionPlanElement = NamedTuple('ActionPlanElement', [
    ('scheduled_at', int),
15
    ('action', RailEnvActions)
u214892's avatar
u214892 committed
16
])
17
18
19
# an action plan gathers all the the actions to be taken by a single agent at the corresponding time steps
ActionPlan = List[ActionPlanElement]

20
# An action plan dict gathers all the actions for every agent identified by the dictionary key = agent_handle
21
ActionPlanDict = Dict[int, ActionPlan]
u214892's avatar
u214892 committed
22
23


24
class ControllerFromTrainruns():
u214892's avatar
u214892 committed
25
    """Takes train runs, derives the actions from it and re-acts them."""
u214892's avatar
u214892 committed
26
27
28
29
    pp = pprint.PrettyPrinter(indent=4)

    def __init__(self,
                 env: RailEnv,
30
                 trainrun_dict: Dict[int, Trainrun]):
u214892's avatar
u214892 committed
31

32
        self.env: RailEnv = env
33
        self.trainrun_dict: Dict[int, Trainrun] = trainrun_dict
34
        self.action_plan: ActionPlanDict = [self._create_action_plan_for_agent(agent_id, chosen_path)
35
                                            for agent_id, chosen_path in trainrun_dict.items()]
u214892's avatar
u214892 committed
36

37
    def get_waypoint_before_or_at_step(self, agent_id: int, step: int) -> Waypoint:
u214892's avatar
u214892 committed
38
        """
u214892's avatar
u214892 committed
39
        Get the way point point from which the current position can be extracted.
u214892's avatar
u214892 committed
40
41
42
43
44
45
46
47
48
49
50

        Parameters
        ----------
        agent_id
        step

        Returns
        -------
        WalkingElement

        """
51
52
        trainrun = self.trainrun_dict[agent_id]
        entry_time_step = trainrun[0].scheduled_at
53
54
        # the agent has no position before and at choosing to enter the grid (one tick elapses before the agent enters the grid)
        if step <= entry_time_step:
55
            return Waypoint(position=None, direction=self.env.agents[agent_id].initial_direction)
56
57

        # the agent has no position as soon as the target is reached
58
        exit_time_step = trainrun[-1].scheduled_at
59
60
        if step >= exit_time_step:
            # agent loses position as soon as target cell is reached
61
            return Waypoint(position=None, direction=trainrun[-1].waypoint.direction)
62

63
64
65
66
67
68
69
70
        waypoint = None
        for trainrun_waypoint in trainrun:
            if step < trainrun_waypoint.scheduled_at:
                return waypoint
            if step >= trainrun_waypoint.scheduled_at:
                waypoint = trainrun_waypoint.waypoint
        assert waypoint is not None
        return waypoint
u214892's avatar
u214892 committed
71
72
73
74

    def get_action_at_step(self, agent_id: int, current_step: int) -> Optional[RailEnvActions]:
        """
        Get the current action if any is defined in the `ActionPlan`.
75
        ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!
u214892's avatar
u214892 committed
76
77
78
79
80
81
82
83
84
85
86

        Parameters
        ----------
        agent_id
        current_step

        Returns
        -------
        WalkingElement, optional

        """
87
88
        for action_plan_element in self.action_plan[agent_id]:
            scheduled_at = action_plan_element.scheduled_at
u214892's avatar
u214892 committed
89
90
            if scheduled_at > current_step:
                return None
91
92
            elif current_step == scheduled_at:
                return action_plan_element.action
u214892's avatar
u214892 committed
93
94
        return None

95
    def act(self, current_step: int) -> Dict[int, RailEnvActions]:
u214892's avatar
u214892 committed
96
97
        """
        Get the action dictionary to be replayed at the current step.
98
        Returns only action where required (no action for done agents or those not at the beginning of the cell).
u214892's avatar
u214892 committed
99

100
101
        ASSUMPTION we assume the env has `remove_agents_at_target=True` and `activate_agents=False`!!

u214892's avatar
u214892 committed
102
103
104
105
106
107
108
109
110
111
        Parameters
        ----------
        current_step: int

        Returns
        -------
        Dict[int, RailEnvActions]

        """
        action_dict = {}
112
        for agent_id in range(len(self.env.agents)):
u214892's avatar
u214892 committed
113
114
115
116
117
118
            action: Optional[RailEnvActions] = self.get_action_at_step(agent_id, current_step)
            if action is not None:
                action_dict[agent_id] = action
        return action_dict

    def print_action_plan(self):
u214892's avatar
u214892 committed
119
120
121
122
123
124
125
        """Pretty-prints `ActionPlanDict` of this `ControllerFromTrainruns`  to stdout."""
        self.__class__.print_action_plan_dict(self.action_plan)

    @staticmethod
    def print_action_plan_dict(action_plan: ActionPlanDict):
        """Pretty-prints `ActionPlanDict` to stdout."""
        for agent_id, plan in enumerate(action_plan):
u214892's avatar
u214892 committed
126
127
128
129
130
            print("{}: ".format(agent_id))
            for step in plan:
                print("  {}".format(step))

    @staticmethod
131
    def assert_actions_plans_equal(expected_action_plan: ActionPlanDict, actual_action_plan: ActionPlanDict):
u214892's avatar
u214892 committed
132
133
134
135
136
137
        assert len(expected_action_plan) == len(actual_action_plan)
        for k in range(len(expected_action_plan)):
            assert len(expected_action_plan[k]) == len(actual_action_plan[k]), \
                "len for agent {} should be the same.\n\n  expected ({}) = {}\n\n  actual ({}) = {}".format(
                    k,
                    len(expected_action_plan[k]),
138
                    ControllerFromTrainruns.pp.pformat(expected_action_plan[k]),
u214892's avatar
u214892 committed
139
                    len(actual_action_plan[k]),
140
                    ControllerFromTrainruns.pp.pformat(actual_action_plan[k]))
u214892's avatar
u214892 committed
141
142
143
144
            for i in range(len(expected_action_plan[k])):
                assert expected_action_plan[k][i] == actual_action_plan[k][i], \
                    "not the same at agent {} at step {}\n\n  expected = {}\n\n  actual = {}".format(
                        k, i,
145
146
                        ControllerFromTrainruns.pp.pformat(expected_action_plan[k][i]),
                        ControllerFromTrainruns.pp.pformat(actual_action_plan[k][i]))
147
148
        assert expected_action_plan == actual_action_plan, \
            "expected {}, found {}".format(expected_action_plan, actual_action_plan)
u214892's avatar
u214892 committed
149

150
    def _create_action_plan_for_agent(self, agent_id, trainrun) -> ActionPlan:
151
        action_plan = []
u214892's avatar
u214892 committed
152
        agent = self.env.agents[agent_id]
153
        minimum_cell_time = agent.speed_counter.max_count
154
155
        for path_loop, trainrun_waypoint in enumerate(trainrun):
            trainrun_waypoint: TrainrunWaypoint = trainrun_waypoint
u214892's avatar
u214892 committed
156

157
            position = trainrun_waypoint.waypoint.position
u214892's avatar
u214892 committed
158
159
160
161

            if Vec2d.is_equal(agent.target, position):
                break

162
163
            next_trainrun_waypoint: TrainrunWaypoint = trainrun[path_loop + 1]
            next_position = next_trainrun_waypoint.waypoint.position
u214892's avatar
u214892 committed
164
165

            if path_loop == 0:
166
                self._add_action_plan_elements_for_first_path_element_of_agent(
u214892's avatar
u214892 committed
167
                    action_plan,
168
169
                    trainrun_waypoint,
                    next_trainrun_waypoint,
170
171
                    minimum_cell_time
                )
u214892's avatar
u214892 committed
172
173
174
175
                continue

            just_before_target = Vec2d.is_equal(agent.target, next_position)

176
            self._add_action_plan_elements_for_current_path_element(
u214892's avatar
u214892 committed
177
178
                action_plan,
                minimum_cell_time,
179
180
                trainrun_waypoint,
                next_trainrun_waypoint)
u214892's avatar
u214892 committed
181
182
183

            # add a final element
            if just_before_target:
184
                self._add_action_plan_elements_for_target_at_path_element_just_before_target(
u214892's avatar
u214892 committed
185
186
                    action_plan,
                    minimum_cell_time,
187
188
                    trainrun_waypoint,
                    next_trainrun_waypoint)
189
190
191
192
193
        return action_plan

    def _add_action_plan_elements_for_current_path_element(self,
                                                           action_plan: ActionPlan,
                                                           minimum_cell_time: int,
194
195
196
197
198
199
200
201
202
                                                           trainrun_waypoint: TrainrunWaypoint,
                                                           next_trainrun_waypoint: TrainrunWaypoint):
        scheduled_at = trainrun_waypoint.scheduled_at
        next_entry_value = next_trainrun_waypoint.scheduled_at

        position = trainrun_waypoint.waypoint.position
        direction = trainrun_waypoint.waypoint.direction
        next_position = next_trainrun_waypoint.waypoint.position
        next_direction = next_trainrun_waypoint.waypoint.direction
u214892's avatar
u214892 committed
203
204
205
206
207
208
209
210
211
212
        next_action = get_action_for_move(position,
                                          direction,
                                          next_position,
                                          next_direction,
                                          self.env.rail)

        # if the next entry is later than minimum_cell_time, then stop here and
        # move minimum_cell_time before the exit
        # we have to do this since agents in the RailEnv are processed in the step() in the order of their handle
        if next_entry_value > scheduled_at + minimum_cell_time:
213
            action = ActionPlanElement(scheduled_at, RailEnvActions.STOP_MOVING)
214
            action_plan.append(action)
u214892's avatar
u214892 committed
215

216
            action = ActionPlanElement(next_entry_value - minimum_cell_time, next_action)
217
            action_plan.append(action)
u214892's avatar
u214892 committed
218
        else:
219
            action = ActionPlanElement(scheduled_at, next_action)
220
            action_plan.append(action)
u214892's avatar
u214892 committed
221

222
223
224
    def _add_action_plan_elements_for_target_at_path_element_just_before_target(self,
                                                                                action_plan: ActionPlan,
                                                                                minimum_cell_time: int,
225
226
227
                                                                                trainrun_waypoint: TrainrunWaypoint,
                                                                                next_trainrun_waypoint: TrainrunWaypoint):
        scheduled_at = trainrun_waypoint.scheduled_at
u214892's avatar
u214892 committed
228

229
        action = ActionPlanElement(scheduled_at + minimum_cell_time, RailEnvActions.STOP_MOVING)
230
231
232
233
        action_plan.append(action)

    def _add_action_plan_elements_for_first_path_element_of_agent(self,
                                                                  action_plan: ActionPlan,
234
235
                                                                  trainrun_waypoint: TrainrunWaypoint,
                                                                  next_trainrun_waypoint: TrainrunWaypoint,
236
                                                                  minimum_cell_time: int):
237
238
239
240
241
        scheduled_at = trainrun_waypoint.scheduled_at
        position = trainrun_waypoint.waypoint.position
        direction = trainrun_waypoint.waypoint.direction
        next_position = next_trainrun_waypoint.waypoint.position
        next_direction = next_trainrun_waypoint.waypoint.direction
u214892's avatar
u214892 committed
242

243
        # add intial do nothing if we do not enter immediately, actually not necessary
u214892's avatar
u214892 committed
244
        if scheduled_at > 0:
245
            action = ActionPlanElement(0, RailEnvActions.DO_NOTHING)
246
            action_plan.append(action)
247

u214892's avatar
u214892 committed
248
        # add action to enter the grid
249
        action = ActionPlanElement(scheduled_at, RailEnvActions.MOVE_FORWARD)
250
        action_plan.append(action)
u214892's avatar
u214892 committed
251
252
253
254
255
256
257

        next_action = get_action_for_move(position,
                                          direction,
                                          next_position,
                                          next_direction,
                                          self.env.rail)

258
        # if the agent is blocked in the cell, we have to call stop upon entering!
259
        if next_trainrun_waypoint.scheduled_at > scheduled_at + 1 + minimum_cell_time:
260
261
262
263
            action = ActionPlanElement(scheduled_at + 1, RailEnvActions.STOP_MOVING)
            action_plan.append(action)

        # execute the action exactly minimum_cell_time before the entry into the next cell
264
        action = ActionPlanElement(next_trainrun_waypoint.scheduled_at - minimum_cell_time, next_action)
265
        action_plan.append(action)