env_utils.py 11.6 KB
Newer Older
hagrid67's avatar
hagrid67 committed
1
2
3
4
5
6
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
7

8
import numpy as np
hagrid67's avatar
hagrid67 committed
9

10
11
from flatland.core.transitions import Grid4TransitionsEnum

12

hagrid67's avatar
hagrid67 committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def get_direction(pos1, pos2):
    """
    Assumes pos1 and pos2 are adjacent location on grid.
    Returns direction (int) that can be used with transitions.
    """
    diff_0 = pos2[0] - pos1[0]
    diff_1 = pos2[1] - pos1[1]
    if diff_0 < 0:
        return 0
    if diff_0 > 0:
        return 2
    if diff_1 > 0:
        return 1
    if diff_1 < 0:
        return 3
    return 0


def mirror(dir):
    return (dir + 2) % 4


def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
    # start by getting direction used to get to current node
    # and direction from current node to possible child node
    new_dir = get_direction(current_pos, new_pos)
    if prev_pos is not None:
        current_dir = get_direction(prev_pos, current_pos)
    else:
        current_dir = new_dir
    # create new transition that would go to child
    new_trans = rail_array[current_pos]
    if prev_pos is None:
        if new_trans == 0:
            # need to flip direction because of how end points are defined
            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
        else:
            # check if matches existing layout
            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
    else:
        # set the forward path
        new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
        # set the backwards path
        new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
    if new_pos == end_pos:
        # need to validate end pos setup as well
        new_trans_e = rail_array[end_pos]
        if new_trans_e == 0:
            # need to flip direction because of how end points are defined
            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
        else:
            # check if matches existing layout
            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)

        if not rail_trans.is_valid(new_trans_e):
            return False

    # is transition is valid?
    return rail_trans.is_valid(new_trans)


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def position_to_coordinate(width, position):
    """

    :param width:
    :param position:
    :return:
    """
    coords = ()
    for p in position:
        coords = coords + ((int(p) % width, int(p) // width),)  # changed x_dim to y_dim
    return coords


def coordinate_to_position(width, coords):
    """

    :param width:
    :param coords:
    :return:
    """
94
95
    position = np.empty(len(coords), dtype=int)
    idx = 0
96
    for t in coords:
97
98
99
        position[idx] = int(t[1] * width + t[0])
        idx += 1
    return position
100

u214892's avatar
u214892 committed
101

maljx's avatar
maljx committed
102
103
104
105
106
107
108
109
110
111
112
113
114
class AStarNode():
    """A node class for A* Pathfinding"""

    def __init__(self, parent=None, pos=None):
        self.parent = parent
        self.pos = pos
        self.g = 0
        self.h = 0
        self.f = 0

    def __eq__(self, other):
        return self.pos == other.pos

115
116
117
    def __hash__(self):
        return hash(self.pos)

maljx's avatar
maljx committed
118
119
120
121
122
123
124
125
    def update_if_better(self, other):
        if other.g < self.g:
            self.parent = other.parent
            self.g = other.g
            self.h = other.h
            self.f = other.f


hagrid67's avatar
hagrid67 committed
126
127
128
129
130
131
132
133
def a_star(rail_trans, rail_array, start, end):
    """
    Returns a list of tuples as a path from the given start to end.
    If no path is found, returns path to closest point to end.
    """
    rail_shape = rail_array.shape
    start_node = AStarNode(None, start)
    end_node = AStarNode(None, end)
134
135
136
    open_nodes = set()
    closed_nodes = set()
    open_nodes.add(start_node)
hagrid67's avatar
hagrid67 committed
137

138
    while len(open_nodes) > 0:
hagrid67's avatar
hagrid67 committed
139
        # get node with current shortest est. path (lowest f)
140
        current_node = None
141
        for item in open_nodes:
142
143
144
            if current_node is None:
                current_node = item
                continue
hagrid67's avatar
hagrid67 committed
145
146
147
148
            if item.f < current_node.f:
                current_node = item

        # pop current off open list, add to closed list
149
150
        open_nodes.remove(current_node)
        closed_nodes.add(current_node)
hagrid67's avatar
hagrid67 committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

        # found the goal
        if current_node == end_node:
            path = []
            current = current_node
            while current is not None:
                path.append(current.pos)
                current = current.parent
            # return reversed path
            return path[::-1]

        # generate children
        children = []
        if current_node.parent is not None:
            prev_pos = current_node.parent.pos
        else:
            prev_pos = None
        for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
            node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
170
            if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
hagrid67's avatar
hagrid67 committed
171
172
173
174
175
176
177
178
179
180
181
182
183
                continue

            # validate positions
            if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
                continue

            # create new node
            new_node = AStarNode(current_node, node_pos)
            children.append(new_node)

        # loop through children
        for child in children:
            # already in closed list?
184
            if child in closed_nodes:
hagrid67's avatar
hagrid67 committed
185
186
187
188
                continue

            # create the f, g, and h values
            child.g = current_node.g + 1
u214892's avatar
u214892 committed
189
            # this heuristic favors diagonal paths:
u214892's avatar
u214892 committed
190
            # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \#  noqa: E800
hagrid67's avatar
hagrid67 committed
191
192
193
194
195
            # this heuristic avoids diagonal paths
            child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
            child.f = child.g + child.h

            # already in the open list?
196
            if child in open_nodes:
hagrid67's avatar
hagrid67 committed
197
198
199
                continue

            # add the child to the open list
200
            open_nodes.add(child)
hagrid67's avatar
hagrid67 committed
201

maljx's avatar
maljx committed
202
        # no full path found
203
        if len(open_nodes) == 0:
maljx's avatar
maljx committed
204
            return []
hagrid67's avatar
hagrid67 committed
205
206
207
208
209
210
211
212
213


def connect_rail(rail_trans, rail_array, start, end):
    """
    Creates a new path [start,end] in rail_array, based on rail_trans.
    """
    # in the worst case we will need to do a A* search, so we might as well set that up
    path = a_star(rail_trans, rail_array, start, end)
    if len(path) < 2:
214
        return []
hagrid67's avatar
hagrid67 committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    current_dir = get_direction(path[0], path[1])
    end_pos = path[-1]
    for index in range(len(path) - 1):
        current_pos = path[index]
        new_pos = path[index + 1]
        new_dir = get_direction(current_pos, new_pos)

        new_trans = rail_array[current_pos]
        if index == 0:
            if new_trans == 0:
                # end-point
                # need to flip direction because of how end points are defined
                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
            else:
                # into existing rail
                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
        else:
            # set the forward path
            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
            # set the backwards path
            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
        rail_array[current_pos] = new_trans

        if new_pos == end_pos:
            # setup end pos setup
            new_trans_e = rail_array[end_pos]
            if new_trans_e == 0:
                # end-point
                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
            else:
                # into existing rail
                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
            rail_array[end_pos] = new_trans_e

        current_dir = new_dir
250
    return path
hagrid67's avatar
hagrid67 committed
251
252
253
254


def distance_on_rail(pos1, pos2):
    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
255
256
257


def get_new_position(position, movement):
258
259
    """ Utility function that converts a compass movement over a 2D grid to new positions (r, c). """
    if movement == Grid4TransitionsEnum.NORTH:
260
        return (position[0] - 1, position[1])
261
    elif movement == Grid4TransitionsEnum.EAST:
262
        return (position[0], position[1] + 1)
263
    elif movement == Grid4TransitionsEnum.SOUTH:
264
        return (position[0] + 1, position[1])
265
    elif movement == Grid4TransitionsEnum.WEST:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        return (position[0], position[1] - 1)


def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
    """
    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).

    TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
    """

    def _path_exists(rail, start, direction, end):
        # BFS - Check if a path exists between the 2 nodes

        visited = set()
        stack = [(start, direction)]
        while stack:
            node = stack.pop()
            if node[0][0] == end[0] and node[0][1] == end[1]:
                return 1
            if node not in visited:
                visited.add(node)
                moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
                for move_index in range(4):
                    if moves[move_index]:
                        stack.append((get_new_position(node[0], move_index),
                                      move_index))

                # If cell is a dead-end, append previous node with reversed
                # orientation!
                nbits = 0
                tmp = rail.get_transitions((node[0][0], node[0][1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    stack.append((node[0], (node[1] + 2) % 4))

        return 0

    valid_positions = []
    for r in range(rail.height):
        for c in range(rail.width):
            if rail.get_transitions((r, c)) > 0:
                valid_positions.append((r, c))

    re_generate = True
    while re_generate:
        agents_position = [
            valid_positions[i] for i in
            np.random.choice(len(valid_positions), num_agents)]
        agents_target = [
            valid_positions[i] for i in
            np.random.choice(len(valid_positions), num_agents)]

        # agents_direction must be a direction for which a solution is
        # guaranteed.
        agents_direction = [0] * num_agents
        re_generate = False
        for i in range(num_agents):
            valid_movements = []
            for direction in range(4):
                position = agents_position[i]
                moves = rail.get_transitions((position[0], position[1], direction))
                for move_index in range(4):
                    if moves[move_index]:
                        valid_movements.append((direction, move_index))

            valid_starting_directions = []
            for m in valid_movements:
                new_position = get_new_position(agents_position[i], m[1])
336
                if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
337
338
339
340
341
342
343
344
                    valid_starting_directions.append(m[0])

            if len(valid_starting_directions) == 0:
                re_generate = True
            else:
                agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]

    return agents_position, agents_direction, agents_target