predictions.py 7.14 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
u214892's avatar
u214892 committed
8
from flatland.envs.rail_env import RailEnvActions
u214892's avatar
u214892 committed
9
10
11
12
13
14
15
16
17
18


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

19
    def get(self, custom_args=None, handle=None):
u214892's avatar
u214892 committed
20
        """
21
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
22
23
24

        Parameters
        -------
25
26
        custom_args: dict
            Not used in this dummy implementation.
u214892's avatar
u214892 committed
27
28
29
30
31
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
32
33
        np.array
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements:
u214892's avatar
u214892 committed
34
35
36
37
38
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
39
40
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
41
42
43
44
45
46
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
47

u214892's avatar
u214892 committed
48
        for agent in agents:
49
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
50
51
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
52
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
53
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
54
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
55
                action_done = False
56
57
                # if we're at the target, stop moving...
                if agent.position == agent.target:
58
                    prediction[index] = [index, *agent.target, agent.direction,
59
60
61
                                         RailEnvActions.STOP_MOVING]

                    continue
u214892's avatar
u214892 committed
62
                for action in action_priorities:
u214892's avatar
u214892 committed
63
64
                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
65
66
67
68
69
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
70
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
71
72
73
                        action_done = True
                        break
                if not action_done:
74
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
75
76
77
78
            prediction_dict[agent.handle] = prediction
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction
        return prediction_dict
79
80
81
82


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
83
    ShortestPathPredictorForRailEnv object.
84

85
    This object returns shortest-path predictions for agents in the RailEnv environment.
86
87
88
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

89
    def get(self, custom_args=None, handle=None):
90
        """
91
92
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
93
94
95

        Parameters
        -------
96
97
        custom_args: dict
            - distance_map : dict
98
99
100
101
102
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
103
104
        np.array
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1) x 5 elements:
105
106
107
108
109
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
110
            The prediction at 0 is the current position, direction etc.
111
112
113
114
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
115
116
117
        assert custom_args
        distance_map = custom_args.get('distance_map')
        assert distance_map is not None
118
119
120
121
122
123

        prediction_dict = {}
        for agent in agents:
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
124
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
125
126
127
            for index in range(1, self.max_depth + 1):
                # if we're at the target, stop moving...
                if agent.position == agent.target:
128
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
129
130
                    continue
                if not agent.moving:
131
                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
132
133
134
135
136
137
138
                    continue
                # Take shortest possible path
                cell_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))

                if np.sum(cell_transitions) == 1:
                    new_direction = np.argmax(cell_transitions)
                    new_position = self._new_position(agent.position, new_direction)
139
140
                elif np.sum(cell_transitions) > 1:
                    min_dist = np.inf
141
142
143
                    for direction in range(4):
                        if cell_transitions[direction] == 1:
                            target_dist = distance_map[agent.handle, agent.position[0], agent.position[1], direction]
144
145
                            if target_dist < min_dist:
                                min_dist = target_dist
146
                                new_direction = direction
147
                    new_position = self._new_position(agent.position, new_direction)
148
149
150

                agent.position = new_position
                agent.direction = new_direction
151
                prediction[index] = [index, *new_position, new_direction, RailEnvActions.MOVE_FORWARD]
152
153
154
155
                action_done = True
                if not action_done:
                    raise Exception("Cannot move further. Something is wrong")
            prediction_dict[agent.handle] = prediction
156
157

            # cleanup: reset initial position
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction

        return prediction_dict

    def _new_position(self, position, movement):
        """
        Utility function that converts a compass movement over a 2D grid to new positions (r, c).
        """
        if movement == 0:  # NORTH
            return (position[0] - 1, position[1])
        elif movement == 1:  # EAST
            return (position[0], position[1] + 1)
        elif movement == 2:  # SOUTH
            return (position[0] + 1, position[1])
        elif movement == 3:  # WEST
            return (position[0], position[1] - 1)