env_observation_builder.py 12.9 KB
Newer Older
1
2
3
import numpy as np
from collections import deque

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
## TODO: add docstrings, pylint, etc...


class ObservationBuilder:
    def __init__(self, env):
        self.env = env

    def reset(self):
        raise NotImplementedError()

    def get(self, handle):
        raise NotImplementedError()



class TreeObsForRailEnv(ObservationBuilder):
20
21
22
    def __init__(self, env):
        self.env = env

23
    def reset(self):
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        self.distance_map = np.inf * np.ones(shape=(self.env.number_of_agents,
                                                    self.env.height,
                                                    self.env.width))
        self.max_dist = np.zeros(self.env.number_of_agents)

        for i in range(self.env.number_of_agents):
            self.max_dist[i] = self._distance_map_walker(self.env.agents_target[i], i)


    def _distance_map_walker(self, position, target_nr):
        # Returns max distance to target, from the farthest away node, while filling in distance_map

        for ori in range(4):
            self.distance_map[target_nr, position[0], position[1]] = 0

        # Fill in the (up to) 4 neighboring nodes
        # nodes_queue = []  # list of tuples (row, col, direction, distance);  direction is the direction of movement, meaning that at least a possible orientation of an agent in cell (row,col) allows a movement in direction `direction'
        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))

        # BFS from target `position' to all the reachable nodes in the grid
        # Stop the search if the target position is re-visited, in any direction
        visited = set([(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2), (position[0], position[1], 3)])

        max_distance = 0

        while nodes_queue:
            node = nodes_queue.popleft()

            node_id = (node[0], node[1], node[2])

            #print(node_id, visited, (node_id in visited))
            #print(nodes_queue)

            if node_id not in visited:
                visited.add(node_id)

                # From the list of possible neighbors that have at least a path to the current node, only keep those whose new orientation in the current cell would allow a transition to direction node[2]
                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])

                for n in valid_neighbors:
                    nodes_queue.append(n)

                if len(valid_neighbors)>0:
                    max_distance = max(max_distance, node[3]+1)

        return max_distance


    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
        neighbors = []

        for direction in range(4):
            new_cell = self._new_position(position, (direction+2)%4)

            if new_cell[0]>=0 and new_cell[0]<self.env.height and new_cell[1]>=0 and new_cell[1]<self.env.width:
                # Check if the two cells are connected by a valid transition
                transitionValid = False
                for orientation in range(4):
                    moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
                    if moves[direction]:
                        transitionValid = True
                        break

                if not transitionValid:
                    continue

                # Check if a transition in direction node[2] is possible if an agent lands in the current cell with orientation `direction'; this only applies to cells that are not dead-ends!
                directionMatch = True
                if enforce_target_direction>=0:
                    directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), enforce_target_direction)

                # If transition is found to invalid, check if perhaps it is a dead-end, in which case the direction of movement is rotated 180 degrees (moving forward turns the agents and makes it step in the previous cell)
                if not directionMatch:
                    # If cell is a dead-end, append previous node with reversed
                    # orientation!
                    nbits = 0
                    tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
                    while tmp > 0:
                        nbits += (tmp & 1)
                        tmp = tmp >> 1
                    if nbits == 1:
                        # Dead-end!
                        # Check if transition is possible in new_cell with orientation (direction+2)%4 in direction `direction'
                        directionMatch = directionMatch or self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2)%4), direction)

                if transitionValid and directionMatch:
                    new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1]], current_distance+1)
                    neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
                    self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance

        return neighbors

    def _new_position(self, position, movement):
        if movement == 0:    # NORTH
            return (position[0]-1, position[1])
        elif movement == 1:  # EAST
            return (position[0], position[1] + 1)
        elif movement == 2:  # SOUTH
            return (position[0]+1, position[1])
        elif movement == 3:  # WEST
            return (position[0], position[1] - 1)

126
127
128
129
130
131
132

    def get(self, handle):
        # TODO: compute the observation for agent `handle'
        #raise NotImplementedError()
        return []


133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325



"""

    def get_observation(self, agent):
        # Get the current observation for an agent
        current_position = self.internal_position[agent]
        #target_heading = self._compass(agent).tolist()
        coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
        agent_distance = self.distance_map[agent][coordinate][0]
        # Start tree search
        if current_position == self.target[agent]:
            agent_tree = Node(current_position, [-np.inf, -np.inf, -np.inf, -np.inf, -1])
        else:
            agent_tree = Node(current_position, [0, 0, 0, 0, agent_distance])

        initial_tree_state = Tree_State(agent, current_position, -1, 0, 0)
        self._tree_search(initial_tree_state, agent_tree, agent)
        observation = []
        distance_data = []

        self._flatten_tree(agent_tree, observation, distance_data,  self.max_depth+1)
        # This is probably very slow!!!!
        #max_obs = np.max([i for i in observation if i < np.inf])
        #if max_obs != 0:
        #    observation = np.array(observation)/ max_obs

        #print([i for i in distance_data if i >= 0])
        observation = np.concatenate((observation, distance_data))
        #observation = np.concatenate((observation, np.identity(5)[int(self.last_action[agent])]))
        #return np.clip(observation / self.max_dist[agent], -1, 1)
        return np.clip(observation / 15., -1, 1)




    def _tree_search(self, in_tree_state, parent_node, agent):
        if in_tree_state.depth >= self.max_depth:
            return
        target_distance = np.inf
        other_target = np.inf
        other_agent = np.inf
        coordinate = tuple(np.transpose(self._position_to_coordinate([in_tree_state.position])))
        curr_target_dist = self.distance_map[agent][coordinate][0]
        forbidden_action = (in_tree_state.direction + 2) % 4
        # Update the position
        failed_move = 0
        leaf_distance = in_tree_state.distance
        for child_idx in range(4):
            if child_idx != forbidden_action or in_tree_state.direction == -1:
                tree_state = copy.deepcopy(in_tree_state)
                tree_state.direction = child_idx
                current_position, invalid_move = self._detect_path(tree_state.position, tree_state.direction)
                if tree_state.initial_direction == None:
                    tree_state.initial_direction = child_idx
                if not invalid_move:
                    coordinate = tuple(np.transpose(self._position_to_coordinate([current_position])))
                    curr_target_dist = self.distance_map[agent][coordinate][0]
                    #if tree_state.initial_direction == None:
                    #    tree_state.initial_direction = child_idx
                    tree_state.position = current_position
                    tree_state.distance += 1


                    # Collect information at the current position
                    detection_distance = tree_state.distance
                    if current_position == self.target[tree_state.agent]:
                        target_distance = detection_distance

                    elif current_position in self.target:
                        other_target = detection_distance

                    if current_position in self.internal_position:
                        other_agent = detection_distance

                    tree_state.data[0] = self._min_greater_zero(target_distance, tree_state.data[0])
                    tree_state.data[1] = self._min_greater_zero(other_target, tree_state.data[1])
                    tree_state.data[2] = self._min_greater_zero(other_agent, tree_state.data[2])
                    tree_state.data[3] = tree_state.distance
                    tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])

                    if self._switch_detection(tree_state.position):
                        tree_state.depth += 1
                        new_tree_state = copy.deepcopy(tree_state)
                        new_node = parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
                        new_tree_state.initial_direction = None
                        new_tree_state.data = [np.inf, np.inf, np.inf, np.inf, np.inf]
                        self._tree_search(new_tree_state, new_node, agent)
                    else:
                        self._tree_search(tree_state, parent_node, agent)
                else:
                    failed_move += 1
            if failed_move == 3 and in_tree_state.direction != -1:
                tree_state.data[4] = self._min_greater_zero(curr_target_dist, tree_state.data[4])
                parent_node.insert(tree_state.position, tree_state.data, tree_state.initial_direction)
                return
        return

    def _flatten_tree(self, node, observation_vector, distance_sensor, depth):
        if depth <= 0:
            return
        if node != None:
            observation_vector.extend(node.data[:-1])
            distance_sensor.extend([node.data[-1]])
        else:
            observation_vector.extend([-np.inf, -np.inf, -np.inf, -np.inf])
            distance_sensor.extend([-np.inf])
        for child_idx in range(4):
            if node != None:
                child = node.children[child_idx]
            else:
                child = None
            self._flatten_tree(child, observation_vector, distance_sensor,  depth -1)



    def _switch_detection(self, position):
        # Hack to detect switches
        # This can later directly be derived from the transition matrix
        paths = 0
        for i in range(4):
            _, invalid_move = self._detect_path(position, i)
            if not invalid_move:
                paths +=1
            if paths >= 3:
                return True
        return False




    def _min_greater_zero(self, x, y):
        if x <= 0 and y <= 0:
            return 0
        if x < 0:
            return y
        if y < 0:
            return x
        return min(x, y)



"""


class Tree_State:
    """
    Keep track of the current state while building the tree
    """
    def __init__(self, agent, position, direction, depth, distance):
        self.agent = agent
        self.position = position
        self.direction = direction
        self.depth = depth
        self.initial_direction = None
        self.distance = distance
        self.data = [np.inf, np.inf, np.inf, np.inf, np.inf]

class Node():
    """
    Define a tree node to get populated during search
    """
    def __init__(self, position, data):
        self.n_children = 4
        self.children = [None]*self.n_children
        self.data = data
        self.position = position

    def insert(self, position, data, child_idx):
        """
        Insert new node with data

        @param data node data object to insert
        """
        new_node = Node(position, data)
        self.children[child_idx] = new_node
        return new_node

    def print_tree(self, i=0, depth = 0):
        """
        Print tree content inorder
        """
        current_i = i
        curr_depth = depth+1
        if i < self.n_children:
            if self.children[i] != None:
                self.children[i].print_tree(depth=curr_depth)
            current_i += 1
            if self.children[i] != None:
                self.children[i].print_tree(i, depth=curr_depth)