predictions.py 6.48 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
8
from flatland.envs.distance_map import DistanceMap
u214892's avatar
u214892 committed
9
from flatland.envs.rail_env import RailEnvActions
10
from flatland.envs.rail_env_shortest_paths import get_shortest_paths
u214892's avatar
u214892 committed
11
from flatland.utils.ordered_set import OrderedSet
u214892's avatar
u214892 committed
12
13
14
15
16
17
18
19
20
21


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

22
    def get(self, handle: int = None):
u214892's avatar
u214892 committed
23
        """
24
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
25
26

        Parameters
u214892's avatar
u214892 committed
27
28
        ----------
        handle : int, optional
u214892's avatar
u214892 committed
29
30
31
32
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
33
        np.array
34
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
35
36
37
38
39
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
40
41
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
42
43
44
45
46
47
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
48

u214892's avatar
u214892 committed
49
        for agent in agents:
50
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
51
52
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
53
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
54
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
55
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
56
                action_done = False
57
58
                # if we're at the target, stop moving...
                if agent.position == agent.target:
59
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
60
61

                    continue
u214892's avatar
u214892 committed
62
                for action in action_priorities:
63
                    cell_is_free, new_cell_isValid, new_direction, new_position, transition_isValid = \
u214892's avatar
u214892 committed
64
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
65
66
67
68
69
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
70
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
71
72
73
                        action_done = True
                        break
                if not action_done:
74
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
75
76
77
78
            prediction_dict[agent.handle] = prediction
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction
        return prediction_dict
79
80
81
82


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
83
    ShortestPathPredictorForRailEnv object.
84

85
    This object returns shortest-path predictions for agents in the RailEnv environment.
86
87
88
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

89
90
    def __init__(self, max_depth: int = 20):
        super().__init__(max_depth)
91

92
    def get(self, handle: int = None):
93
        """
94
95
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
96
97
98
        Does not take into account future positions of other agents!

        If there is no shortest path, the agent just stands still and stops moving.
99
100

        Parameters
u214892's avatar
u214892 committed
101
102
        ----------
        handle : int, optional
103
104
105
106
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
107
        np.array
108
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
109
110
111
112
            - time_offset
            - position axis 0
            - position axis 1
            - direction
113
            - action taken to come here (not implemented yet)
114
            The prediction at 0 is the current position, direction etc.
115
116
117
118
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
119
120
121
        distance_map: DistanceMap = self.env.distance_map

        shortest_paths = get_shortest_paths(distance_map, max_depth=self.max_depth)
122
123
124
125
126

        prediction_dict = {}
        for agent in agents:
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
127
128
            agent_speed = agent.speed_data["speed"]
            times_per_cell = int(np.reciprocal(agent_speed))
129
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
130
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
131
132
133
134
135
136
137

            shortest_path = shortest_paths[agent.handle]

            # if there is a shortest path, remove the initial position
            if shortest_path:
                shortest_path = shortest_path[1:]

138
139
            new_direction = _agent_initial_direction
            new_position = _agent_initial_position
u214892's avatar
u214892 committed
140
            visited = OrderedSet()
141
            for index in range(1, self.max_depth + 1):
142
143
144
145
                # if we're at the target or not moving, stop moving until max_depth is reached
                if new_position == agent.target or not agent.moving or not shortest_path:
                    prediction[index] = [index, *new_position, new_direction, RailEnvActions.STOP_MOVING]
                    visited.add((*new_position, agent.direction))
146
                    continue
147
148
149
150
151
152

                if index % times_per_cell == 0:
                    new_position = shortest_path[0].position
                    new_direction = shortest_path[0].direction

                    shortest_path = shortest_path[1:]
153
154

                # prediction is ready
155
                prediction[index] = [index, *new_position, new_direction, 0]
156
157
158
                visited.add((*new_position, new_direction))

            # TODO: very bady side effects for visualization only: hand the dev_pred_dict back instead of setting on env!
Erik Nygren's avatar
Erik Nygren committed
159
            self.env.dev_pred_dict[agent.handle] = visited
160
            prediction_dict[agent.handle] = prediction
161

162
        return prediction_dict