predictions.py 7.46 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.rail_env import RailEnvActions
u214892's avatar
u214892 committed
10
from flatland.utils.ordered_set import OrderedSet
u214892's avatar
u214892 committed
11
12
13
14
15
16
17
18
19
20


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

21
    def get(self, handle: int = None):
u214892's avatar
u214892 committed
22
        """
23
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
24
25

        Parameters
u214892's avatar
u214892 committed
26
27
        ----------
        handle : int, optional
u214892's avatar
u214892 committed
28
29
30
31
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
32
        np.array
33
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
34
35
36
37
38
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
39
40
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
41
42
43
44
45
46
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
47

u214892's avatar
u214892 committed
48
        for agent in agents:
49
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
50
51
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
52
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
53
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
54
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
55
                action_done = False
56
57
                # if we're at the target, stop moving...
                if agent.position == agent.target:
58
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
59
60

                    continue
u214892's avatar
u214892 committed
61
                for action in action_priorities:
u214892's avatar
u214892 committed
62
63
                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
64
65
66
67
68
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
69
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
70
71
72
                        action_done = True
                        break
                if not action_done:
73
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
74
75
76
77
            prediction_dict[agent.handle] = prediction
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction
        return prediction_dict
78
79
80
81


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
82
    ShortestPathPredictorForRailEnv object.
83

84
    This object returns shortest-path predictions for agents in the RailEnv environment.
85
86
87
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

88
89
    def __init__(self, max_depth: int = 20):
        super().__init__(max_depth)
90

91
    def get(self, handle: int = None):
92
        """
93
94
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
95
96

        Parameters
u214892's avatar
u214892 committed
97
98
        ----------
        handle : int, optional
99
100
101
102
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
103
        np.array
104
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
105
106
107
108
109
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
110
            The prediction at 0 is the current position, direction etc.
111
112
113
114
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
115
        distance_map = self.env.distance_map
116
        assert distance_map is not None
117
118
119
120
121

        prediction_dict = {}
        for agent in agents:
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
122
123
            agent_speed = agent.speed_data["speed"]
            times_per_cell = int(np.reciprocal(agent_speed))
124
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
125
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
126
127
            new_direction = _agent_initial_direction
            new_position = _agent_initial_position
u214892's avatar
u214892 committed
128
            visited = OrderedSet()
129
130
131
            for index in range(1, self.max_depth + 1):
                # if we're at the target, stop moving...
                if agent.position == agent.target:
132
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
Erik Nygren's avatar
Erik Nygren committed
133
                    visited.add((agent.position[0], agent.position[1], agent.direction))
134
135
                    continue
                if not agent.moving:
136
                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
Erik Nygren's avatar
Erik Nygren committed
137
                    visited.add((agent.position[0], agent.position[1], agent.direction))
138
139
                    continue
                # Take shortest possible path
u214892's avatar
u214892 committed
140
                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
141

142
                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
143
                    new_direction = np.argmax(cell_transitions)
144
                    new_position = get_new_position(agent.position, new_direction)
145
                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
146
                    min_dist = np.inf
147
                    no_dist_found = True
148
149
                    for direction in range(4):
                        if cell_transitions[direction] == 1:
150
                            neighbour_cell = get_new_position(agent.position, direction)
151
                            target_dist = distance_map.get()[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
152
                            if target_dist < min_dist or no_dist_found:
153
                                min_dist = target_dist
154
                                new_direction = direction
155
                                no_dist_found = False
156
                    new_position = get_new_position(agent.position, new_direction)
157
                elif index % times_per_cell == 0:
158
159
                    raise Exception("No transition possible {}".format(cell_transitions))

160
                # update the agent's position and direction
161
162
                agent.position = new_position
                agent.direction = new_direction
163
164

                # prediction is ready
165
                prediction[index] = [index, *new_position, new_direction, 0]
Erik Nygren's avatar
Erik Nygren committed
166
167
                visited.add((new_position[0], new_position[1], new_direction))
            self.env.dev_pred_dict[agent.handle] = visited
168
            prediction_dict[agent.handle] = prediction
169
170

            # cleanup: reset initial position
171
172
173
174
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction

        return prediction_dict