predictions.py 7.05 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
u214892's avatar
u214892 committed
8
from flatland.envs.agent_utils import RailAgentStatus
9
from flatland.envs.distance_map import DistanceMap
u214892's avatar
u214892 committed
10
from flatland.envs.rail_env import RailEnvActions
11
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
u214892's avatar
u214892 committed
12
from flatland.utils.ordered_set import OrderedSet
u214892's avatar
u214892 committed
13
14
15
16
17
18
19
20
21
22


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

23
    def get(self, handle: int = None):
u214892's avatar
u214892 committed
24
        """
25
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
26
27

        Parameters
u214892's avatar
u214892 committed
28
29
        ----------
        handle : int, optional
u214892's avatar
u214892 committed
30
31
32
33
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
34
        np.array
35
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
36
37
38
39
40
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
41
42
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
43
44
45
46
47
48
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
49

u214892's avatar
u214892 committed
50
        for agent in agents:
u214892's avatar
u214892 committed
51
52
53
            if agent.status != RailAgentStatus.ACTIVE:
                # TODO make this generic
                continue
54
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
55
56
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
57
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
58
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
59
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
60
                action_done = False
61
62
                # if we're at the target, stop moving...
                if agent.position == agent.target:
63
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
64
65

                    continue
u214892's avatar
u214892 committed
66
                for action in action_priorities:
67
                    cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
u214892's avatar
u214892 committed
68
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
69
70
71
72
73
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
74
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
75
76
77
                        action_done = True
                        break
                if not action_done:
78
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
79
80
81
82
            prediction_dict[agent.handle] = prediction
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction
        return prediction_dict
83
84
85
86


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
87
    ShortestPathPredictorForRailEnv object.
88

89
    This object returns shortest-path predictions for agents in the RailEnv environment.
90
91
92
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

93
94
    def __init__(self, max_depth: int = 20):
        super().__init__(max_depth)
95

96
    def get(self, handle: int = None):
97
        """
98
99
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
100
101
102
        Does not take into account future positions of other agents!

        If there is no shortest path, the agent just stands still and stops moving.
103
104

        Parameters
u214892's avatar
u214892 committed
105
106
        ----------
        handle : int, optional
107
108
109
110
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
111
        np.array
112
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
113
114
115
116
            - time_offset
            - position axis 0
            - position axis 1
            - direction
117
            - action taken to come here (not implemented yet)
118
            The prediction at 0 is the current position, direction etc.
119
120
121
122
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
123
124
125
        distance_map: DistanceMap = self.env.distance_map

        shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
126
127
128

        prediction_dict = {}
        for agent in agents:
u214892's avatar
u214892 committed
129
130
131
132
133
134
135
136
137

            if agent.status == RailAgentStatus.READY_TO_DEPART:
                _agent_initial_position = agent.initial_position
            elif agent.status == RailAgentStatus.ACTIVE:
                _agent_initial_position = agent.position
            elif agent.status == RailAgentStatus.DONE:
                    _agent_initial_position = agent.target
            else:
                prediction_dict[agent.handle] = None
u214892's avatar
u214892 committed
138
                continue
u214892's avatar
u214892 committed
139

140
            _agent_initial_direction = agent.direction
141
142
            agent_speed = agent.speed_data["speed"]
            times_per_cell = int(np.reciprocal(agent_speed))
143
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
144
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
145
146
147
148
149
150
151

            shortest_path = shortest_paths[agent.handle]

            # if there is a shortest path, remove the initial position
            if shortest_path:
                shortest_path = shortest_path[1:]

152
153
            new_direction = _agent_initial_direction
            new_position = _agent_initial_position
u214892's avatar
u214892 committed
154
            visited = OrderedSet()
155
            for index in range(1, self.max_depth + 1):
156
157
158
159
                # if we're at the target or not moving, stop moving until max_depth is reached
                if new_position == agent.target or not agent.moving or not shortest_path:
                    prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
                    visited.add((*new_position, agent.direction))
160
                    continue
161
162
163
164
165
166

                if index % times_per_cell == 0:
                    new_position = shortest_path[0].position
                    new_direction = shortest_path[0].direction

                    shortest_path = shortest_path[1:]
167
168

                # prediction is ready
169
                prediction[index] = [index, *new_position, new_direction, 0]
170
171
172
                visited.add((*new_position, new_direction))

            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
Erik Nygren's avatar
Erik Nygren committed
173
            self.env.dev_pred_dict[agent.handle] = visited
174
            prediction_dict[agent.handle] = prediction
175

176
        return prediction_dict