predictions.py 7.15 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
8
from flatland.envs.distance_map import DistanceMap
9
from flatland.envs.rail_env_action import RailEnvActions
10
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
u214892's avatar
u214892 committed
11
from flatland.utils.ordered_set import OrderedSet
12
from flatland.envs.step_utils.states import TrainState
u214892's avatar
u214892 committed
13
14
15
16
17
18
19
20
21
22


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

23
    def get(self, handle: int = None):
u214892's avatar
u214892 committed
24
        """
25
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
26
27

        Parameters
u214892's avatar
u214892 committed
28
29
        ----------
        handle : int, optional
u214892's avatar
u214892 committed
30
31
32
33
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
34
        np.array
35
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
36
37
38
39
40
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
41
42
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
43
44
45
46
47
48
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
49

u214892's avatar
u214892 committed
50
        for agent in agents:
51
            if not agent.state.is_on_map_state():
u214892's avatar
u214892 committed
52
53
                # TODO make this generic
                continue
54
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
55
56
            agent_virtual_position = agent.position
            agent_virtual_direction = agent.direction
57
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
u214892's avatar
u214892 committed
58
            prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
59
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
60
                action_done = False
61
62
                # if we're at the target, stop moving...
                if agent.position == agent.target:
63
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
64
65

                    continue
u214892's avatar
u214892 committed
66
                for action in action_priorities:
67
                    cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
u214892's avatar
u214892 committed
68
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
69
70
71
72
73
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
74
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
75
76
77
                        action_done = True
                        break
                if not action_done:
78
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
79
            prediction_dict[agent.handle] = prediction
u214892's avatar
u214892 committed
80
81
            agent.position = agent_virtual_position
            agent.direction = agent_virtual_direction
u214892's avatar
u214892 committed
82
        return prediction_dict
83
84
85
86


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
87
    ShortestPathPredictorForRailEnv object.
88

89
    This object returns shortest-path predictions for agents in the RailEnv environment.
90
91
92
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

93
94
    def __init__(self, max_depth: int = 20):
        super().__init__(max_depth)
95

96
    def get(self, handle: int = None):
97
        """
98
99
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
100
101
102
        Does not take into account future positions of other agents!

        If there is no shortest path, the agent just stands still and stops moving.
103
104

        Parameters
u214892's avatar
u214892 committed
105
106
        ----------
        handle : int, optional
107
108
109
110
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
111
        np.array
112
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
113
114
115
116
            - time_offset
            - position axis 0
            - position axis 1
            - direction
117
            - action taken to come here (not implemented yet)
118
            The prediction at 0 is the current position, direction etc.
119
120
121
122
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
123
124
125
        distance_map: DistanceMap = self.env.distance_map

        shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
126
127
128

        prediction_dict = {}
        for agent in agents:
129
            if agent.state.is_off_map_state():
130
                agent_virtual_position = agent.initial_position
131
            elif agent.state.is_on_map_state():
u214892's avatar
u214892 committed
132
                agent_virtual_position = agent.position
133
            elif agent.state == TrainState.DONE:
134
                agent_virtual_position = agent.target
u214892's avatar
u214892 committed
135
            else:
136
137
138
139
140

                prediction = np.zeros(shape=(self.max_depth + 1, 5))
                for i in range(self.max_depth):
                    prediction[i] = [i, None, None, None, None]
                prediction_dict[agent.handle] = prediction
u214892's avatar
u214892 committed
141
                continue
u214892's avatar
u214892 committed
142

u214892's avatar
u214892 committed
143
            agent_virtual_direction = agent.direction
144
            agent_speed = agent.speed_counter.speed
145
            times_per_cell = int(np.reciprocal(agent_speed))
146
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
u214892's avatar
u214892 committed
147
            prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
148
149
150
151
152
153
154

            shortest_path = shortest_paths[agent.handle]

            # if there is a shortest path, remove the initial position
            if shortest_path:
                shortest_path = shortest_path[1:]

u214892's avatar
u214892 committed
155
156
            new_direction = agent_virtual_direction
            new_position = agent_virtual_position
u214892's avatar
u214892 committed
157
            visited = OrderedSet()
158
            for index in range(1, self.max_depth + 1):
u229589's avatar
u229589 committed
159
160
                # if we're at the target, stop moving until max_depth is reached
                if new_position == agent.target or not shortest_path:
161
162
                    prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
                    visited.add((*new_position, agent.direction))
163
                    continue
164
165
166
167
168
169

                if index % times_per_cell == 0:
                    new_position = shortest_path[0].position
                    new_direction = shortest_path[0].direction

                    shortest_path = shortest_path[1:]
170
171

                # prediction is ready
172
                prediction[index] = [index, *new_position, new_direction, 0]
173
174
175
                visited.add((*new_position, new_direction))

            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
Erik Nygren's avatar
Erik Nygren committed
176
            self.env.dev_pred_dict[agent.handle] = visited
177
            prediction_dict[agent.handle] = prediction
178

179
        return prediction_dict