predictions.py 7.4 KB
Newer Older
u214892's avatar
u214892 committed
1
2
3
4
"""
Collection of environment-specific PredictionBuilder.
"""

u214892's avatar
u214892 committed
5
6
import numpy as np

u214892's avatar
u214892 committed
7
from flatland.core.env_prediction_builder import PredictionBuilder
u214892's avatar
u214892 committed
8
from flatland.core.grid.grid4_utils import get_new_position
u214892's avatar
u214892 committed
9
from flatland.envs.rail_env import RailEnvActions
u214892's avatar
u214892 committed
10
11
12
13
14
15
16
17
18
19


class DummyPredictorForRailEnv(PredictionBuilder):
    """
    DummyPredictorForRailEnv object.

    This object returns predictions for agents in the RailEnv environment.
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

20
    def get(self, custom_args=None, handle=None):
u214892's avatar
u214892 committed
21
        """
22
        Called whenever get_many in the observation build is called.
u214892's avatar
u214892 committed
23
24
25

        Parameters
        -------
26
27
        custom_args: dict
            Not used in this dummy implementation.
u214892's avatar
u214892 committed
28
29
30
31
32
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
33
        np.array
34
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
u214892's avatar
u214892 committed
35
36
37
38
39
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
40
41
            The prediction at 0 is the current position, direction etc.

u214892's avatar
u214892 committed
42
43
44
45
46
47
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]

        prediction_dict = {}
u214892's avatar
u214892 committed
48

u214892's avatar
u214892 committed
49
        for agent in agents:
50
            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
u214892's avatar
u214892 committed
51
52
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
53
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
54
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
55
            for index in range(1, self.max_depth + 1):
u214892's avatar
u214892 committed
56
                action_done = False
57
58
                # if we're at the target, stop moving...
                if agent.position == agent.target:
59
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
60
61

                    continue
u214892's avatar
u214892 committed
62
                for action in action_priorities:
u214892's avatar
u214892 committed
63
64
                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
                        self.env._check_action_on_agent(action, agent)
u214892's avatar
u214892 committed
65
66
67
68
69
                    if all([new_cell_isValid, transition_isValid]):
                        # move and change direction to face the new_direction that was
                        # performed
                        agent.position = new_position
                        agent.direction = new_direction
70
                        prediction[index] = [index, *new_position, new_direction, action]
u214892's avatar
u214892 committed
71
72
73
                        action_done = True
                        break
                if not action_done:
74
                    raise Exception("Cannot move further. Something is wrong")
u214892's avatar
u214892 committed
75
76
77
78
            prediction_dict[agent.handle] = prediction
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction
        return prediction_dict
79
80
81
82


class ShortestPathPredictorForRailEnv(PredictionBuilder):
    """
83
    ShortestPathPredictorForRailEnv object.
84

85
    This object returns shortest-path predictions for agents in the RailEnv environment.
86
87
88
    The prediction acts as if no other agent is in the environment and always takes the forward action.
    """

89
    def __init__(self, max_depth=20):
90
        # Initialize with depth 20
91
92
        self.max_depth = max_depth

93
    def get(self, custom_args=None, handle=None):
94
        """
95
96
        Called whenever get_many in the observation build is called.
        Requires distance_map to extract the shortest path.
97
98
99

        Parameters
        -------
100
101
        custom_args: dict
            - distance_map : dict
102
103
104
105
106
        handle : int (optional)
            Handle of the agent for which to compute the observation vector.

        Returns
        -------
107
        np.array
108
            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
109
110
111
112
113
            - time_offset
            - position axis 0
            - position axis 1
            - direction
            - action taken to come here
114
            The prediction at 0 is the current position, direction etc.
115
116
117
118
        """
        agents = self.env.agents
        if handle:
            agents = [self.env.agents[handle]]
119
        assert custom_args is not None
120
121
        distance_map = custom_args.get('distance_map')
        assert distance_map is not None
122
123
124
125
126
127

        prediction_dict = {}
        for agent in agents:
            _agent_initial_position = agent.position
            _agent_initial_direction = agent.direction
            prediction = np.zeros(shape=(self.max_depth + 1, 5))
128
            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
Erik Nygren's avatar
Erik Nygren committed
129
            visited = set()
130
131
132
            for index in range(1, self.max_depth + 1):
                # if we're at the target, stop moving...
                if agent.position == agent.target:
133
                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
Erik Nygren's avatar
Erik Nygren committed
134
                    visited.add((agent.position[0], agent.position[1], agent.direction))
135
136
                    continue
                if not agent.moving:
137
                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
Erik Nygren's avatar
Erik Nygren committed
138
                    visited.add((agent.position[0], agent.position[1], agent.direction))
139
140
                    continue
                # Take shortest possible path
u214892's avatar
u214892 committed
141
                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
142

143
144
                new_position = None
                new_direction = None
145
146
                if np.sum(cell_transitions) == 1:
                    new_direction = np.argmax(cell_transitions)
147
                    new_position = get_new_position(agent.position, new_direction)
148
149
                elif np.sum(cell_transitions) > 1:
                    min_dist = np.inf
150
                    no_dist_found = True
151
152
                    for direction in range(4):
                        if cell_transitions[direction] == 1:
153
154
                            neighbour_cell = get_new_position(agent.position, direction)
                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
155
                            if target_dist < min_dist or no_dist_found:
156
                                min_dist = target_dist
157
                                new_direction = direction
158
                                no_dist_found = False
159
160
161
162
                    new_position = get_new_position(agent.position, new_direction)
                else:
                    raise Exception("No transition possible {}".format(cell_transitions))

163
                # update the agent's position and direction
164
165
                agent.position = new_position
                agent.direction = new_direction
166
167

                # prediction is ready
168
                prediction[index] = [index, *new_position, new_direction, 0]
Erik Nygren's avatar
Erik Nygren committed
169
170
                visited.add((new_position[0], new_position[1], new_direction))
            self.env.dev_pred_dict[agent.handle] = visited
171
            prediction_dict[agent.handle] = prediction
172
173

            # cleanup: reset initial position
174
175
176
177
            agent.position = _agent_initial_position
            agent.direction = _agent_initial_direction

        return prediction_dict