predictions.py 7.24 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
8
from flatland.envs.distance_map import DistanceMap
9
from flatland.envs.rail_env_action import RailEnvActions
10
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
u214892's avatar
u214892 committed
11
from flatland.utils.ordered_set import OrderedSet
12
from flatland.envs.step_utils.states import TrainState
13
from flatland.envs.step_utils import transition_utils
u214892's avatar
u214892 committed
14
15
16
17
18
19
20
21
22
23


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

24
    def get(self, handle: int = None):
u214892's avatar
u214892 committed
25
        """
26
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
27
28

        Parameters
u214892's avatar
u214892 committed
29
30
        ----------
        handle : int, optional
u214892's avatar
u214892 committed
31
32
33
34
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
35
        np.array
36
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
37
38
39
40
41
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
42
43
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
44
45
46
47
48
49
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
50

u214892's avatar
u214892 committed
51
        for agent in agents:
52
            if not agent.state.is_on_map_state():
u214892's avatar
u214892 committed
53
54
                # TODO make this generic
                continue
55
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
56
57
            agent_virtual_position = agent.position
            agent_virtual_direction = agent.direction
58
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
u214892's avatar
u214892 committed
59
            prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
60
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
61
                action_done = False
62
63
                # if we're at the target, stop moving...
                if agent.position == agent.target:
64
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
65
66

                    continue
u214892's avatar
u214892 committed
67
                for action in action_priorities:
68
69
                    new_cell_isValid, new_direction, new_position, transition_isValid = \
                        transition_utils.check_action_on_agent(action, self.env.rail, agent.position, agent.direction)
u214892's avatar
u214892 committed
70
71
72
73
74
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
75
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
76
77
78
                        action_done = True
                        break
                if not action_done:
79
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
80
            prediction_dict[agent.handle] = prediction
u214892's avatar
u214892 committed
81
82
            agent.position = agent_virtual_position
            agent.direction = agent_virtual_direction
u214892's avatar
u214892 committed
83
        return prediction_dict
84
85
86
87


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
88
    ShortestPathPredictorForRailEnv object.
89

90
    This object returns shortest-path predictions for agents in the RailEnv environment.
91
92
93
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

94
95
    def __init__(self, max_depth: int = 20):
        super().__init__(max_depth)
96

97
    def get(self, handle: int = None):
98
        """
99
100
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
101
102
103
        Does not take into account future positions of other agents!

        If there is no shortest path, the agent just stands still and stops moving.
104
105

        Parameters
u214892's avatar
u214892 committed
106
107
        ----------
        handle : int, optional
108
109
110
111
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
112
        np.array
113
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
114
115
116
117
            - time_offset
            - position axis 0
            - position axis 1
            - direction
118
            - action taken to come here (not implemented yet)
119
            The prediction at 0 is the current position, direction etc.
120
121
122
123
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
124
125
126
        distance_map: DistanceMap = self.env.distance_map

        shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
127
128
129

        prediction_dict = {}
        for agent in agents:
130
            if agent.state.is_off_map_state():
131
                agent_virtual_position = agent.initial_position
132
            elif agent.state.is_on_map_state():
u214892's avatar
u214892 committed
133
                agent_virtual_position = agent.position
134
            elif agent.state == TrainState.DONE:
135
                agent_virtual_position = agent.target
u214892's avatar
u214892 committed
136
            else:
137
138
139
140
141

                prediction = np.zeros(shape=(self.max_depth + 1, 5))
                for i in range(self.max_depth):
                    prediction[i] = [i, None, None, None, None]
                prediction_dict[agent.handle] = prediction
u214892's avatar
u214892 committed
142
                continue
u214892's avatar
u214892 committed
143

u214892's avatar
u214892 committed
144
            agent_virtual_direction = agent.direction
145
            agent_speed = agent.speed_counter.speed
146
            times_per_cell = int(np.reciprocal(agent_speed))
147
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
u214892's avatar
u214892 committed
148
            prediction[0] = [0, *agent_virtual_position, agent_virtual_direction, 0]
149
150
151
152
153
154
155

            shortest_path = shortest_paths[agent.handle]

            # if there is a shortest path, remove the initial position
            if shortest_path:
                shortest_path = shortest_path[1:]

u214892's avatar
u214892 committed
156
157
            new_direction = agent_virtual_direction
            new_position = agent_virtual_position
u214892's avatar
u214892 committed
158
            visited = OrderedSet()
159
            for index in range(1, self.max_depth + 1):
u229589's avatar
u229589 committed
160
161
                # if we're at the target, stop moving until max_depth is reached
                if new_position == agent.target or not shortest_path:
162
163
                    prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
                    visited.add((*new_position, agent.direction))
164
                    continue
165
166
167
168
169
170

                if index % times_per_cell == 0:
                    new_position = shortest_path[0].position
                    new_direction = shortest_path[0].direction

                    shortest_path = shortest_path[1:]
171
172

                # prediction is ready
173
                prediction[index] = [index, *new_position, new_direction, 0]
174
175
176
                visited.add((*new_position, new_direction))

            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
Erik Nygren's avatar
Erik Nygren committed
177
            self.env.dev_pred_dict[agent.handle] = visited
178
            prediction_dict[agent.handle] = prediction
179

180
        return prediction_dict