env_utils.py 11.3 KB
Newer Older
hagrid67's avatar
hagrid67 committed
1
2
3
4
5
6
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
7

8
import numpy as np
hagrid67's avatar
hagrid67 committed
9

10

hagrid67's avatar
hagrid67 committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def get_direction(pos1, pos2):
    """
    Assumes pos1 and pos2 are adjacent location on grid.
    Returns direction (int) that can be used with transitions.
    """
    diff_0 = pos2[0] - pos1[0]
    diff_1 = pos2[1] - pos1[1]
    if diff_0 < 0:
        return 0
    if diff_0 > 0:
        return 2
    if diff_1 > 0:
        return 1
    if diff_1 < 0:
        return 3
    return 0


def mirror(dir):
    return (dir + 2) % 4


def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
    # start by getting direction used to get to current node
    # and direction from current node to possible child node
    new_dir = get_direction(current_pos, new_pos)
    if prev_pos is not None:
        current_dir = get_direction(prev_pos, current_pos)
    else:
        current_dir = new_dir
    # create new transition that would go to child
    new_trans = rail_array[current_pos]
    if prev_pos is None:
        if new_trans == 0:
            # need to flip direction because of how end points are defined
            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
        else:
            # check if matches existing layout
            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
    else:
        # set the forward path
        new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
        # set the backwards path
        new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
    if new_pos == end_pos:
        # need to validate end pos setup as well
        new_trans_e = rail_array[end_pos]
        if new_trans_e == 0:
            # need to flip direction because of how end points are defined
            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
        else:
            # check if matches existing layout
            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)

        if not rail_trans.is_valid(new_trans_e):
            return False

    # is transition is valid?
    return rail_trans.is_valid(new_trans)


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def position_to_coordinate(width, position):
    """

    :param width:
    :param position:
    :return:
    """
    coords = ()
    for p in position:
        coords = coords + ((int(p) % width, int(p) // width),)  # changed x_dim to y_dim
    return coords


def coordinate_to_position(width, coords):
    """

    :param width:
    :param coords:
    :return:
    """
    position = []
    for t in coords:
        position.append((t[1] * width + t[0]))
    return np.array(position)

u214892's avatar
u214892 committed
97

maljx's avatar
maljx committed
98
99
100
101
102
103
104
105
106
107
108
109
110
class AStarNode():
    """A node class for A* Pathfinding"""

    def __init__(self, parent=None, pos=None):
        self.parent = parent
        self.pos = pos
        self.g = 0
        self.h = 0
        self.f = 0

    def __eq__(self, other):
        return self.pos == other.pos

111
112
113
    def __hash__(self):
        return hash(self.pos)

maljx's avatar
maljx committed
114
115
116
117
118
119
120
121
    def update_if_better(self, other):
        if other.g < self.g:
            self.parent = other.parent
            self.g = other.g
            self.h = other.h
            self.f = other.f


hagrid67's avatar
hagrid67 committed
122
123
124
125
126
127
128
129
def a_star(rail_trans, rail_array, start, end):
    """
    Returns a list of tuples as a path from the given start to end.
    If no path is found, returns path to closest point to end.
    """
    rail_shape = rail_array.shape
    start_node = AStarNode(None, start)
    end_node = AStarNode(None, end)
130
131
132
    open_nodes = set()
    closed_nodes = set()
    open_nodes.add(start_node)
hagrid67's avatar
hagrid67 committed
133

134
    while len(open_nodes) > 0:
hagrid67's avatar
hagrid67 committed
135
        # get node with current shortest est. path (lowest f)
136
        current_node = None
137
        for item in open_nodes:
138
139
140
            if current_node is None:
                current_node = item
                continue
hagrid67's avatar
hagrid67 committed
141
142
143
144
            if item.f < current_node.f:
                current_node = item

        # pop current off open list, add to closed list
145
146
        open_nodes.remove(current_node)
        closed_nodes.add(current_node)
hagrid67's avatar
hagrid67 committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        # found the goal
        if current_node == end_node:
            path = []
            current = current_node
            while current is not None:
                path.append(current.pos)
                current = current.parent
            # return reversed path
            return path[::-1]

        # generate children
        children = []
        if current_node.parent is not None:
            prev_pos = current_node.parent.pos
        else:
            prev_pos = None
        for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
            node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
166
            if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
hagrid67's avatar
hagrid67 committed
167
168
169
170
171
172
173
174
175
176
177
178
179
                continue

            # validate positions
            if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
                continue

            # create new node
            new_node = AStarNode(current_node, node_pos)
            children.append(new_node)

        # loop through children
        for child in children:
            # already in closed list?
180
            if child in closed_nodes:
hagrid67's avatar
hagrid67 committed
181
182
183
184
                continue

            # create the f, g, and h values
            child.g = current_node.g + 1
u214892's avatar
u214892 committed
185
            # this heuristic favors diagonal paths:
u214892's avatar
u214892 committed
186
            # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + ((child.pos[1] - end_node.pos[1]) ** 2) \#  noqa: E800
hagrid67's avatar
hagrid67 committed
187
188
189
190
191
            # this heuristic avoids diagonal paths
            child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
            child.f = child.g + child.h

            # already in the open list?
192
            if child in open_nodes:
hagrid67's avatar
hagrid67 committed
193
194
195
                continue

            # add the child to the open list
196
            open_nodes.add(child)
hagrid67's avatar
hagrid67 committed
197

maljx's avatar
maljx committed
198
        # no full path found
199
        if len(open_nodes) == 0:
maljx's avatar
maljx committed
200
            return []
hagrid67's avatar
hagrid67 committed
201
202
203
204
205
206
207
208
209


def connect_rail(rail_trans, rail_array, start, end):
    """
    Creates a new path [start,end] in rail_array, based on rail_trans.
    """
    # in the worst case we will need to do a A* search, so we might as well set that up
    path = a_star(rail_trans, rail_array, start, end)
    if len(path) < 2:
210
        return []
hagrid67's avatar
hagrid67 committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    current_dir = get_direction(path[0], path[1])
    end_pos = path[-1]
    for index in range(len(path) - 1):
        current_pos = path[index]
        new_pos = path[index + 1]
        new_dir = get_direction(current_pos, new_pos)

        new_trans = rail_array[current_pos]
        if index == 0:
            if new_trans == 0:
                # end-point
                # need to flip direction because of how end points are defined
                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
            else:
                # into existing rail
                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
        else:
            # set the forward path
            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
            # set the backwards path
            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
        rail_array[current_pos] = new_trans

        if new_pos == end_pos:
            # setup end pos setup
            new_trans_e = rail_array[end_pos]
            if new_trans_e == 0:
                # end-point
                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
            else:
                # into existing rail
                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
            rail_array[end_pos] = new_trans_e

        current_dir = new_dir
246
    return path
hagrid67's avatar
hagrid67 committed
247
248
249
250


def distance_on_rail(pos1, pos2):
    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330


def get_new_position(position, movement):
    if movement == 0:  # NORTH
        return (position[0] - 1, position[1])
    elif movement == 1:  # EAST
        return (position[0], position[1] + 1)
    elif movement == 2:  # SOUTH
        return (position[0] + 1, position[1])
    elif movement == 3:  # WEST
        return (position[0], position[1] - 1)


def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents):
    """
    Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target).

    TODO: add extensive documentation, as users may need this function to simplify their custom level generators.
    """

    def _path_exists(rail, start, direction, end):
        # BFS - Check if a path exists between the 2 nodes

        visited = set()
        stack = [(start, direction)]
        while stack:
            node = stack.pop()
            if node[0][0] == end[0] and node[0][1] == end[1]:
                return 1
            if node not in visited:
                visited.add(node)
                moves = rail.get_transitions((node[0][0], node[0][1], node[1]))
                for move_index in range(4):
                    if moves[move_index]:
                        stack.append((get_new_position(node[0], move_index),
                                      move_index))

                # If cell is a dead-end, append previous node with reversed
                # orientation!
                nbits = 0
                tmp = rail.get_transitions((node[0][0], node[0][1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    stack.append((node[0], (node[1] + 2) % 4))

        return 0

    valid_positions = []
    for r in range(rail.height):
        for c in range(rail.width):
            if rail.get_transitions((r, c)) > 0:
                valid_positions.append((r, c))

    re_generate = True
    while re_generate:
        agents_position = [
            valid_positions[i] for i in
            np.random.choice(len(valid_positions), num_agents)]
        agents_target = [
            valid_positions[i] for i in
            np.random.choice(len(valid_positions), num_agents)]

        # agents_direction must be a direction for which a solution is
        # guaranteed.
        agents_direction = [0] * num_agents
        re_generate = False
        for i in range(num_agents):
            valid_movements = []
            for direction in range(4):
                position = agents_position[i]
                moves = rail.get_transitions((position[0], position[1], direction))
                for move_index in range(4):
                    if moves[move_index]:
                        valid_movements.append((direction, move_index))

            valid_starting_directions = []
            for m in valid_movements:
                new_position = get_new_position(agents_position[i], m[1])
331
                if m[0] not in valid_starting_directions and _path_exists(rail, new_position, m[0], agents_target[i]):
332
333
334
335
336
337
338
339
                    valid_starting_directions.append(m[0])

            if len(valid_starting_directions) == 0:
                re_generate = True
            else:
                agents_direction[i] = valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]

    return agents_position, agents_direction, agents_target