predictions.py 7.61 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.rail_env import RailEnvActions
u214892's avatar
u214892 committed
10
from flatland.utils.ordered_set import OrderedSet
u214892's avatar
u214892 committed
11
12
13
14
15
16
17
18
19
20


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

21
    def get(self, handle=None):
u214892's avatar
u214892 committed
22
        """
23
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
24
25

        Parameters
u214892's avatar
u214892 committed
26
        ----------
27
28
        custom_args: dict
            Not used in this dummy implementation.
u214892's avatar
u214892 committed
29
        handle : int, optional
u214892's avatar
u214892 committed
30
31
32
33
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
34
        np.array
35
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
36
37
38
39
40
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
41
42
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
43
44
45
46
47
48
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
49

u214892's avatar
u214892 committed
50
        for agent in agents:
51
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
52
53
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
54
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
55
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
56
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
57
                action_done = False
58
59
                # if we're at the target, stop moving...
                if agent.position == agent.target:
60
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
61
62

                    continue
u214892's avatar
u214892 committed
63
                for action in action_priorities:
u214892's avatar
u214892 committed
64
65
                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
66
67
68
69
70
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
71
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
72
73
74
                        action_done = True
                        break
                if not action_done:
75
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
76
77
78
79
            prediction_dict[agent.handle] = prediction
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction
        return prediction_dict
80
81
82
83


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
84
    ShortestPathPredictorForRailEnv object.
85

86
    This object returns shortest-path predictions for agents in the RailEnv environment.
87
88
89
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

90
    def __init__(self, max_depth=20):
91
        # Initialize with depth 20
92
93
        self.max_depth = max_depth

94
    def get(self, handle=None):
95
        """
96
97
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
98
99

        Parameters
u214892's avatar
u214892 committed
100
        ----------
101
102
        custom_args: dict
            - distance_map : dict
u214892's avatar
u214892 committed
103
        handle : int, optional
104
105
106
107
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
108
        np.array
109
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
110
111
112
113
114
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
115
            The prediction at 0 is the current position, direction etc.
116
117
118
119
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
120
        distance_map = self.env.distance_map
121
        assert distance_map is not None
122
123
124
125
126

        prediction_dict = {}
        for agent in agents:
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
127
128
            agent_speed = agent.speed_data["speed"]
            times_per_cell = int(np.reciprocal(agent_speed))
129
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
130
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
131
132
            new_direction = _agent_initial_direction
            new_position = _agent_initial_position
u214892's avatar
u214892 committed
133
            visited = OrderedSet()
134
135
136
            for index in range(1, self.max_depth + 1):
                # if we're at the target, stop moving...
                if agent.position == agent.target:
137
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
Erik Nygren's avatar
Erik Nygren committed
138
                    visited.add((agent.position[0], agent.position[1], agent.direction))
139
140
                    continue
                if not agent.moving:
141
                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
Erik Nygren's avatar
Erik Nygren committed
142
                    visited.add((agent.position[0], agent.position[1], agent.direction))
143
144
                    continue
                # Take shortest possible path
u214892's avatar
u214892 committed
145
                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
146

147
                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
148
                    new_direction = np.argmax(cell_transitions)
149
                    new_position = get_new_position(agent.position, new_direction)
150
                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
151
                    min_dist = np.inf
152
                    no_dist_found = True
153
154
                    for direction in range(4):
                        if cell_transitions[direction] == 1:
155
                            neighbour_cell = get_new_position(agent.position, direction)
156
                            target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
157
                            if target_dist < min_dist or no_dist_found:
158
                                min_dist = target_dist
159
                                new_direction = direction
160
                                no_dist_found = False
161
                    new_position = get_new_position(agent.position, new_direction)
162
                elif index % times_per_cell == 0:
163
164
                    raise Exception("No transition possible {}".format(cell_transitions))

165
                # update the agent's position and direction
166
167
                agent.position = new_position
                agent.direction = new_direction
168
169

                # prediction is ready
170
                prediction[index] = [index, *new_position, new_direction, 0]
Erik Nygren's avatar
Erik Nygren committed
171
172
                visited.add((new_position[0], new_position[1], new_direction))
            self.env.dev_pred_dict[agent.handle] = visited
173
            prediction_dict[agent.handle] = prediction
174
175

            # cleanup: reset initial position
176
177
178
179
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction

        return prediction_dict