rail_generators.py 38.2 KB
Newer Older
u214892's avatar
u214892 committed
1
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
Erik Nygren's avatar
Erik Nygren committed
2
import sys
3
import warnings
Erik Nygren's avatar
Erik Nygren committed
4
from typing import Callable, Tuple, Optional, Dict, List
5

hagrid67's avatar
hagrid67 committed
6
import numpy as np
7
from numpy.random.mtrand import RandomState
hagrid67's avatar
hagrid67 committed
8

9
from flatland.core.grid.grid4 import Grid4TransitionsEnum
u229589's avatar
u229589 committed
10
from flatland.core.grid.grid4_utils import get_direction, mirror, direction_to_point
11
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
u229589's avatar
u229589 committed
12
from flatland.core.grid.grid_utils import distance_on_rail, IntVector2DArray, IntVector2D, \
u229589's avatar
u229589 committed
13
    Vec2dOperations
u214892's avatar
u214892 committed
14
from flatland.core.grid.rail_env_grid import RailEnvTransitions
15
from flatland.core.transition_map import GridTransitionMap
16
from flatland.envs.grid4_generators_utils import connect_rail_in_grid_map, connect_straight_line_in_grid_map, \
17
    fix_inner_nodes, align_cell_to_city
18
from flatland.envs import persistence
hagrid67's avatar
hagrid67 committed
19

20

u214892's avatar
u214892 committed
21
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
22
23
24
25
""" A rail generator returns a RailGenerator Product, which is just
    a GridTransitionMap followed by an (optional) dict/
"""

u214892's avatar
u214892 committed
26
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
hagrid67's avatar
hagrid67 committed
27

u214892's avatar
u214892 committed
28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class RailGen(object):
    """ Base class for RailGen(erator) replacement
    
        WIP to replace bare generators with classes / objects without unnamed local variables
        which prevent pickling.
    """ 
    def __init__(self, *args, **kwargs):
        """ constructor to record any state to be reused in each "generation"
        """
        pass

    def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
                  np_random: RandomState = None) -> RailGeneratorProduct:
        pass

    def __call__(self, *args, **kwargs) -> RailGeneratorProduct:
        return self.generate(*args, **kwargs)





u214892's avatar
u214892 committed
51
def empty_rail_generator() -> RailGenerator:
52
53
54
    return EmptyRailGen()

class EmptyRailGen(RailGen):
55
56
57
58
    """
    Returns a generator which returns an empty rail mail with no agents.
    Primarily used by the editor
    """
59

60
    def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
61
                  np_random: RandomState = None) -> RailGenerator:
62
63
64
65
66
        rail_trans = RailEnvTransitions()
        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
        rail_array = grid_map.grid
        rail_array.fill(0)

u214892's avatar
u214892 committed
67
        return grid_map, None
68

hagrid67's avatar
hagrid67 committed
69
70
71
72
73
74
75
76

def rail_from_manual_specifications_generator(rail_spec):
    """
    Utility to convert a rail given by manual specification as a map of tuples
    (cell_type, rotation), to a transition map with the correct 16-bit
    transitions specifications.

    Parameters
u214892's avatar
u214892 committed
77
    ----------
hagrid67's avatar
hagrid67 committed
78
    rail_spec : list of list of tuples
u214892's avatar
u214892 committed
79
        List (rows) of lists (columns) of tuples, each specifying a rail_spec_of_cell for
hagrid67's avatar
hagrid67 committed
80
81
82
83
84
85
86
        the RailEnv environment as (cell_type, rotation), with rotation being
        clock-wise and in [0, 90, 180, 270].

    Returns
    -------
    function
        Generator function that always returns a GridTransitionMap object with
u214892's avatar
u214892 committed
87
        the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
hagrid67's avatar
hagrid67 committed
88
89
    """

90
91
    def generator(width: int, height: int, num_agents: int, num_resets: int = 0,
                  np_random: RandomState = None) -> RailGenerator:
u214892's avatar
u214892 committed
92
        rail_env_transitions = RailEnvTransitions()
hagrid67's avatar
hagrid67 committed
93
94
95

        height = len(rail_spec)
        width = len(rail_spec[0])
u214892's avatar
u214892 committed
96
        rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions)
hagrid67's avatar
hagrid67 committed
97
98
99

        for r in range(height):
            for c in range(width):
u214892's avatar
u214892 committed
100
101
102
103
104
                rail_spec_of_cell = rail_spec[r][c]
                index_basic_type_of_cell_ = rail_spec_of_cell[0]
                rotation_cell_ = rail_spec_of_cell[1]
                if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(rail_env_transitions.transitions):
                    print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_)
hagrid67's avatar
hagrid67 committed
105
                    return []
u214892's avatar
u214892 committed
106
107
108
                basic_type_of_cell_ = rail_env_transitions.transitions[index_basic_type_of_cell_]
                effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
                rail.set_transitions((r, c), effective_transition_cell)
hagrid67's avatar
hagrid67 committed
109

u214892's avatar
u214892 committed
110
        return [rail, None]
hagrid67's avatar
hagrid67 committed
111
112
113
114

    return generator


u214892's avatar
u214892 committed
115
def rail_from_file(filename, load_from_package=None) -> RailGenerator:
116
117
118
119
    """
    Utility to load pickle file

    Parameters
u214892's avatar
u214892 committed
120
    ----------
u214892's avatar
u214892 committed
121
    filename : Pickle file generated by env.save() or editor
122
123
124
125
126
127
128
129

    Returns
    -------
    function
        Generator function that always returns a GridTransitionMap object with
        the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
    """

130
    def generator(width: int, height: int, num_agents: int, num_resets: int = 0,
131
132
                  np_random: RandomState = None) -> List:
        env_dict = persistence.RailEnvPersister.load_env_dict(filename, load_from_package=load_from_package)
133
        rail_env_transitions = RailEnvTransitions()
134
135

        grid = np.array(env_dict["grid"])
136
137
        rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
        rail.grid = grid
138
139
        if "distance_map" in env_dict:
            distance_map = env_dict["distance_map"]
140
141
            if len(distance_map) > 0:
                return rail, {'distance_map': distance_map}
u214892's avatar
u214892 committed
142
        return [rail, None]
143

144
145
    return generator

146
class RailFromGridGen(RailGen):
147
    def __init__(self, rail_map, optionals=None):
148
        self.rail_map = rail_map
149
        self.optionals = optionals
150
151
152

    def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
                  np_random: RandomState = None) -> RailGenerator:
153
        return self.rail_map, self.optionals
154

155

156
157
def rail_from_grid_transition_map(rail_map, optionals=None) -> RailGenerator:
    return RailFromGridGen(rail_map, optionals)
hagrid67's avatar
hagrid67 committed
158
159


160
161
162
163
164
def sparse_rail_generator(*args, **kwargs):
    return SparseRailGen(*args, **kwargs)

class SparseRailGen(RailGen):

165
    def __init__(self, max_num_cities: int = 2, grid_mode: bool = False, max_rails_between_cities: int = 2,
166
                          max_rail_pairs_in_city: int = 2, seed=0) -> RailGenerator:
167
168
        """
        Generates railway networks with cities and inner city rails
Erik Nygren's avatar
Erik Nygren committed
169

170
171
172
173
174
175
176
177
178
        Parameters
        ----------
        max_num_cities : int
            Max number of cities to build. The generator tries to achieve this numbers given all the parameters
        grid_mode: Bool
            How to distribute the cities in the path, either equally in a grid or random
        max_rails_between_cities: int
            Max number of rails connecting to a city. This is only the number of connection points at city boarder.
            Number of tracks drawn inbetween cities can still vary
179
        max_rail_pairs_in_city: int
180
181
182
            Number of parallel tracks in the city. This represents the number of tracks in the trainstations
        seed: int
            Initiate the seed
Erik Nygren's avatar
Erik Nygren committed
183

184
185
186
187
188
189
190
        Returns
        -------
        Returns the rail generator object to the rail env constructor
        """
        self.max_num_cities = max_num_cities
        self.grid_mode = grid_mode
        self.max_rails_between_cities = max_rails_between_cities
191
        self.max_rail_pairs_in_city = max_rail_pairs_in_city
192
        self.seed = seed # TODO: seed in constructor or generate?
193

194
195

    def generate(self, width: int, height: int, num_agents: int, num_resets: int = 0,
196
                  np_random: RandomState = None) -> RailGenerator:
Erik Nygren's avatar
Erik Nygren committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
        """

        Parameters
        ----------
        width: int
            Width of the environment
        height: int
            Height of the environment
        num_agents:
            Number of agents to be placed within the environment
        num_resets: int
            Count for how often the environment has been reset

        Returns
        -------
        Returns the grid_map --> The railway infrastructure
        Hints:
        agents_hints': {
            'num_agents': how many agents have starting and end spots
            'agent_start_targets_cities': touples of agent start and target cities
            'train_stations': locations of train stations for start and targets
            'city_orientations' : orientation of cities
        """
220
221
222
        if np_random is None:
            np_random = RandomState()
            
223
224
        rail_trans = RailEnvTransitions()
        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
225
226
227
228
229
230

        # NEW : SCHED CONST (Pairs of rails (1,2,3 pairs))
        min_nr_rail_pairs_in_city = 1 # (min pair must be 1)
        rail_pairs_in_city = min_nr_rail_pairs_in_city if self.max_rail_pairs_in_city < min_nr_rail_pairs_in_city else self.max_rail_pairs_in_city # (pairs can be 1,2,3)
        rails_between_cities = (rail_pairs_in_city*2) if self.max_rails_between_cities > (rail_pairs_in_city*2) else self.max_rails_between_cities

Erik Nygren's avatar
Erik Nygren committed
231
232
        # We compute the city radius by the given max number of rails it can contain.
        # The radius is equal to the number of tracks divided by 2
233
        # We add 2 cells to avoid that track lenght is to short
Erik Nygren's avatar
Erik Nygren committed
234
        city_padding = 2
235
        # We use ceil if we get uneven numbers of city radius. This is to guarantee that all rails fit within the city.
236
        city_radius = int(np.ceil((rail_pairs_in_city*2) / 2)) + city_padding
Erik Nygren's avatar
Erik Nygren committed
237
        vector_field = np.zeros(shape=(height, width)) - 1.
238

239
240
        # Calculate the max number of cities allowed
        # and reduce the number of cities to build to avoid problems
241
        max_feasible_cities = min(self.max_num_cities,
Erik Nygren's avatar
Erik Nygren committed
242
                                  ((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
243
        if max_feasible_cities < 2:
244
245
            # sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
            raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!")
Erik Nygren's avatar
Erik Nygren committed
246

247
        # Evenly distribute cities
248
249
        if self.grid_mode:
            city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
Erik Nygren's avatar
Erik Nygren committed
250
                                                                   height)
Erik Nygren's avatar
Erik Nygren committed
251
        # Distribute cities randomlz
252
        else:
253
            city_positions = self._generate_random_city_positions(max_feasible_cities, city_radius, width, height,
254
                                                             np_random=np_random)
255

u229589's avatar
u229589 committed
256
        # reduce num_cities if less were generated in random mode
257
        num_cities = len(city_positions)
Erik Nygren's avatar
Erik Nygren committed
258
        # If random generation failed just put the cities evenly
259
        if num_cities < 2:
260
            warnings.warn("[WARNING] Changing to Grid mode to place at least 2 cities.")
261
            city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
Erik Nygren's avatar
Erik Nygren committed
262
                                                                   height)
263
264
        num_cities = len(city_positions)

265
        # Set up connection points for all cities
Erik Nygren's avatar
Erik Nygren committed
266
        inner_connection_points, outer_connection_points, city_orientations, city_cells = \
267
            self._generate_city_connection_points(
268
                city_positions, city_radius, vector_field, rails_between_cities,
269
                rail_pairs_in_city, np_random=np_random)
Erik Nygren's avatar
Erik Nygren committed
270

271
        # Connect the cities through the connection points
272
        inter_city_lines = self._connect_cities(city_positions, outer_connection_points, city_cells,
Erik Nygren's avatar
Erik Nygren committed
273
                                           rail_trans, grid_map)
274

275
        # Build inner cities
276
        free_rails = self._build_inner_cities(city_positions, inner_connection_points,
277
278
279
280
                                         outer_connection_points,
                                         rail_trans,
                                         grid_map)

281
        # Populate cities
282
        train_stations = self._set_trainstation_positions(city_positions, city_radius, free_rails)
Erik Nygren's avatar
Erik Nygren committed
283

284
        # Fix all transition elements
285
        self._fix_transitions(city_cells, inter_city_lines, grid_map, vector_field)
286
287
        return grid_map, {'agents_hints': {
            'num_agents': num_agents,
Erik Nygren's avatar
Erik Nygren committed
288
            'city_positions': city_positions,
289
            'train_stations': train_stations,
290
            'city_orientations': city_orientations
291
292
        }}

293
    def _generate_random_city_positions(self, num_cities: int, city_radius: int, width: int,
294
295
                                        height: int, np_random: RandomState = None) -> (
        IntVector2DArray, IntVector2DArray):
Erik Nygren's avatar
Erik Nygren committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
        """
        Distribute the cities randomly in the environment while respecting city sizes and guaranteeing that they
        don't overlap.

        Parameters
        ----------
        num_cities: int
            Max number of cities that should be placed
        city_radius: int
            Radius of each city. Cities are squares with edge length 2 * city_radius + 1
        width: int
            Width of the environment
        height: int
            Height of the environment

        Returns
        -------
        Returns a list of all city positions as coordinates (x,y)

        """
316

u229589's avatar
u229589 committed
317
        city_positions: IntVector2DArray = []
u229589's avatar
u229589 committed
318
319
        for city_idx in range(num_cities):
            too_close = True
320
321
            tries = 0

u229589's avatar
u229589 committed
322
            while too_close:
323
324
                row = city_radius + 1 + np_random.randint(height - 2 * (city_radius + 1))
                col = city_radius + 1 + np_random.randint(width - 2 * (city_radius + 1))
u229589's avatar
u229589 committed
325
                too_close = False
u229589's avatar
u229589 committed
326
327
                # Check distance to cities
                for city_pos in city_positions:
328
                    if self.__class__._are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1):
u229589's avatar
u229589 committed
329
                        too_close = True
330

u229589's avatar
u229589 committed
331
332
                if not too_close:
                    city_positions.append((row, col))
333

334
                tries += 1
335
                if tries > 200:
336
                    warnings.warn(
Erik Nygren's avatar
Erik Nygren committed
337
                        "Could not set all required cities!")
338
                    break
339
        return city_positions
340

341
    def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int
342
                                              ) -> (IntVector2DArray, IntVector2DArray):
Erik Nygren's avatar
Erik Nygren committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        """
        Distribute the cities in an evenly spaced grid

        Parameters
        ----------
        num_cities: int
            Max number of cities that should be placed
        city_radius: int
            Radius of each city. Cities are squares with edge length 2 * city_radius + 1
        width: int
            Width of the environment
        height: int
            Height of the environment

        Returns
        -------
        Returns a list of all city positions as coordinates (x,y)

        """
u229589's avatar
u229589 committed
362
        aspect_ratio = height / width
Erik Nygren's avatar
Erik Nygren committed
363

Erik Nygren's avatar
Erik Nygren committed
364
365
366
367
        # Compute max numbe of possible cities per row and col.
        # Respect padding at edges of environment
        # Respect padding between cities
        padding = 2
368
        city_size = 2 * (city_radius + 1)
Erik Nygren's avatar
Erik Nygren committed
369
        max_cities_per_row = int((height - padding) // city_size)
Erik Nygren's avatar
Erik Nygren committed
370
        max_cities_per_col = int((width - padding) // city_size)
371

Erik Nygren's avatar
Erik Nygren committed
372
373
        # Choose number of cities per row.
        # Limit if it is more then max number of possible cities
374
375
376

        cities_per_row = min(int(np.ceil(np.sqrt(num_cities * aspect_ratio))), max_cities_per_row)
        cities_per_col = min(int(np.ceil(num_cities / cities_per_row)), max_cities_per_col)
377
        num_build_cities = min(num_cities, cities_per_col * cities_per_row)
Erik Nygren's avatar
Erik Nygren committed
378
379
        row_positions = np.linspace(city_radius + 2, height - (city_radius + 2), cities_per_row, dtype=int)
        col_positions = np.linspace(city_radius + 2, width - (city_radius + 2), cities_per_col, dtype=int)
u229589's avatar
u229589 committed
380
        city_positions = []
381

382
        for city_idx in range(num_build_cities):
u229589's avatar
u229589 committed
383
384
385
            row = row_positions[city_idx % cities_per_row]
            col = col_positions[city_idx // cities_per_row]
            city_positions.append((row, col))
386
        return city_positions
u229589's avatar
u229589 committed
387

388
    def _generate_city_connection_points(self, city_positions: IntVector2DArray, city_radius: int,
Erik Nygren's avatar
Erik Nygren committed
389
                                         vector_field: IntVector2DArray, rails_between_cities: int,
390
                                         rail_pairs_in_city: int = 1, np_random: RandomState = None) -> (
391
392
393
394
        List[List[List[IntVector2D]]],
        List[List[List[IntVector2D]]],
        List[np.ndarray],
        List[Grid4TransitionsEnum]):
Erik Nygren's avatar
Erik Nygren committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
        """
        Generate the city connection points. Internal connection points are used to generate the parallel paths
        within the city.
        External connection points are used to connect different cities together

        Parameters
        ----------
        city_positions: IntVector2DArray
            Vector that contains all the positions of the cities
        city_radius: int
            Radius of each city. Cities are squares with edge length 2 * city_radius + 1
        vector_field: IntVector2DArray
            Vectorfield of the size of the environment. It is used to generate preferred orienations for each cell.
            Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
        rails_between_cities: int
            Number of rails that connect out from the city
411
        rail_pairs_in_city: int
Erik Nygren's avatar
Erik Nygren committed
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
            Number of rails within the city

        Returns
        -------
        inner_connection_points: List of List of length number of cities
            Contains all the inner connection points for each boarder of each city.
            [North_Points, East_Poinst, South_Points, West_Points]
        outer_connection_points: List of List of length number of cities
            Contains all the outer connection points for each boarder of the city.
            [North_Points, East_Poinst, South_Points, West_Points]
        city_orientations: List of length number of cities
            Contains all the orientations of cities. This is then used to orient agents according to the rails
        city_cells: List
            List containing the coordinates of all the cells that belong to a city. This is used by other algorithms
            to avoid drawing inter-city-rails through cities.
        """
428
429
430
        inner_connection_points: List[List[List[IntVector2D]]] = []
        outer_connection_points: List[List[List[IntVector2D]]] = []
        city_orientations: List[Grid4TransitionsEnum] = []
431
        city_cells: IntVector2DArray = []
432

u229589's avatar
u229589 committed
433
        for city_position in city_positions:
434

435
436
            # Chose the directions where close cities are situated
            neighb_dist = []
u229589's avatar
u229589 committed
437
438
            for neighbour_city in city_positions:
                neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_position, neighbour_city))
439
            closest_neighb_idx = self.__class__.argsort(neighb_dist)
440

441
            # Store the directions to these neighbours and orient city to face closest neighbour
442
            connection_sides_idx = []
443
            idx = 1
444
            if self.grid_mode:
445
                current_closest_direction = np_random.randint(4)
446
            else:
447
                current_closest_direction = direction_to_point(city_position, city_positions[closest_neighb_idx[idx]])
448
449
            connection_sides_idx.append(current_closest_direction)
            connection_sides_idx.append((current_closest_direction + 2) % 4)
450
            city_orientations.append(current_closest_direction)
451
            city_cells.extend(self._get_cells_in_city(city_position, city_radius, city_orientations[-1], vector_field))
452
            # set the number of tracks within a city, at least 2 tracks per city
453
            connections_per_direction = np.zeros(4, dtype=int)
454
455
            # NEW : SCHED CONST
            nr_of_connection_points = np_random.randint(1, rail_pairs_in_city + 1) * 2  # can be (1,2,3)*2 = (2,4,6)
456
            for idx in connection_sides_idx:
457
                connections_per_direction[idx] = nr_of_connection_points
458
459
            connection_points_coordinates_inner: List[List[IntVector2D]] = [[] for i in range(4)]
            connection_points_coordinates_outer: List[List[IntVector2D]] = [[] for i in range(4)]
460
            number_of_out_rails = np_random.randint(1, min(rails_between_cities, nr_of_connection_points) + 1)
461
            start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
462
            for direction in range(4):
463
                connection_slots = np.arange(nr_of_connection_points) - start_idx
464
                # Offset the rails away from the center of the city
465
                offset_distances = np.arange(nr_of_connection_points) - int(nr_of_connection_points / 2)
466
467
                # The clipping helps ofsetting one side more than the other to avoid switches at same locations
                # The magic number plus one is added such that all points have at least one offset
468
                inner_point_offset = np.abs(offset_distances) + np.clip(offset_distances, 0, 1) + 1
469
470
                for connection_idx in range(connections_per_direction[direction]):
                    if direction == 0:
471
                        tmp_coordinates = (
472
473
474
                            city_position[0] - city_radius + inner_point_offset[connection_idx],
                            city_position[1] + connection_slots[connection_idx])
                        out_tmp_coordinates = (
u229589's avatar
u229589 committed
475
                            city_position[0] - city_radius, city_position[1] + connection_slots[connection_idx])
476
                    if direction == 1:
477
                        tmp_coordinates = (
478
479
480
                            city_position[0] + connection_slots[connection_idx],
                            city_position[1] + city_radius - inner_point_offset[connection_idx])
                        out_tmp_coordinates = (
u229589's avatar
u229589 committed
481
                            city_position[0] + connection_slots[connection_idx], city_position[1] + city_radius)
482
                    if direction == 2:
483
                        tmp_coordinates = (
484
485
486
                            city_position[0] + city_radius - inner_point_offset[connection_idx],
                            city_position[1] + connection_slots[connection_idx])
                        out_tmp_coordinates = (
u229589's avatar
u229589 committed
487
                            city_position[0] + city_radius, city_position[1] + connection_slots[connection_idx])
488
                    if direction == 3:
489
                        tmp_coordinates = (
490
491
492
                            city_position[0] + connection_slots[connection_idx],
                            city_position[1] - city_radius + inner_point_offset[connection_idx])
                        out_tmp_coordinates = (
u229589's avatar
u229589 committed
493
                            city_position[0] + connection_slots[connection_idx], city_position[1] - city_radius)
494
                    connection_points_coordinates_inner[direction].append(tmp_coordinates)
Erik Nygren's avatar
Erik Nygren committed
495
                    if connection_idx in range(start_idx, start_idx + number_of_out_rails):
496
                        connection_points_coordinates_outer[direction].append(out_tmp_coordinates)
497
498
499

            inner_connection_points.append(connection_points_coordinates_inner)
            outer_connection_points.append(connection_points_coordinates_outer)
Erik Nygren's avatar
Erik Nygren committed
500
        return inner_connection_points, outer_connection_points, city_orientations, city_cells
501

502
    def _connect_cities(self, city_positions: IntVector2DArray, connection_points: List[List[List[IntVector2D]]],
Erik Nygren's avatar
Erik Nygren committed
503
                        city_cells: IntVector2DArray,
Erik Nygren's avatar
Erik Nygren committed
504
                        rail_trans: RailEnvTransitions, grid_map: RailEnvTransitions) -> List[IntVector2DArray]:
505
        """
Erik Nygren's avatar
Erik Nygren committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        Connects cities together through rails. Each city connects from its outgoing connection points to the closest
        cities. This guarantees that all connection points are used.

        Parameters
        ----------
        city_positions: IntVector2DArray
            All coordinates of the cities
        connection_points: List[List[List[IntVector2D]]]
            List of coordinates of all outer connection points
        city_cells: IntVector2DArray
            Coordinates of all the cells contained in any city. This is used to avoid drawing rails through existing
            cities.
        rail_trans: RailEnvTransitions
            Railway transition objects
        grid_map: RailEnvTransitions
            The grid map containing the rails. Used to draw new rails

        Returns
        -------
        Returns a list of all the cells (Coordinates) that belong to a rail path. This can be used to access railway
        cells later.
527
        """
528
        all_paths: List[IntVector2DArray] = []
529

u229589's avatar
u229589 committed
530
531
532
        grid4_directions = [Grid4TransitionsEnum.NORTH, Grid4TransitionsEnum.EAST, Grid4TransitionsEnum.SOUTH,
                            Grid4TransitionsEnum.WEST]

u229589's avatar
u229589 committed
533
        for current_city_idx in np.arange(len(city_positions)):
534
            closest_neighbours = self._closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
u229589's avatar
u229589 committed
535
536
            for out_direction in grid4_directions:

537
                neighbour_idx = self.get_closest_neighbour_for_direction(closest_neighbours, out_direction)
u229589's avatar
u229589 committed
538
539
540

                for city_out_connection_point in connection_points[current_city_idx][out_direction]:

Erik Nygren's avatar
Erik Nygren committed
541
                    min_connection_dist = np.inf
u229589's avatar
u229589 committed
542
543
                    for direction in grid4_directions:
                        current_points = connection_points[neighbour_idx][direction]
Erik Nygren's avatar
Erik Nygren committed
544
                        for tmp_in_connection_point in current_points:
u229589's avatar
u229589 committed
545
                            tmp_dist = Vec2dOperations.get_manhattan_distance(city_out_connection_point,
u229589's avatar
u229589 committed
546
                                                                              tmp_in_connection_point)
Erik Nygren's avatar
Erik Nygren committed
547
548
                            if tmp_dist < min_connection_dist:
                                min_connection_dist = tmp_dist
u229589's avatar
u229589 committed
549
550
551
552
553
                                neighbour_connection_point = tmp_in_connection_point

                    new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point,
                                                        rail_trans, flip_start_node_trans=False,
                                                        flip_end_node_trans=False, respect_transition_validity=False,
554
                                                        avoid_rail=True,
u229589's avatar
u229589 committed
555
                                                        forbidden_cells=city_cells)
Erik Nygren's avatar
Erik Nygren committed
556
                    all_paths.extend(new_line)
557

Erik Nygren's avatar
Erik Nygren committed
558
        return all_paths
559

560
    def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction):
Erik Nygren's avatar
Erik Nygren committed
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        """
        Given a list of clostest neighbours in each direction this returns the city index of the neighbor in a given
        direction. Direction is a 90 degree cone facing the desired directiont.
        Exampe:
            North: The closes neighbour in the North direction is within the cone spanned by a line going
            North-West and North-East

        Parameters
        ----------
        closest_neighbours: List
            List of length 4 containing the index of closes neighbour in the corresponfing direction:
            [North-Neighbour, East-Neighbour, South-Neighbour, West-Neighbour]
        out_direction: int
            Direction we want to get city index from
            North: 0, East: 1, South: 2, West: 3

        Returns
        -------
        Returns the index of the closest neighbour in the desired direction. If none was present the neighbor clockwise
        or counter clockwise is returned
        """

u229589's avatar
u229589 committed
583
584
585
586
        neighbour_idx = closest_neighbours[out_direction]
        if neighbour_idx is not None:
            return neighbour_idx

Erik Nygren's avatar
Erik Nygren committed
587
        neighbour_idx = closest_neighbours[(out_direction - 1) % 4]  # counter-clockwise
u229589's avatar
u229589 committed
588
589
590
591
592
593
594
595
596
        if neighbour_idx is not None:
            return neighbour_idx

        neighbour_idx = closest_neighbours[(out_direction + 1) % 4]  # clockwise
        if neighbour_idx is not None:
            return neighbour_idx

        return closest_neighbours[(out_direction + 2) % 4]  # clockwise

597
    def _build_inner_cities(self, city_positions: IntVector2DArray, inner_connection_points: List[List[List[IntVector2D]]],
598
599
                            outer_connection_points: List[List[List[IntVector2D]]], rail_trans: RailEnvTransitions,
                            grid_map: GridTransitionMap) -> (List[IntVector2DArray], List[List[List[IntVector2D]]]):
600
        """
Erik Nygren's avatar
Erik Nygren committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
        Set the parallel tracks within the city. The center track of the city is of the length of the city, the lenght
        of the tracks decrease by 2 for every parallel track away from the center
        EG:

                ---     Left Track
               -----    Center Track
                ---     Right Track

        Parameters
        ----------
        city_positions: IntVector2DArray
                        All coordinates of the cities

        inner_connection_points: List[List[List[IntVector2D]]]
            Points on city boarder that are used to generate inner city track
        outer_connection_points: List[List[List[IntVector2D]]]
            Points where the city is connected to neighboring cities
        rail_trans: RailEnvTransitions
            Railway transition objects
        grid_map: RailEnvTransitions
            The grid map containing the rails. Used to draw new rails

        Returns
        -------
        Returns a list of all the cells (Coordinates) that belong to a rail paths within the city.
626
        """
Erik Nygren's avatar
Erik Nygren committed
627

628
        free_rails: List[List[List[IntVector2D]]] = [[] for i in range(len(city_positions))]
u229589's avatar
u229589 committed
629
        for current_city in range(len(city_positions)):
630

631
632
            # This part only works if we have keep same number of connection points for both directions
            # Also only works with two connection direction at each city
633
634
635
636
637
638
            for i in range(4):
                if len(inner_connection_points[current_city][i]) > 0:
                    boarder = i
                    break

            opposite_boarder = (boarder + 2) % 4
639
640
641
            nr_of_connection_points = len(inner_connection_points[current_city][boarder])
            number_of_out_rails = len(outer_connection_points[current_city][boarder])
            start_idx = int((nr_of_connection_points - number_of_out_rails) / 2)
642
            # Connect parallel tracks
643
            for track_id in range(nr_of_connection_points):
Erik Nygren's avatar
Erik Nygren committed
644
645
                source = inner_connection_points[current_city][boarder][track_id]
                target = inner_connection_points[current_city][opposite_boarder][track_id]
u229589's avatar
u229589 committed
646
                current_track = connect_straight_line_in_grid_map(grid_map, source, target, rail_trans)
647
                free_rails[current_city].append(current_track)
648

649
650
651
            for track_id in range(nr_of_connection_points):
                source = inner_connection_points[current_city][boarder][track_id]
                target = inner_connection_points[current_city][opposite_boarder][track_id]
652
653

                # Connect parallel tracks with each other
654
655
656
657
                fix_inner_nodes(
                    grid_map, source, rail_trans)
                fix_inner_nodes(
                    grid_map, target, rail_trans)
658
659

                # Connect outer tracks to inner tracks
660
661
662
663
664
                if start_idx <= track_id < start_idx + number_of_out_rails:
                    source_outer = outer_connection_points[current_city][boarder][track_id - start_idx]
                    target_outer = outer_connection_points[current_city][opposite_boarder][track_id - start_idx]
                    connect_straight_line_in_grid_map(grid_map, source, source_outer, rail_trans)
                    connect_straight_line_in_grid_map(grid_map, target, target_outer, rail_trans)
665
        return free_rails
666

667
    def _set_trainstation_positions(self, city_positions: IntVector2DArray, city_radius: int,
668
                                    free_rails: List[List[List[IntVector2D]]]) -> List[List[Tuple[IntVector2D, int]]]:
Erik Nygren's avatar
Erik Nygren committed
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        """
        Populate the cities with possible start and end positions. Trainstations are set on the center of each paralell
        track. Each trainstation gets a coordinate as well as number indicating what track it is on

        Parameters
        ----------
        city_positions: IntVector2DArray
                        All coordinates of the cities
        city_radius: int
            Radius of each city. Cities are squares with edge length 2 * city_radius + 1
        free_rails: List[List[List[IntVector2D]]]
            Cells that allow for trainstations to be placed

        Returns
        -------
        Returns a List[List[Tuple[IntVector2D, int]]] containing the coordinates of trainstations as well as their
        track number within the city
        """
u229589's avatar
u229589 committed
687
688
689
690
        num_cities = len(city_positions)
        train_stations = [[] for i in range(num_cities)]
        for current_city in range(len(city_positions)):
            for track_nbr in range(len(free_rails[current_city])):
691
692
                possible_location = free_rails[current_city][track_nbr][
                    int(len(free_rails[current_city][track_nbr]) / 2)]
Erik Nygren's avatar
Erik Nygren committed
693
                train_stations[current_city].append((possible_location, track_nbr))
694
        return train_stations
Erik Nygren's avatar
Erik Nygren committed
695

696
    def _fix_transitions(self, city_cells: IntVector2DArray, inter_city_lines: List[IntVector2DArray],
Erik Nygren's avatar
Erik Nygren committed
697
                         grid_map: GridTransitionMap, vector_field):
698
        """
Erik Nygren's avatar
Erik Nygren committed
699
700
        Check and fix transitions of all the cells that were modified. This is necessary because we ignore validity
        while drawing the rails.
Erik Nygren's avatar
Erik Nygren committed
701
702
703

        Parameters
        ----------
Erik Nygren's avatar
Erik Nygren committed
704
705
706
707
708
709
710
711
712
        city_cells: IntVector2DArray
            Cells within cities. All of these might have changed and are thus checked
        inter_city_lines: List[IntVector2DArray]
            All cells within rails drawn between cities
        vector_field: IntVector2DArray
            Vectorfield of the size of the environment. It is used to generate preferred orienations for each cell.
            Each cell contains the prefered orientation of cells. If no prefered orientation is present it is set to -1
        grid_map: RailEnvTransitions
            The grid map containing the rails. Used to draw new rails
Erik Nygren's avatar
Erik Nygren committed
713
714

        """
Erik Nygren's avatar
Erik Nygren committed
715

u229589's avatar
u229589 committed
716
        # Fix all cities with illegal transition maps
Erik Nygren's avatar
Erik Nygren committed
717
        rails_to_fix = np.zeros(3 * grid_map.height * grid_map.width * 2, dtype='int')
718
        rails_to_fix_cnt = 0
Erik Nygren's avatar
Erik Nygren committed
719
720
        cells_to_fix = city_cells + inter_city_lines
        for cell in cells_to_fix:
721
            cell_valid = grid_map.cell_neighbours_valid(cell, True)
722

723
            if not cell_valid:
Erik Nygren's avatar
Erik Nygren committed
724
725
                rails_to_fix[3 * rails_to_fix_cnt] = cell[0]
                rails_to_fix[3 * rails_to_fix_cnt + 1] = cell[1]
726
                rails_to_fix[3 * rails_to_fix_cnt + 2] = vector_field[cell]
727

728
                rails_to_fix_cnt += 1
729
        # Fix all other cells
730
        for cell in range(rails_to_fix_cnt):
731
            grid_map.fix_transitions((rails_to_fix[3 * cell], rails_to_fix[3 * cell + 1]), rails_to_fix[3 * cell + 2])
732

733
    def _closest_neighbour_in_grid4_directions(self, current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
734
        """
Erik Nygren's avatar
Erik Nygren committed
735
        Finds the closest city in each direction of the current city
Erik Nygren's avatar
Erik Nygren committed
736
737
        Parameters
        ----------
Erik Nygren's avatar
Erik Nygren committed
738
739
740
741
        current_city_idx: int
            Index of current city
        city_positions: IntVector2DArray
            Vector containing the coordinates of all cities
Erik Nygren's avatar
Erik Nygren committed
742
743
744

        Returns
        -------
u229589's avatar
u229589 committed
745
        Returns indices of closest neighbour in every direction NESW
746
        """
Erik Nygren's avatar
Erik Nygren committed
747

u229589's avatar
u229589 committed
748
749
        city_distances = []
        closest_neighbour: List[int] = [None for i in range(4)]
u229589's avatar
u229589 committed
750
751

        # compute distance to all other cities
u229589's avatar
u229589 committed
752
        for city_idx in range(len(city_positions)):
Erik Nygren's avatar
Erik Nygren committed
753
754
            city_distances.append(
                Vec2dOperations.get_manhattan_distance(city_positions[current_city_idx], city_positions[city_idx]))
u229589's avatar
u229589 committed
755
        sorted_neighbours = np.argsort(city_distances)
u229589's avatar
u229589 committed
756

Erik Nygren's avatar
Erik Nygren committed
757
        for neighbour in sorted_neighbours[1:]:  # do not include city itself
758
            direction_to_neighbour = direction_to_point(city_positions[current_city_idx], city_positions[neighbour])
u229589's avatar
u229589 committed
759
            if closest_neighbour[direction_to_neighbour] is None:
u229589's avatar
u229589 committed
760
                closest_neighbour[direction_to_neighbour] = neighbour
761

u229589's avatar
u229589 committed
762
763
            # early return once all 4 directions have a closest neighbour
            if None not in closest_neighbour:
u229589's avatar
u229589 committed
764
                return closest_neighbour
u229589's avatar
u229589 committed
765

u229589's avatar
u229589 committed
766
        return closest_neighbour
767

768
    @staticmethod
769
    def argsort(seq):
Erik Nygren's avatar
Erik Nygren committed
770
        """
Erik Nygren's avatar
Erik Nygren committed
771
        Same as Numpy sort but for lists
Erik Nygren's avatar
Erik Nygren committed
772
773
        Parameters
        ----------
Erik Nygren's avatar
Erik Nygren committed
774
775
        seq: List
            list that we would like to sort from smallest to largest
Erik Nygren's avatar
Erik Nygren committed
776
777
778

        Returns
        -------
Erik Nygren's avatar
Erik Nygren committed
779
        Returns the sorted list
Erik Nygren's avatar
Erik Nygren committed
780
781

        """
782
783
        # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
        return sorted(range(len(seq)), key=seq.__getitem__)
784

785
    def _get_cells_in_city(self, center: IntVector2D, radius: int, city_orientation: int,
Erik Nygren's avatar
Erik Nygren committed
786
                           vector_field: IntVector2DArray) -> IntVector2DArray:
787
        """
Erik Nygren's avatar
Erik Nygren committed
788
789
        Function the collect cells of a city. It also populates the vector field accoring to the orientation of the
        city.
u229589's avatar
u229589 committed
790

Erik Nygren's avatar
Erik Nygren committed
791
792
793
794
795
796
        Example: City oriented north with a radius of 5, the vectorfield in the city will be as follows:
            |S|S|S|S|S|
            |S|S|S|S|S|
            |S|S|S|S|S|  <-- City center
            |N|N|N|N|N|
            |N|N|N|N|N|
Erik Nygren's avatar
Erik Nygren committed
797

Erik Nygren's avatar
Erik Nygren committed
798
        This is used to later orient the switches to avoid infeasible maps.
Erik Nygren's avatar
Erik Nygren committed
799

u229589's avatar
u229589 committed
800
801
        Parameters
        ----------
Erik Nygren's avatar
Erik Nygren committed
802
803
804
805
806
807
        center: IntVector2D
            center coordinates of city
        radius: int
            radius of city (it is a square)
        city_orientation: int
            Orientation of city
u229589's avatar
u229589 committed
808
809
810
811
        Returns
        -------
        flat list of all cell coordinates in the city

812
        """
813
814
815
816
        x_range = np.arange(center[0] - radius, center[0] + radius + 1)
        y_range = np.arange(center[1] - radius, center[1] + radius + 1)
        x_values = np.repeat(x_range, len(y_range))
        y_values = np.tile(y_range, len(x_range))
Erik Nygren's avatar
Erik Nygren committed
817
818
        city_cells = list(zip(x_values, y_values))
        for cell in city_cells:
819
            vector_field[cell] = align_cell_to_city(center, city_orientation, cell)
Erik Nygren's avatar
Erik Nygren committed
820
        return city_cells
821

822
    @staticmethod
u229589's avatar
u229589 committed
823
    def _are_cities_overlapping(center_1, center_2, radius):
Erik Nygren's avatar
Erik Nygren committed
824
        """
Erik Nygren's avatar
Erik Nygren committed
825
        Check if two cities overlap. That is we check if two squares with certain edge length and position overlap
Erik Nygren's avatar
Erik Nygren committed
826
827
        Parameters
        ----------
Erik Nygren's avatar
Erik Nygren committed
828
829
830
831
832
833
834
        center_1: (int, int)
            Center of first city
        center_2: (int, int)
            Center of second city

        radius: int
            Radius of each city. Cities are squares with edge length 2 * city_radius + 1
Erik Nygren's avatar
Erik Nygren committed
835
836
837

        Returns
        -------
Erik Nygren's avatar
Erik Nygren committed
838
        Returns True if the cities overlap and False otherwise
Erik Nygren's avatar
Erik Nygren committed
839
        """
840
        return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius
841