rail_env.py 46.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""
Definition of the RailEnv environment and related level-generation functions.

Generator functions are functions that take width, height and num_resets as arguments and return
a GridTransitionMap object.
"""
import numpy as np

from flatland.core.env import Environment
from flatland.core.env_observation_builder import TreeObsForRailEnv

spiglerg's avatar
spiglerg committed
12
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions
13
14
15
from flatland.core.transition_map import GridTransitionMap


maljx's avatar
maljx committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class AStarNode():
    """A node class for A* Pathfinding"""

    def __init__(self, parent=None, pos=None):
        self.parent = parent
        self.pos = pos
        self.g = 0
        self.h = 0
        self.f = 0

    def __eq__(self, other):
        return self.pos == other.pos

    def update_if_better(self, other):
        if other.g < self.g:
            self.parent = other.parent
            self.g = other.g
            self.h = other.h
            self.f = other.f


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def get_direction(pos1, pos2):
    """
    Assumes pos1 and pos2 are adjacent location on grid.
    Returns direction (int) that can be used with transitions.
    """
    diff_0 = pos2[0] - pos1[0]
    diff_1 = pos2[1] - pos1[1]
    if diff_0 < 0:
        return 0
    if diff_0 > 0:
        return 2
    if diff_1 > 0:
        return 1
    if diff_1 < 0:
        return 3
    return 0


def mirror(dir):
    return (dir + 2) % 4


def validate_new_transition(rail_trans, rail_array, prev_pos, current_pos, new_pos, end_pos):
    # start by getting direction used to get to current node
    # and direction from current node to possible child node
    new_dir = get_direction(current_pos, new_pos)
    if prev_pos is not None:
        current_dir = get_direction(prev_pos, current_pos)
    else:
        current_dir = new_dir
    # create new transition that would go to child
    new_trans = rail_array[current_pos]
    if prev_pos is None:
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
70
71
72
73
74
75
        if new_trans == 0:
            # need to flip direction because of how end points are defined
            new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
        else:
            # check if matches existing layout
            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
maljx's avatar
maljx committed
76
            # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
77
            # rail_trans.print(new_trans)
78
79
80
81
82
83
84
85
    else:
        # set the forward path
        new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
        # set the backwards path
        new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
    if new_pos == end_pos:
        # need to validate end pos setup as well
        new_trans_e = rail_array[end_pos]
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
86
87
88
89
90
91
        if new_trans_e == 0:
            # need to flip direction because of how end points are defined
            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
        else:
            # check if matches existing layout
            new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
maljx's avatar
maljx committed
92
            # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
93
94
95
            # print("end:", end_pos, current_pos)
            # rail_trans.print(new_trans_e)

96
97
98
        # print("========> end trans")
        # rail_trans.print(new_trans_e)
        if not rail_trans.is_valid(new_trans_e):
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
99
            # print("end failed", end_pos, current_pos)
100
            return False
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
101
102
103
        # else:
        #    print("end ok!", end_pos, current_pos)

104
105
106
107
108
109
110
    # is transition is valid?
    # print("=======> trans")
    # rail_trans.print(new_trans)
    return rail_trans.is_valid(new_trans)


def a_star(rail_trans, rail_array, start, end):
maljx's avatar
maljx committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    """
    Returns a list of tuples as a path from the given start to end.
    If no path is found, returns path to closest point to end.
    """
    rail_shape = rail_array.shape
    start_node = AStarNode(None, start)
    end_node = AStarNode(None, end)
    open_list = []
    closed_list = []

    open_list.append(start_node)

    # this could be optimized
    def is_node_in_list(node, the_list):
        for o_node in the_list:
            if node == o_node:
                return o_node
        return None

    while len(open_list) > 0:
        # get node with current shortest est. path (lowest f)
        current_node = open_list[0]
        current_index = 0
        for index, item in enumerate(open_list):
            if item.f < current_node.f:
                current_node = item
                current_index = index

        # pop current off open list, add to closed list
        open_list.pop(current_index)
        closed_list.append(current_node)

        # print("a*:", current_node.pos)
        # for cn in closed_list:
        #    print("closed:", cn.pos)

        # found the goal
        if current_node == end_node:
            path = []
            current = current_node
            while current is not None:
                path.append(current.pos)
                current = current.parent
            # return reversed path
            return path[::-1]

        # generate children
        children = []
159
160
161
162
        if current_node.parent is not None:
            prev_pos = current_node.parent.pos
        else:
            prev_pos = None
maljx's avatar
maljx committed
163
164
165
        for new_pos in [(0, -1), (0, 1), (-1, 0), (1, 0)]:
            node_pos = (current_node.pos[0] + new_pos[0], current_node.pos[1] + new_pos[1])
            if node_pos[0] >= rail_shape[0] or \
Erik Nygren's avatar
Erik Nygren committed
166
167
168
                node_pos[0] < 0 or \
                node_pos[1] >= rail_shape[1] or \
                node_pos[1] < 0:
maljx's avatar
maljx committed
169
170
171
172
173
174
175
                continue

            # validate positions
            # debug: avoid all current rails
            # if rail_array.item(node_pos) != 0:
            #    continue

176
177
178
179
180
            # validate positions
            if not validate_new_transition(rail_trans, rail_array, prev_pos, current_node.pos, node_pos, end_node.pos):
                # print("A*: transition invalid")
                continue

maljx's avatar
maljx committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
            # create new node
            new_node = AStarNode(current_node, node_pos)
            children.append(new_node)

        # loop through children
        for child in children:
            # already in closed list?
            closed_node = is_node_in_list(child, closed_list)
            if closed_node is not None:
                continue

            # create the f, g, and h values
            child.g = current_node.g + 1
            # this heuristic favors diagonal paths
            # child.h = ((child.pos[0] - end_node.pos[0]) ** 2) + \
            #           ((child.pos[1] - end_node.pos[1]) ** 2)
            # this heuristic avoids diagonal paths
            child.h = abs(child.pos[0] - end_node.pos[0]) + abs(child.pos[1] - end_node.pos[1])
            child.f = child.g + child.h

            # already in the open list?
            open_node = is_node_in_list(child, open_list)
            if open_node is not None:
                open_node.update_if_better(child)
                continue

            # add the child to the open list
            open_list.append(child)

        # no full path found, return partial path
        if len(open_list) == 0:
            path = []
            current = current_node
            while current is not None:
                path.append(current.pos)
                current = current.parent
            # return reversed path
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
218
            print("partial:", start, end, path[::-1])
maljx's avatar
maljx committed
219
220
221
            return path[::-1]


222
223
224
225
226
227
228
229
230
231
232
233
234
def connect_rail(rail_trans, rail_array, start, end):
    """
    Creates a new path [start,end] in rail_array, based on rail_trans.
    """
    # in the worst case we will need to do a A* search, so we might as well set that up
    path = a_star(rail_trans, rail_array, start, end)
    # print("connecting path", path)
    if len(path) < 2:
        return
    current_dir = get_direction(path[0], path[1])
    end_pos = path[-1]
    for index in range(len(path) - 1):
        current_pos = path[index]
Erik Nygren's avatar
Erik Nygren committed
235
        new_pos = path[index + 1]
236
237
238
239
        new_dir = get_direction(current_pos, new_pos)

        new_trans = rail_array[current_pos]
        if index == 0:
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
240
241
242
243
244
245
246
            if new_trans == 0:
                # end-point
                # need to flip direction because of how end points are defined
                new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1)
            else:
                # into existing rail
                new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
maljx's avatar
maljx committed
247
248
                # new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
                pass
249
250
251
252
253
254
255
256
        else:
            # set the forward path
            new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
            # set the backwards path
            new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
        rail_array[current_pos] = new_trans

        if new_pos == end_pos:
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
257
            # setup end pos setup
258
            new_trans_e = rail_array[end_pos]
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
259
260
261
262
263
264
            if new_trans_e == 0:
                # end-point
                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1)
            else:
                # into existing rail
                new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
maljx's avatar
maljx committed
265
                # new_trans_e = rail_trans.set_transition(new_trans_e, mirror(new_dir), mirror(new_dir), 1)
266
267
268
269
270
            rail_array[end_pos] = new_trans_e

        current_dir = new_dir


271
272
273
274
275
def distance_on_rail(pos1, pos2):
    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])


def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0):
maljx's avatar
maljx committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    """
    Parameters
    -------
    width : int
        The width (number of cells) of the grid to generate.
    height : int
        The height (number of cells) of the grid to generate.

    Returns
    -------
    numpy.ndarray of type numpy.uint16
        The matrix with the correct 16-bit bitmaps for each cell.
    """

    def generator(width, height, num_resets=0):
        rail_trans = RailEnvTransitions()
        rail_array = np.zeros(shape=(width, height), dtype=np.uint16)

294
        np.random.seed(seed + num_resets)
maljx's avatar
maljx committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316

        # generate rail array
        # step 1:
        # - generate a list of start and goal positions
        # - use a min/max distance allowed as input for this
        # - validate that start/goals are not placed too close to other start/goals
        #
        # step 2: (optional)
        # - place random elements on rails array
        #   - for instance "train station", etc.
        #
        # step 3:
        # - iterate over all [start, goal] pairs:
        #   - [first X pairs]
        #     - draw a rail from [start,goal]
        #     - draw either vertical or horizontal part first (randomly)
        #     - if rail crosses existing rail then validate new connection
        #       - if new connection is invalid turn 90 degrees to left/right
        #       - possibility that this fails to create a path to goal
        #         - on failure goto step1 and retry with seed+1
        #     - [avoid crossing other start,goal positions] (optional)
        #
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
317
        #   - [after X pairs]
maljx's avatar
maljx committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        #     - find closest rail from start (Pa)
        #       - iterating outwards in a "circle" from start until an existing rail cell is hit
        #     - connect [start, Pa]
        #       - validate crossing rails
        #     - Do A* from Pa to find closest point on rail (Pb) to goal point
        #       - Basically normal A* but find point on rail which is closest to goal
        #       - since full path to goal is unlikely
        #     - connect [Pb, goal]
        #       - validate crossing rails
        #
        # step 4: (optional)
        # - add more rails to map randomly
        #
        # step 5:
        # - return transition map + list of [start, goal] points
        #

        start_goal = []
        for _ in range(nr_start_goal):
337
338
339
340
            sanity_max = 9000
            for _ in range(sanity_max):
                start = (np.random.randint(0, width), np.random.randint(0, height))
                goal = (np.random.randint(0, height), np.random.randint(0, height))
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
341
                # check to make sure start,goal pos is empty?
342
343
                if rail_array[goal] != 0 or rail_array[start] != 0:
                    continue
344
345
346
347
348
349
350
351
                # check min/max distance
                dist_sg = distance_on_rail(start, goal)
                if dist_sg < min_dist:
                    continue
                if dist_sg > max_dist:
                    continue
                # check distance to existing points
                sg_new = [start, goal]
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
352

353
354
355
356
357
358
359
360
361
                def check_all_dist(sg_new):
                    for sg in start_goal:
                        for i in range(2):
                            for j in range(2):
                                dist = distance_on_rail(sg_new[i], sg[j])
                                if dist < 2:
                                    # print("too close:", dist, sg_new[i], sg[j])
                                    return False
                    return True
Erik Nygren's avatar
Erik Nygren committed
362

363
364
                if check_all_dist(sg_new):
                    break
maljx's avatar
maljx committed
365
            start_goal.append([start, goal])
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
366
            connect_rail(rail_trans, rail_array, start, goal)
maljx's avatar
maljx committed
367

Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
368
369
        print("Created #", len(start_goal), "pairs")
        # print(start_goal)
maljx's avatar
maljx committed
370
371
372

        return_rail = GridTransitionMap(width=width, height=height, transitions=rail_trans)
        return_rail.grid = rail_array
373
        # TODO: return start_goal
maljx's avatar
maljx committed
374
375
376
377
378
        return return_rail

    return generator


379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
def rail_from_manual_specifications_generator(rail_spec):
    """
    Utility to convert a rail given by manual specification as a map of tuples
    (cell_type, rotation), to a transition map with the correct 16-bit
    transitions specifications.

    Parameters
    -------
    rail_spec : list of list of tuples
        List (rows) of lists (columns) of tuples, each specifying a cell for
        the RailEnv environment as (cell_type, rotation), with rotation being
        clock-wise and in [0, 90, 180, 270].

    Returns
    -------
    function
        Generator function that always returns a GridTransitionMap object with
        the matrix of correct 16-bit bitmaps for each cell.
    """
Erik Nygren's avatar
Erik Nygren committed
398

399
400
401
402
403
404
405
406
407
408
409
410
411
    def generator(width, height, num_resets=0):
        t_utils = RailEnvTransitions()

        height = len(rail_spec)
        width = len(rail_spec[0])
        rail = GridTransitionMap(width=width, height=height, transitions=t_utils)

        for r in range(height):
            for c in range(width):
                cell = rail_spec[r][c]
                if cell[0] < 0 or cell[0] >= len(t_utils.transitions):
                    print("ERROR - invalid cell type=", cell[0])
                    return []
spiglerg's avatar
spiglerg committed
412
                rail.set_transitions((r, c), t_utils.rotate_transition(t_utils.transitions[cell[0]], cell[1]))
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433

        return rail

    return generator


def rail_from_GridTransitionMap_generator(rail_map):
    """
    Utility to convert a rail given by a GridTransitionMap map with the correct
    16-bit transitions specifications.

    Parameters
    -------
    rail_map : GridTransitionMap object
        GridTransitionMap object to return when the generator is called.

    Returns
    -------
    function
        Generator function that always returns the given `rail_map' object.
    """
Erik Nygren's avatar
Erik Nygren committed
434

435
436
437
438
439
440
    def generator(width, height, num_resets=0):
        return rail_map

    return generator


441
442
443
444
445
446
447
448
449
450
451
452
453
454
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
    """
    Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.

    Parameters
    -------
    list_of_filenames : list
        List of filenames with the saved grids to load.

    Returns
    -------
    function
        Generator function that always returns the given `rail_map' object.
    """
Erik Nygren's avatar
Erik Nygren committed
455

456
457
458
459
460
461
462
463
464
465
466
467
468
    def generator(width, height, num_resets=0):
        t_utils = RailEnvTransitions()
        rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
        rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)

        if rail_map.grid.dtype == np.uint64:
            rail_map.transitions = Grid8Transitions()

        return rail_map

    return generator


469
470
471
472
473
474
475
476
477
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
    def generator(width, height, num_resets=0):
        return generate_rail_from_manual_specifications(list_of_specifications)

    return generator
"""


spiglerg's avatar
spiglerg committed
478
def random_rail_generator(cell_type_relative_proportion=[1.0] * 8):
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
    """
    Dummy random level generator:
    - fill in cells at random in [width-2, height-2]
    - keep filling cells in among the unfilled ones, such that all transitions
      are legit;  if no cell can be filled in without violating some
      transitions, pick one among those that can satisfy most transitions
      (1,2,3 or 4), and delete (+mark to be re-filled) the cells that were
      incompatible.
    - keep trying for a total number of insertions
      (e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the
      board and try again from scratch.
    - finally pad the border of the map with dead-ends to avoid border issues.

    Dead-ends are not allowed inside the grid, only at the border; however, if
    no cell type can be inserted in a given cell (because of the neighboring
    transitions), deadends are allowed if they solve the problem. This was
    found to turn most un-genereatable levels into valid ones.

    Parameters
    -------
    width : int
        The width (number of cells) of the grid to generate.
    height : int
        The height (number of cells) of the grid to generate.

    Returns
    -------
    numpy.ndarray of type numpy.uint16
        The matrix with the correct 16-bit bitmaps for each cell.
    """

510
511
512
513
514
515
516
    def generator(width, height, num_resets=0):
        t_utils = RailEnvTransitions()

        transition_probability = cell_type_relative_proportion

        transitions_templates_ = []
        transition_probabilities = []
Mattias Ljungstrom's avatar
Mattias Ljungstrom committed
517
        for i in range(len(t_utils.transitions) - 4):  # don't include dead-ends
518
519
520
521
522
523
524
525
526
            all_transitions = 0
            for dir_ in range(4):
                trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
                all_transitions |= (trans[0] << 3) | \
                                   (trans[1] << 2) | \
                                   (trans[2] << 1) | \
                                   (trans[3])

            template = [int(x) for x in bin(all_transitions)[2:]]
spiglerg's avatar
spiglerg committed
527
            template = [0] * (4 - len(template)) + template
528
529
530
531

            # add all rotations
            for rot in [0, 90, 180, 270]:
                transitions_templates_.append((template,
Erik Nygren's avatar
Erik Nygren committed
532
533
534
                                               t_utils.rotate_transition(
                                                   t_utils.transitions[i],
                                                   rot)))
535
                transition_probabilities.append(transition_probability[i])
spiglerg's avatar
spiglerg committed
536
                template = [template[-1]] + template[:-1]
537
538
539
540
541
542
543

        def get_matching_templates(template):
            ret = []
            for i in range(len(transitions_templates_)):
                is_match = True
                for j in range(4):
                    if template[j] >= 0 and \
Erik Nygren's avatar
Erik Nygren committed
544
                        template[j] != transitions_templates_[i][0][j]:
545
546
547
548
549
550
                        is_match = False
                        break
                if is_match:
                    ret.append((transitions_templates_[i][1], transition_probabilities[i]))
            return ret

spiglerg's avatar
spiglerg committed
551
        MAX_INSERTIONS = (width - 2) * (height - 2) * 10
552
553
554
555
556
557
558
        MAX_ATTEMPTS_FROM_SCRATCH = 10

        attempt_number = 0
        while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
            cells_to_fill = []
            rail = []
            for r in range(height):
spiglerg's avatar
spiglerg committed
559
560
561
                rail.append([None] * width)
                if r > 0 and r < height - 1:
                    cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
562
563
564

            num_insertions = 0
            while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
565
566
                # cell = random.sample(cells_to_fill, 1)[0]
                cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
567
568
569
570
571
572
573
574
575
576
577
578
                cells_to_fill.remove(cell)
                row = cell[0]
                col = cell[1]

                # look at its neighbors and see what are the possible transitions
                # that can be chosen from, if any.
                valid_template = [-1, -1, -1, -1]

                for el in [(0, 2, (-1, 0)),
                           (1, 3, (0, 1)),
                           (2, 0, (1, 0)),
                           (3, 1, (0, -1))]:  # N, E, S, W
spiglerg's avatar
spiglerg committed
579
                    neigh_trans = rail[row + el[2][0]][col + el[2][1]]
580
581
582
583
584
                    if neigh_trans is not None:
                        # select transition coming from facing direction el[1] and
                        # moving to direction el[1]
                        max_bit = 0
                        for k in range(4):
spiglerg's avatar
spiglerg committed
585
                            max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
586
587
588
589
590
591
592
593
594
595
596
597
598
599

                        if max_bit:
                            valid_template[el[0]] = 1
                        else:
                            valid_template[el[0]] = 0

                possible_cell_transitions = get_matching_templates(valid_template)

                if len(possible_cell_transitions) == 0:  # NO VALID TRANSITIONS
                    # no cell can be filled in without violating some transitions
                    # can a dead-end solve the problem?
                    if valid_template.count(1) == 1:
                        for k in range(4):
                            if valid_template[k] == 1:
600
                                rot = 0
601
602
603
604
605
606
607
608
609
                                if k == 0:
                                    rot = 180
                                elif k == 1:
                                    rot = 270
                                elif k == 2:
                                    rot = 0
                                elif k == 3:
                                    rot = 90

spiglerg's avatar
spiglerg committed
610
                                rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
611
                                num_insertions += 1
612

613
614
615
616
617
618
619
620
621
622
                                break

                    else:
                        # can I get valid transitions by removing a single
                        # neighboring cell?
                        bestk = -1
                        besttrans = []
                        for k in range(4):
                            tmp_template = valid_template[:]
                            tmp_template[k] = -1
spiglerg's avatar
spiglerg committed
623
                            possible_cell_transitions = get_matching_templates(tmp_template)
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
                            if len(possible_cell_transitions) > len(besttrans):
                                besttrans = possible_cell_transitions
                                bestk = k

                        if bestk >= 0:
                            # Replace the corresponding cell with None, append it
                            # to cells to fill, fill in a transition in the current
                            # cell.
                            replace_row = row - 1
                            replace_col = col
                            if bestk == 1:
                                replace_row = row
                                replace_col = col + 1
                            elif bestk == 2:
                                replace_row = row + 1
                                replace_col = col
                            elif bestk == 3:
                                replace_row = row
                                replace_col = col - 1

                            cells_to_fill.append((replace_row, replace_col))
                            rail[replace_row][replace_col] = None

                            possible_transitions, possible_probabilities = zip(*besttrans)
spiglerg's avatar
spiglerg committed
648
                            possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
649
650
651

                            rail[row][col] = np.random.choice(possible_transitions,
                                                              p=possible_probabilities)
652
653
                            num_insertions += 1

654
655
656
657
658
                        else:
                            print('WARNING: still nothing!')
                            rail[row][col] = int('0000000000000000', 2)
                            num_insertions += 1
                            pass
659
660

                else:
661
                    possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
spiglerg's avatar
spiglerg committed
662
                    possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
663

664
665
666
                    rail[row][col] = np.random.choice(possible_transitions,
                                                      p=possible_probabilities)
                    num_insertions += 1
667

668
669
670
671
672
            if num_insertions == MAX_INSERTIONS:
                # Failed to generate a valid level; try again for a number of times
                attempt_number += 1
            else:
                break
673

674
675
        if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
            print('ERROR: failed to generate level')
676

677
678
679
680
681
682
683
684
        # Finally pad the border of the map with dead-ends to avoid border issues;
        # at most 1 transition in the neigh cell
        for r in range(height):
            # Check for transitions coming from [r][1] to WEST
            max_bit = 0
            neigh_trans = rail[r][1]
            if neigh_trans is not None:
                for k in range(4):
Erik Nygren's avatar
Erik Nygren committed
685
                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
686
687
                    max_bit = max_bit | (neigh_trans_from_direction & 1)
            if max_bit:
spiglerg's avatar
spiglerg committed
688
                rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
689
690
691
692
693
694
695
696
            else:
                rail[r][0] = int('0000000000000000', 2)

            # Check for transitions coming from [r][-2] to EAST
            max_bit = 0
            neigh_trans = rail[r][-2]
            if neigh_trans is not None:
                for k in range(4):
Erik Nygren's avatar
Erik Nygren committed
697
                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
698
699
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
            if max_bit:
700
                rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
701
                                                        90)
702
            else:
703
704
                rail[r][-1] = int('0000000000000000', 2)

705
        for c in range(width):
706
707
708
709
710
            # Check for transitions coming from [1][c] to NORTH
            max_bit = 0
            neigh_trans = rail[1][c]
            if neigh_trans is not None:
                for k in range(4):
Erik Nygren's avatar
Erik Nygren committed
711
                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
712
713
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
            if max_bit:
714
                rail[0][c] = int('0010000000000000', 2)
715
716
717
718
719
720
721
722
            else:
                rail[0][c] = int('0000000000000000', 2)

            # Check for transitions coming from [-2][c] to SOUTH
            max_bit = 0
            neigh_trans = rail[-2][c]
            if neigh_trans is not None:
                for k in range(4):
Erik Nygren's avatar
Erik Nygren committed
723
                    neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
724
725
                    max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
            if max_bit:
spiglerg's avatar
spiglerg committed
726
                rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
727
728
            else:
                rail[-1][c] = int('0000000000000000', 2)
729

730
731
732
733
734
735
736
        # For display only, wrong levels
        for r in range(height):
            for c in range(width):
                if rail[r][c] is None:
                    rail[r][c] = int('0000000000000000', 2)

        tmp_rail = np.asarray(rail, dtype=np.uint16)
737

738
739
740
741
742
        return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
        return_rail.grid = tmp_rail
        return return_rail

    return generator
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772


class RailEnv(Environment):
    """
    RailEnv environment class.

    RailEnv is an environment inspired by a (simplified version of) a rail
    network, in which agents (trains) have to navigate to their target
    locations in the shortest time possible, while at the same time cooperating
    to avoid bottlenecks.

    The valid actions in the environment are:
        0: do nothing
        1: turn left and move to the next cell
        2: move to the next cell in front of the agent
        3: turn right and move to the next cell

    Moving forward in a dead-end cell makes the agent turn 180 degrees and step
    to the cell it came from.

    The actions of the agents are executed in order of their handle to prevent
    deadlocks and to allow them to learn relative priorities.

    TODO: WRITE ABOUT THE REWARD FUNCTION, and possibly allow for alpha and
    beta to be passed as parameters to __init__().
    """

    def __init__(self,
                 width,
                 height,
spiglerg's avatar
spiglerg committed
773
                 rail_generator=random_rail_generator(),
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
                 number_of_agents=1,
                 obs_builder_object=TreeObsForRailEnv(max_depth=2)):
        """
        Environment init.

        Parameters
        -------
        rail_generator : function
            The rail_generator function is a function that takes the width and
            height of a  rail map along with the number of times the env has
            been reset, and returns a GridTransitionMap object.
            Implemented functions are:
                random_rail_generator : generate a random rail of given size
                rail_from_GridTransitionMap_generator(rail_map) : generate a rail from
                                        a GridTransitionMap object
                rail_from_manual_specifications_generator(rail_spec) : generate a rail from
                                        a rail specifications array
                TODO: generate_rail_from_saved_list or from list of ndarray bitmaps ---
        width : int
            The width of the rail map. Potentially in the future,
            a range of widths to sample from.
        height : int
            The height of the rail map. Potentially in the future,
            a range of heights to sample from.
        number_of_agents : int
            Number of agents to spawn on the map. Potentially in the future,
            a range of number of agents to sample from.
        obs_builder_object: ObservationBuilder object
            ObservationBuilder-derived object that takes builds observation
            vectors for each agent.
        """

        self.rail_generator = rail_generator
        self.rail = None
        self.width = width
        self.height = height

        self.number_of_agents = number_of_agents

        self.obs_builder = obs_builder_object
        self.obs_builder._set_env(self)

spiglerg's avatar
spiglerg committed
816
817
        self.actions = [0] * self.number_of_agents
        self.rewards = [0] * self.number_of_agents
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        self.done = False

        self.dones = {"__all__": False}
        self.obs_dict = {}
        self.rewards_dict = {}

        self.agents_handles = list(range(self.number_of_agents))

        # self.agents_position = []
        # self.agents_target = []
        # self.agents_direction = []
        self.num_resets = 0
        self.reset()
        self.num_resets = 0

833
834
        self.valid_positions = None

835
836
837
    def get_agent_handles(self):
        return self.agents_handles

838
839
840
841
842
843
844
845
846
    def fill_valid_positions(self):
        self.valid_positions = valid_positions = []
        for r in range(self.height):
            for c in range(self.width):
                if self.rail.get_transitions((r, c)) > 0:
                    valid_positions.append((r, c))

    def check_agent_lists(self):
        for lAgents, name in zip(
Erik Nygren's avatar
Erik Nygren committed
847
848
849
            [self.agents_handles, self.agents_position, self.agents_direction],
            ["handles", "positions", "directions"]):
            assert self.number_of_agents == len(lAgents), "Inconsistent agent list:" + name
850
851
852
853
854
855
856
857
858
859
860
861
862
863

    def check_agent_locdirpath(self, iAgent):
        valid_movements = []
        for direction in range(4):
            position = self.agents_position[iAgent]
            moves = self.rail.get_transitions((position[0], position[1], direction))
            for move_index in range(4):
                if moves[move_index]:
                    valid_movements.append((direction, move_index))

        valid_starting_directions = []
        for m in valid_movements:
            new_position = self._new_position(self.agents_position[iAgent], m[1])
            if m[0] not in valid_starting_directions and \
Erik Nygren's avatar
Erik Nygren committed
864
                self._path_exists(new_position, m[0], self.agents_target[iAgent]):
865
866
867
868
869
870
871
872
873
874
875
876
                valid_starting_directions.append(m[0])

        if len(valid_starting_directions) == 0:
            return False

    def pick_agent_direction(self, rcPos, rcTarget):
        valid_movements = []
        for direction in range(4):
            moves = self.rail.get_transitions((*rcPos, direction))
            for move_index in range(4):
                if moves[move_index]:
                    valid_movements.append((direction, move_index))
877
        # print("pos", rcPos, "targ", rcTarget, "valid movements", valid_movements)
878
879
880
881
882

        valid_starting_directions = []
        for m in valid_movements:
            new_position = self._new_position(rcPos, m[1])
            if m[0] not in valid_starting_directions and \
Erik Nygren's avatar
Erik Nygren committed
883
                self._path_exists(new_position, m[0], rcTarget):
884
885
886
887
888
889
890
891
892
893
894
895
896
                valid_starting_directions.append(m[0])

        if len(valid_starting_directions) == 0:
            return None
        else:
            return valid_starting_directions[np.random.choice(len(valid_starting_directions), 1)[0]]

    def add_agent(self, rcPos=None, rcTarget=None, iDir=None):
        self.check_agent_lists()

        if rcPos is None:
            rcPos = np.random.choice(len(self.valid_positions))

897
        iAgent = self.number_of_agents
Erik Nygren's avatar
Erik Nygren committed
898

899
900
        self.agents_position.append(tuple(rcPos))  # ensure it's a tuple not a list
        self.agents_handles.append(max(self.agents_handles + [-1]) + 1)  # max(handles) + 1, starting at 0
hagrid67's avatar
hagrid67 committed
901
902
903
904

        if iDir is None:
            iDir = self.pick_agent_direction(rcPos, rcTarget)
        self.agents_direction.append(iDir)
905
        self.agents_target.append(rcPos)  # set the target to the origin initially
hagrid67's avatar
hagrid67 committed
906
        self.number_of_agents += 1
907
        self.check_agent_lists()
908
        return iAgent
Erik Nygren's avatar
Erik Nygren committed
909

910
911
    def reset(self, regen_rail=True, replace_agents=True):
        if regen_rail or self.rail is None:
spmohanty's avatar
spmohanty committed
912
            # TODO: Import not only rail information but also start and goal positions
913
914
915
            self.rail = self.rail_generator(self.width, self.height, self.num_resets)
            self.fill_valid_positions()

916
917
918
919
920
921
        self.num_resets += 1

        self.dones = {"__all__": False}
        for handle in self.agents_handles:
            self.dones[handle] = False

922
923
        # Use a TreeObsForRailEnv to compute distance maps to each agent's target, to sample initial
        # agent's orientations that allow a valid solution.
924
        # TODO: Possibility ot fill valid positions from list of goals and start
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
        self.fill_valid_positions()

        if replace_agents:
            re_generate = True
            while re_generate:

                # self.agents_position = random.sample(valid_positions,
                #                                     self.number_of_agents)
                self.agents_position = [
                    self.valid_positions[i] for i in
                    np.random.choice(len(self.valid_positions), self.number_of_agents)]
                self.agents_target = [
                    self.valid_positions[i] for i in
                    np.random.choice(len(self.valid_positions), self.number_of_agents)]

                # agents_direction must be a direction for which a solution is
                # guaranteed.
                self.agents_direction = [0] * self.number_of_agents
                re_generate = False

                for i in range(self.number_of_agents):
946
                    direction = self.pick_agent_direction(self.agents_position[i], self.agents_target[i])
947
948
949
950
                    if direction is None:
                        re_generate = True
                        break
                    else:
951
                        self.agents_direction[i] = direction
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967

                # Jeremy extracted this into the method pick_agent_direction
                if False:
                    for i in range(self.number_of_agents):
                        valid_movements = []
                        for direction in range(4):
                            position = self.agents_position[i]
                            moves = self.rail.get_transitions((position[0], position[1], direction))
                            for move_index in range(4):
                                if moves[move_index]:
                                    valid_movements.append((direction, move_index))

                        valid_starting_directions = []
                        for m in valid_movements:
                            new_position = self._new_position(self.agents_position[i], m[1])
                            if m[0] not in valid_starting_directions and \
Erik Nygren's avatar
Erik Nygren committed
968
                                self._path_exists(new_position, m[0], self.agents_target[i]):
969
970
971
972
973
974
975
                                valid_starting_directions.append(m[0])

                        if len(valid_starting_directions) == 0:
                            re_generate = True
                        else:
                            self.agents_direction[i] = valid_starting_directions[
                                np.random.choice(len(valid_starting_directions), 1)[0]]
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991

        # Reset the state of the observation builder with the new environment
        self.obs_builder.reset()

        # Return the new observation vectors for each agent
        return self._get_observations()

    def step(self, action_dict):
        alpha = 1.0
        beta = 1.0

        invalid_action_penalty = -2
        step_penalty = -1 * alpha
        global_reward = 1 * beta

        # Reset the step rewards
992
        self.rewards_dict = dict()
993
994
995
996
997
998
999
1000
        for handle in self.agents_handles:
            self.rewards_dict[handle] = 0

        if self.dones["__all__"]:
            return self._get_observations(), self.rewards_dict, self.dones, {}

        for i in range(len(self.agents_handles)):
            handle = self.agents_handles[i]
1001
            transition_isValid = None
Erik Nygren's avatar
Erik Nygren committed
1002

1003
1004
1005
            if handle not in action_dict:
                continue

1006
1007
            if self.dones[handle]:
                continue
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
            action = action_dict[handle]

            if action < 0 or action > 3:
                print('ERROR: illegal action=', action,
                      'for agent with handle=', handle)
                return

            if action > 0:
                pos = self.agents_position[i]
                direction = self.agents_direction[i]

Erik Nygren's avatar
Erik Nygren committed
1019
1020
1021
1022
1023
                # compute number of possible transitions in the current
                # cell used to check for invalid actions

                nbits = 0
                tmp = self.rail.get_transitions((pos[0], pos[1]))
Erik Nygren's avatar
Erik Nygren committed
1024
1025
                print(np.sum(self.rail.get_transitions((pos[0], pos[1],direction))),self.rail.get_transitions((pos[0], pos[1],direction)),self.rail.get_transitions((pos[0], pos[1])),(pos[0], pos[1],direction))

Erik Nygren's avatar
Erik Nygren committed
1026
1027
1028
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
1029
1030
1031
                movement = direction
                if action == 1:
                    movement = direction - 1
1032
                    if nbits <= 2:
1033
                        transition_isValid = False
1034

1035
1036
                elif action == 3:
                    movement = direction + 1
1037
                    if nbits <= 2:
1038
                        transition_isValid = False
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
                if movement < 0:
                    movement += 4
                if movement >= 4:
                    movement -= 4

                is_deadend = False
                if action == 2:
                    if nbits == 1:
                        # dead-end;  assuming the rail network is consistent,
                        # this should match the direction the agent has come
                        # from. But it's better to check in any case.
                        reverse_direction = 0
                        if direction == 0:
                            reverse_direction = 2
                        elif direction == 1:
                            reverse_direction = 3
                        elif direction == 2:
                            reverse_direction = 0
                        elif direction == 3:
                            reverse_direction = 1

                        valid_transition = self.rail.get_transition(
spiglerg's avatar
spiglerg committed
1061
1062
                            (pos[0], pos[1], direction),
                            reverse_direction)
1063
1064
1065
1066
                        if valid_transition:
                            direction = reverse_direction
                            movement = reverse_direction
                            is_deadend = True
1067
                    if nbits == 2:
1068
1069
                        # Checking for curves

1070
1071
1072
1073
1074
1075
                        valid_transition = self.rail.get_transition(
                            (pos[0], pos[1], direction),
                            movement)
                        reverse_direction = (direction + 2) % 4
                        curv_dir = (movement + 1) % 4
                        while not valid_transition:
1076
1077
1078
1079
1080
1081
1082
1083
                            if curv_dir != reverse_direction:
                                valid_transition = self.rail.get_transition(
                                    (pos[0], pos[1], direction),
                                    curv_dir)
                            if valid_transition:
                                movement = curv_dir
                            curv_dir = (curv_dir + 1) % 4

1084

1085
1086
1087
1088
                new_position = self._new_position(pos, movement)
                # Is it a legal move?  1) transition allows the movement in the
                # cell,  2) the new cell is not empty (case 0),  3) the cell is
                # free, i.e., no agent is currently in that cell
Erik Nygren's avatar
Erik Nygren committed
1089
1090
1091
                if new_position[1] >= self.width or \
                    new_position[0] >= self.height or \
                    new_position[0] < 0 or new_position[1] < 0:
1092
1093
1094
1095
1096
1097
1098
                    new_cell_isValid = False

                elif self.rail.get_transitions((new_position[0], new_position[1])) > 0:
                    new_cell_isValid = True
                else:
                    new_cell_isValid = False

Erik Nygren's avatar
Erik Nygren committed
1099
                # If transition validity hasn't been checked yet.
1100
1101
1102
1103
                if transition_isValid == None:
                    transition_isValid = self.rail.get_transition(
                        (pos[0], pos[1], direction),
                        movement) or is_deadend
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121

                cell_isFree = True
                for j in range(self.number_of_agents):
                    if self.agents_position[j] == new_position:
                        cell_isFree = False
                        break

                if new_cell_isValid and transition_isValid and cell_isFree:
                    # move and change direction to face the movement that was
                    # performed
                    self.agents_position[i] = new_position
                    self.agents_direction[i] = movement
                else:
                    # the action was not valid, add penalty
                    self.rewards_dict[handle] += invalid_action_penalty

            # if agent is not in target position, add step penalty
            if self.agents_position[i][0] == self.agents_target[i][0] and \
Erik Nygren's avatar
Erik Nygren committed
1122
                self.agents_position[i][1] == self.agents_target[i][1]:
1123
1124
1125
1126
1127
1128
1129
1130
                self.dones[handle] = True
            else:
                self.rewards_dict[handle] += step_penalty

        # Check for end of episode + add global reward to all rewards!
        num_agents_in_target_position = 0
        for i in range(self.number_of_agents):
            if self.agents_position[i][0] == self.agents_target[i][0] and \
Erik Nygren's avatar
Erik Nygren committed
1131
                self.agents_position[i][1] == self.agents_target[i][1]:
1132
1133
1134
1135
                num_agents_in_target_position += 1

        if num_agents_in_target_position == self.number_of_agents:
            self.dones["__all__"] = True
spiglerg's avatar
spiglerg committed
1136
            self.rewards_dict = [r + global_reward for r in self.rewards_dict]
1137
1138
1139

        # Reset the step actions (in case some agent doesn't 'register_action'
        # on the next step)
spiglerg's avatar
spiglerg committed
1140
        self.actions = [0] * self.number_of_agents
1141
1142
1143
        return self._get_observations(), self.rewards_dict, self.dones, {}

    def _new_position(self, position, movement):
Erik Nygren's avatar
Erik Nygren committed
1144
        if movement == 0:  # NORTH
spiglerg's avatar
spiglerg committed
1145
            return (position[0] - 1, position[1])
1146
1147
1148
        elif movement == 1:  # EAST
            return (position[0], position[1] + 1)
        elif movement == 2:  # SOUTH
spiglerg's avatar
spiglerg committed
1149
            return (position[0] + 1, position[1])
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
        elif movement == 3:  # WEST
            return (position[0], position[1] - 1)

    def _path_exists(self, start, direction, end):
        # BFS - Check if a path exists between the 2 nodes

        visited = set()
        stack = [(start, direction)]
        while stack:
            node = stack.pop()
            if node[0][0] == end[0] and node[0][1] == end[1]:
                return 1
            if node not in visited:
                visited.add(node)
                moves = self.rail.get_transitions((node[0][0], node[0][1], node[1]))
                for move_index in range(4):
                    if moves[move_index]:
                        stack.append((self._new_position(node[0], move_index),
                                      move_index))

                # If cell is a dead-end, append previous node with reversed
                # orientation!
                nbits = 0
                tmp = self.rail.get_transitions((node[0][0], node[0][1]))
                while tmp > 0:
                    nbits += (tmp & 1)
                    tmp = tmp >> 1
                if nbits == 1:
                    stack.append((node[0], (node[1] + 2) % 4))

        return 0

    def _get_observations(self):
        self.obs_dict = {}
        for handle in self.agents_handles:
            self.obs_dict[handle] = self.obs_builder.get(handle)
        return self.obs_dict

    def render(self):
        # TODO:
        pass