predictions.py 7.3 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
u214892's avatar
u214892 committed
8
from flatland.envs.agent_utils import RailAgentStatus
9
from flatland.envs.distance_map import DistanceMap
10
from flatland.envs.rail_env_action import RailEnvActions
11
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
u214892's avatar
u214892 committed
12
from flatland.utils.ordered_set import OrderedSet
u214892's avatar
u214892 committed
13
14
15
16
17
18
19
20
21
22


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

23
    def get(self, handle: int = None):
u214892's avatar
u214892 committed
24
        """
25
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
26
27

        Parameters
u214892's avatar
u214892 committed
28
29
        ----------
        handle : int, optional
u214892's avatar
u214892 committed
30
31
32
33
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
34
        np.array
35
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
36
37
38
39
40
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
41
42
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
43
44
45
46
47
48
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
49

u214892's avatar
u214892 committed
50
        for agent in agents:
u214892's avatar
u214892 committed
51
52
53
            if agent.status != RailAgentStatus.ACTIVE:
                # TODO make this generic
                continue
54
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
55
56
            agent_virtual_position = agent.position
            agent_virtual_direction = agent.direction
57
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
u214892's avatar
u214892 committed
58
            prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
59
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
60
                action_done = False
61
62
                # if we're at the target, stop moving...
                if agent.position == agent.target:
63
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
64
65

                    continue
u214892's avatar
u214892 committed
66
                for action in action_priorities:
67
                    cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
u214892's avatar
u214892 committed
68
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
69
70
71
72
73
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
74
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
75
76
77
                        action_done = True
                        break
                if not action_done:
78
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
79
            prediction_dict[agent.handle] = prediction
u214892's avatar
u214892 committed
80
81
            agent.position = agent_virtual_position
            agent.direction = agent_virtual_direction
u214892's avatar
u214892 committed
82
        return prediction_dict
83
84
85
86


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
87
    ShortestPathPredictorForRailEnv object.
88

89
    This object returns shortest-path predictions for agents in the RailEnv environment.
90
91
92
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

93
94
    def __init__(self, max_depth: int = 20):
        super().__init__(max_depth)
95

96
    def get(self, handle: int = None):
97
        """
98
99
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
100
101
102
        Does not take into account future positions of other agents!

        If there is no shortest path, the agent just stands still and stops moving.
103
104

        Parameters
u214892's avatar
u214892 committed
105
106
        ----------
        handle : int, optional
107
108
109
110
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
111
        np.array
112
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
113
114
115
116
            - time_offset
            - position axis 0
            - position axis 1
            - direction
117
            - action taken to come here (not implemented yet)
118
            The prediction at 0 is the current position, direction etc.
119
120
121
122
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
123
124
125
        distance_map: DistanceMap = self.env.distance_map

        shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
126
127
128

        prediction_dict = {}
        for agent in agents:
129
130
131
            if agent.status == RailAgentStatus.WAITING:
                agent_virtual_position = agent.initial_position
            elif agent.status == RailAgentStatus.READY_TO_DEPART:
u214892's avatar
u214892 committed
132
                agent_virtual_position = agent.initial_position
u214892's avatar
u214892 committed
133
            elif agent.status == RailAgentStatus.ACTIVE:
u214892's avatar
u214892 committed
134
                agent_virtual_position = agent.position
u214892's avatar
u214892 committed
135
            elif agent.status == RailAgentStatus.DONE:
136
                agent_virtual_position = agent.target
u214892's avatar
u214892 committed
137
            else:
138
139
140
141
142

                prediction = np.zeros(shape=(self.max_depth + 1, 5))
                for i in range(self.max_depth):
                    prediction[i] = [i, None, None, None, None]
                prediction_dict[agent.handle] = prediction
u214892's avatar
u214892 committed
143
                continue
u214892's avatar
u214892 committed
144

u214892's avatar
u214892 committed
145
            agent_virtual_direction = agent.direction
146
147
            agent_speed = agent.speed_data["speed"]
            times_per_cell = int(np.reciprocal(agent_speed))
148
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
u214892's avatar
u214892 committed
149
            prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
150
151
152
153
154
155
156

            shortest_path = shortest_paths[agent.handle]

            # if there is a shortest path, remove the initial position
            if shortest_path:
                shortest_path = shortest_path[1:]

u214892's avatar
u214892 committed
157
158
            new_direction = agent_virtual_direction
            new_position = agent_virtual_position
u214892's avatar
u214892 committed
159
            visited = OrderedSet()
160
            for index in range(1, self.max_depth + 1):
u229589's avatar
u229589 committed
161
162
                # if we're at the target, stop moving until max_depth is reached
                if new_position == agent.target or not shortest_path:
163
164
                    prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
                    visited.add((*new_position, agent.direction))
165
                    continue
166
167
168
169
170
171

                if index % times_per_cell == 0:
                    new_position = shortest_path[0].position
                    new_direction = shortest_path[0].direction

                    shortest_path = shortest_path[1:]
172
173

                # prediction is ready
174
                prediction[index] = [index, *new_position, new_direction, 0]
175
176
177
                visited.add((*new_position, new_direction))

            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
Erik Nygren's avatar
Erik Nygren committed
178
            self.env.dev_pred_dict[agent.handle] = visited
179
            prediction_dict[agent.handle] = prediction
180

181
        return prediction_dict