From 95ba74c8c9158bf526f8c83d97825596ce9cd3da Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 24 Sep 2019 18:25:09 -0400
Subject: [PATCH] first commit of improved level generator. Now cities are not
 connected from the center but rather from the boarders. NESW

---
 .../Simple_Realistic_Railway_Generator.py     | 612 ++++++++++++++++++
 examples/flatland_2_0_example.py              |  21 +-
 .../simple_example_city_railway_generator.py  |   2 +-
 flatland/core/grid/grid_utils.py              |   2 +-
 flatland/envs/rail_generators.py              | 135 ++--
 5 files changed, 710 insertions(+), 62 deletions(-)
 create mode 100644 examples/Simple_Realistic_Railway_Generator.py

diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py
new file mode 100644
index 00000000..41506ba9
--- /dev/null
+++ b/examples/Simple_Realistic_Railway_Generator.py
@@ -0,0 +1,612 @@
+import copy
+import os
+import time
+import warnings
+
+import numpy as np
+
+from flatland.core.grid.grid4_utils import mirror
+from flatland.core.grid.rail_env_grid import RailEnvTransitions
+from flatland.core.transition_map import GridTransitionMap
+from flatland.envs.grid4_generators_utils import connect_from_nodes, connect_nodes, connect_rail
+from flatland.envs.observations import GlobalObsForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+
+
+class Vec2dOperations:
+    def subtract_pos(nodeA, nodeB):
+        """
+        vector operation : nodeA - nodeB
+
+        :param nodeA: tuple with coordinate (x,y) or 2d vector
+        :param nodeB: tuple with coordinate (x,y) or 2d vector
+        :return:
+            -------
+        tuple with coordinate (x,y) or 2d vector
+        """
+        return (nodeA[0] - nodeB[0], nodeA[1] - nodeB[1])
+
+    def add_pos(nodeA, nodeB):
+        """
+        vector operation : nodeA + nodeB
+
+        :param nodeA: tuple with coordinate (x,y) or 2d vector
+        :param nodeB: tuple with coordinate (x,y) or 2d vector
+        :return:
+            -------
+        tuple with coordinate (x,y) or 2d vector
+        """
+        return (nodeA[0] + nodeB[0], nodeA[1] + nodeB[1])
+
+    def make_orthogonal_pos(node):
+        """
+        vector operation : rotates the 2D vector +90°
+
+        :param node: tuple with coordinate (x,y) or 2d vector
+        :return:
+            -------
+        tuple with coordinate (x,y) or 2d vector
+        """
+        return (node[1], -node[0])
+
+    def get_norm_pos(node):
+        """
+        calculates the euclidean norm of the 2d vector
+
+        :param node: tuple with coordinate (x,y) or 2d vector
+        :return:
+            -------
+        tuple with coordinate (x,y) or 2d vector
+        """
+        return np.sqrt(node[0] * node[0] + node[1] * node[1])
+
+    def normalize_pos(node):
+        """
+        normalize the 2d vector = v/|v|
+
+        :param node: tuple with coordinate (x,y) or 2d vector
+        :return:
+            -------
+        tuple with coordinate (x,y) or 2d vector
+        """
+        n = Vec2dOperations.get_norm_pos(node)
+        if n > 0.0:
+            n = 1 / n
+        return Vec2dOperations.scale_pos(node, n)
+
+    def scale_pos(node, scalar):
+        """
+         scales the 2d vector = node * scale
+
+         :param node: tuple with coordinate (x,y) or 2d vector
+         :param scale: scalar to scale
+         :return:
+             -------
+         tuple with coordinate (x,y) or 2d vector
+         """
+        return (node[0] * scalar, node[1] * scalar)
+
+    def round_pos(node):
+        """
+         rounds the x and y coordinate and convert them to an integer values
+
+         :param node: tuple with coordinate (x,y) or 2d vector
+         :return:
+             -------
+         tuple with coordinate (x,y) or 2d vector
+         """
+        return (int(np.round(node[0])), int(np.round(node[1])))
+
+    def ceil_pos(node):
+        """
+         ceiling the x and y coordinate and convert them to an integer values
+
+         :param node: tuple with coordinate (x,y) or 2d vector
+         :return:
+             -------
+         tuple with coordinate (x,y) or 2d vector
+         """
+        return (int(np.ceil(node[0])), int(np.ceil(node[1])))
+
+    def bound_pos(node, min_value, max_value):
+        """
+         force the values x and y to be between min_value and max_value
+
+         :param node: tuple with coordinate (x,y) or 2d vector
+         :param min_value: scalar value
+         :param max_value: scalar value
+         :return:
+             -------
+         tuple with coordinate (x,y) or 2d vector
+         """
+        return (max(min_value, min(max_value, node[0])), max(min_value, min(max_value, node[1])))
+
+    def rotate_pos(node, rot_in_degree):
+        """
+         rotate the 2d vector with given angle in degree
+
+         :param node: tuple with coordinate (x,y) or 2d vector
+         :param rot_in_degree:  angle in degree
+         :return:
+             -------
+         tuple with coordinate (x,y) or 2d vector
+         """
+        alpha = rot_in_degree / 180.0 * np.pi
+        x0 = node[0]
+        y0 = node[1]
+        x1 = x0 * np.cos(alpha) - y0 * np.sin(alpha)
+        y1 = x0 * np.sin(alpha) + y0 * np.cos(alpha)
+        return (x1, y1)
+
+
+class Grid_Map_Op:
+    def min_max_cut(min_v, max_v, v):
+        return max(min_v, min(max_v, v))
+
+    def add_rail(width, height, grid_map, pt_from, pt_via, pt_to, bAddRemove=True):
+        gRCTrans = np.array([[-1, 0], [0, 1], [1, 0], [0, -1]])  # NESW in RC
+
+        lrcStroke = [[Grid_Map_Op.min_max_cut(0, height - 1, pt_from[0]),
+                      Grid_Map_Op.min_max_cut(0, width - 1, pt_from[1])],
+                     [Grid_Map_Op.min_max_cut(0, height - 1, pt_via[0]),
+                      Grid_Map_Op.min_max_cut(0, width - 1, pt_via[1])],
+                     [Grid_Map_Op.min_max_cut(0, height - 1, pt_to[0]),
+                      Grid_Map_Op.min_max_cut(0, width - 1, pt_to[1])]]
+
+        rc3Cells = np.array(lrcStroke[:3])  # the 3 cells
+        rcMiddle = rc3Cells[1]  # the middle cell which we will update
+        bDeadend = np.all(lrcStroke[0] == lrcStroke[2])  # deadend means cell 0 == cell 2
+
+        # get the 2 row, col deltas between the 3 cells, eg [[-1,0],[0,1]] = North, East
+        rc2Trans = np.diff(rc3Cells, axis=0)
+
+        # get the direction index for the 2 transitions
+        liTrans = []
+        for rcTrans in rc2Trans:
+            # gRCTrans - rcTrans gives an array of vector differences between our rcTrans
+            # and the 4 directions stored in gRCTrans.
+            # Where the vector difference is zero, we have a match...
+            # np.all detects where the whole row,col vector is zero.
+            # argwhere gives the index of the zero vector, ie the direction index
+            iTrans = np.argwhere(np.all(gRCTrans - rcTrans == 0, axis=1))
+            if len(iTrans) > 0:
+                iTrans = iTrans[0][0]
+                liTrans.append(iTrans)
+
+        # check that we have two transitions
+        if len(liTrans) == 2:
+            # Set the transition
+            # Set the transition
+            # If this transition spans 3 cells, it is not a deadend, so remove any deadends.
+            # The user will need to resolve any conflicts.
+            grid_map.set_transition((*rcMiddle, liTrans[0]),
+                                    liTrans[1],
+                                    bAddRemove,
+                                    remove_deadends=not bDeadend)
+
+            # Also set the reverse transition
+            # use the reversed outbound transition for inbound
+            # and the reversed inbound transition for outbound
+            grid_map.set_transition((*rcMiddle, mirror(liTrans[1])),
+                                    mirror(liTrans[0]), bAddRemove, remove_deadends=not bDeadend)
+
+
+def realistic_rail_generator(num_cities=5,
+                             city_size=10,
+                             allowed_rotation_angles=[0, 90],
+                             max_number_of_station_tracks=4,
+                             nbr_of_switches_per_station_track=2,
+                             connect_max_nbr_of_shortes_city=4,
+                             do_random_connect_stations=False,
+                             seed=0,
+                             print_out_info=True) -> RailGenerator:
+    """
+    This is a level generator which generates a realistic rail configurations
+
+    :param num_cities: Number of city node
+    :param city_size: Length of city measure in cells
+    :param allowed_rotation_angles: Rotate the city (around center)
+    :param max_number_of_station_tracks: max number of tracks per station
+    :param nbr_of_switches_per_station_track: number of switches per track (max)
+    :param connect_max_nbr_of_shortes_city: max number of connecting track between stations
+    :param do_random_connect_stations : if false connect the stations along the grid (top,left -> down,right), else rand
+    :param seed: Random Seed
+    :print_out_info : print debug info
+    :return:
+        -------
+    numpy.ndarray of type numpy.uint16
+        The matrix with the correct 16-bit bitmaps for each cell.
+    """
+
+    def do_generate_city_locations(width, height, intern_city_size, intern_max_number_of_station_tracks):
+
+        X = int(np.floor(max(1, height - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
+        Y = int(np.floor(max(1, width - 2 * intern_max_number_of_station_tracks - 1) / intern_city_size))
+
+        max_num_cities = min(num_cities, X * Y)
+
+        cities_at = np.random.choice(X * Y, max_num_cities, False)
+        cities_at = np.sort(cities_at)
+        if print_out_info:
+            print("max nbr of cities with given configuration is:", max_num_cities)
+
+        x = np.floor(cities_at / Y)
+        y = cities_at - x * Y
+        xs = (x * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2
+        ys = (y * intern_city_size + intern_max_number_of_station_tracks) + intern_city_size / 2
+
+        generate_city_locations = [[(int(xs[i]), int(ys[i])), (int(xs[i]), int(ys[i]))] for i in range(len(xs))]
+        return generate_city_locations, max_num_cities
+
+    def do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles):
+        for i in range(len(generate_city_locations)):
+            # station main orientation  (horizontal or vertical
+            rot_angle = np.random.choice(allowed_rotation_angles)
+            add_pos_val = Vec2dOperations.scale_pos(Vec2dOperations.rotate_pos((1, 0), rot_angle),
+                                                    (max(1, (intern_city_size - 3) / 2)))
+            generate_city_locations[i][0] = Vec2dOperations.add_pos(generate_city_locations[i][1], add_pos_val)
+            add_pos_val = Vec2dOperations.scale_pos(Vec2dOperations.rotate_pos((1, 0), 180 + rot_angle),
+                                                    (max(1, (intern_city_size - 3) / 2)))
+            generate_city_locations[i][1] = Vec2dOperations.add_pos(generate_city_locations[i][1], add_pos_val)
+        return generate_city_locations
+
+    def create_stations_from_city_locations(rail_trans, rail_array, generate_city_locations,
+                                            intern_max_number_of_station_tracks):
+        nodes_added = []
+        start_nodes_added = [[] for i in range(len(generate_city_locations))]
+        end_nodes_added = [[] for i in range(len(generate_city_locations))]
+        station_slots = [[] for i in range(len(generate_city_locations))]
+        station_tracks = [[[] for j in range(intern_max_number_of_station_tracks)] for i in range(len(
+            generate_city_locations))]
+
+        station_slots_cnt = 0
+
+        for city_loop in range(len(generate_city_locations)):
+            # Connect train station to the correct node
+            number_of_connecting_tracks = np.random.choice(max(0, intern_max_number_of_station_tracks)) + 1
+            for ct in range(number_of_connecting_tracks):
+                org_start_node = generate_city_locations[city_loop][0]
+                org_end_node = generate_city_locations[city_loop][1]
+
+                ortho_trans = Vec2dOperations.make_orthogonal_pos(
+                    Vec2dOperations.normalize_pos(Vec2dOperations.subtract_pos(org_start_node, org_end_node)))
+                s = (ct - number_of_connecting_tracks / 2.0)
+                start_node = Vec2dOperations.ceil_pos(
+                    Vec2dOperations.add_pos(org_start_node, Vec2dOperations.scale_pos(ortho_trans, s)))
+                end_node = Vec2dOperations.ceil_pos(
+                    Vec2dOperations.add_pos(org_end_node, Vec2dOperations.scale_pos(ortho_trans, s)))
+
+                connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node)
+                if len(connection) > 0:
+                    nodes_added.append(start_node)
+                    nodes_added.append(end_node)
+
+                    start_nodes_added[city_loop].append(start_node)
+                    end_nodes_added[city_loop].append(end_node)
+
+                    # place in the center of path a station slot
+                    station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
+                    station_slots_cnt += 1
+
+                    station_tracks[city_loop][ct] = connection
+
+        if print_out_info:
+            print("max nbr of station slots with given configuration is:", station_slots_cnt)
+
+        return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks
+
+    def create_switches_at_stations(width, height, grid_map, station_tracks, nodes_added,
+                                    intern_nbr_of_switches_per_station_track):
+        # generate switch based on switch slot list and connect them
+        for city_loop in range(len(station_tracks)):
+            datas = station_tracks[city_loop]
+            for data_loop in range(len(datas) - 1):
+                data = datas[data_loop]
+                data1 = datas[data_loop + 1]
+                if len(data) > 2 and len(data1) > 2:
+                    for i in np.random.choice(min(len(data1), len(data)) - 2,
+                                              intern_nbr_of_switches_per_station_track):
+                        Grid_Map_Op.add_rail(width, height, grid_map, data[i + 1], data1[i + 1], data1[i + 2], True)
+                        nodes_added.append(data[i + 1])
+                        nodes_added.append(data1[i + 1])
+                        nodes_added.append(data1[i + 2])
+
+        return nodes_added
+
+    def calc_nbr_of_graphs(graph):
+        for i in range(len(graph)):
+            for j in range(len(graph)):
+                a = graph[i]
+                b = graph[j]
+                connected = False
+                if a[0] == b[0] or a[1] == b[0]:
+                    connected = True
+                if a[0] == b[1] or a[1] == b[1]:
+                    connected = True
+
+                if connected:
+                    a = [graph[i][0], graph[i][1], graph[i][2]]
+                    b = [graph[j][0], graph[j][1], graph[j][2]]
+                    graph[i] = (graph[i][0], graph[i][1], min(np.min(a), np.min(b)))
+                    graph[j] = (graph[j][0], graph[j][1], min(np.min(a), np.min(b)))
+                else:
+                    a = [graph[i][0], graph[i][1], graph[i][2]]
+                    graph[i] = (graph[i][0], graph[i][1], np.min(a))
+                    b = [graph[j][0], graph[j][1], graph[j][2]]
+                    graph[j] = (graph[j][0], graph[j][1], np.min(b))
+
+        graph_ids = []
+        for i in range(len(graph)):
+            graph_ids.append(graph[i][2])
+        if print_out_info:
+            print("************* NBR of graphs:", len(np.unique(graph_ids)))
+        return graph, np.unique(graph_ids).astype(int)
+
+    def connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added):
+        _, graphids = calc_nbr_of_graphs(city_edges)
+        if len(graphids) > 0:
+            for i in range(len(graphids) - 1):
+                connection = []
+                cnt = 0
+                while len(connection) == 0 and cnt < 100:
+                    s_nodes = copy.deepcopy(org_s_nodes)
+                    e_nodes = copy.deepcopy(org_e_nodes)
+                    start_nodes = s_nodes[graphids[i]]
+                    end_nodes = e_nodes[graphids[i + 1]]
+                    start_node = start_nodes[np.random.choice(len(start_nodes))]
+                    end_node = end_nodes[np.random.choice(len(end_nodes))]
+                    rail_array[start_node] = 0
+                    rail_array[end_node] = 0
+                    connection = connect_rail(rail_trans, rail_array, start_node, end_node)
+                    if len(connection) > 0:
+                        nodes_added.append(start_node)
+                        nodes_added.append(end_node)
+                    cnt += 1
+
+    def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added,
+                         inter_connect_max_nbr_of_shortes_city):
+
+        city_edges = []
+
+        s_nodes = copy.deepcopy(org_s_nodes)
+        e_nodes = copy.deepcopy(org_e_nodes)
+
+        for k in range(inter_connect_max_nbr_of_shortes_city):
+            for city_loop in range(len(s_nodes)):
+                sns = s_nodes[city_loop]
+                for start_node in sns:
+                    min_distance = np.inf
+                    end_node = None
+                    for city_loop_find_shortest in range(len(e_nodes)):
+                        if city_loop_find_shortest == city_loop:
+                            continue
+                        ens = e_nodes[city_loop_find_shortest]
+                        for en in ens:
+                            d = Vec2dOperations.get_norm_pos(Vec2dOperations.subtract_pos(en, start_node))
+                            if d < min_distance:
+                                min_distance = d
+                                end_node = en
+                                cl = city_loop_find_shortest
+
+                    if end_node is not None:
+                        tmp_trans_sn = rail_array[start_node]
+                        tmp_trans_en = rail_array[end_node]
+                        rail_array[start_node] = 0
+                        rail_array[end_node] = 0
+                        connection = connect_rail(rail_trans, rail_array, start_node, end_node)
+                        if len(connection) > 0:
+                            s_nodes[city_loop].remove(start_node)
+                            e_nodes[cl].remove(end_node)
+                            a = (city_loop, cl, np.inf)
+                            if city_loop > cl:
+                                a = (cl, city_loop, np.inf)
+                            if not (a in city_edges):
+                                city_edges.append(a)
+                            nodes_added.append(start_node)
+                            nodes_added.append(end_node)
+                        else:
+                            rail_array[start_node] = tmp_trans_sn
+                            rail_array[end_node] = tmp_trans_en
+
+        connect_sub_graphs(rail_trans, rail_array, org_s_nodes, org_e_nodes, city_edges, nodes_added)
+
+    def connect_random_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added,
+                                inter_connect_max_nbr_of_shortes_city):
+        x = np.arange(len(start_nodes_added))
+        random_city_idx = np.random.choice(x, len(x), False)
+
+        # cyclic connection
+        random_city_idx = np.append(random_city_idx, random_city_idx[0])
+
+        for city_loop in range(len(random_city_idx) - 1):
+            idx_a = random_city_idx[city_loop + 1]
+            idx_b = random_city_idx[city_loop]
+            s_nodes = start_nodes_added[idx_a]
+            e_nodes = end_nodes_added[idx_b]
+
+            max_input_output = max(len(s_nodes), len(e_nodes))
+            max_input_output = min(inter_connect_max_nbr_of_shortes_city, max_input_output)
+
+            if do_random_connect_stations:
+                idx_s_nodes = np.random.choice(np.arange(len(s_nodes)), len(s_nodes), False)
+                idx_e_nodes = np.random.choice(np.arange(len(e_nodes)), len(e_nodes), False)
+            else:
+                idx_s_nodes = np.arange(len(s_nodes))
+                idx_e_nodes = np.arange(len(e_nodes))
+
+            if len(idx_s_nodes) < max_input_output:
+                idx_s_nodes = np.append(idx_s_nodes, np.random.choice(np.arange(len(s_nodes)), max_input_output - len(
+                    idx_s_nodes)))
+            if len(idx_e_nodes) < max_input_output:
+                idx_e_nodes = np.append(idx_e_nodes,
+                                        np.random.choice(np.arange(len(idx_e_nodes)), max_input_output - len(
+                                            idx_e_nodes)))
+
+            if len(idx_s_nodes) > inter_connect_max_nbr_of_shortes_city:
+                idx_s_nodes = np.random.choice(idx_s_nodes, inter_connect_max_nbr_of_shortes_city, False)
+            if len(idx_e_nodes) > inter_connect_max_nbr_of_shortes_city:
+                idx_e_nodes = np.random.choice(idx_e_nodes, inter_connect_max_nbr_of_shortes_city, False)
+
+            for i in range(max_input_output):
+                start_node = s_nodes[idx_s_nodes[i]]
+                end_node = e_nodes[idx_e_nodes[i]]
+                new_trans = rail_array[start_node] = 0
+                new_trans = rail_array[end_node] = 0
+                connection = connect_nodes(rail_trans, rail_array, start_node, end_node)
+                if len(connection) > 0:
+                    nodes_added.append(start_node)
+                    nodes_added.append(end_node)
+
+    def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct:
+        rail_trans = RailEnvTransitions()
+        grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
+        rail_array = grid_map.grid
+        rail_array.fill(0)
+        np.random.seed(seed + num_resets)
+
+        intern_city_size = city_size
+        if city_size < 3:
+            warnings.warn("min city_size requried to be > 3!")
+            intern_city_size = 3
+        if print_out_info:
+            print("intern_city_size:", intern_city_size)
+
+        intern_max_number_of_station_tracks = max_number_of_station_tracks
+        if max_number_of_station_tracks < 1:
+            warnings.warn("min max_number_of_station_tracks requried to be > 1!")
+            intern_max_number_of_station_tracks = 1
+        if print_out_info:
+            print("intern_max_number_of_station_tracks:", intern_max_number_of_station_tracks)
+
+        intern_nbr_of_switches_per_station_track = nbr_of_switches_per_station_track
+        if nbr_of_switches_per_station_track < 1:
+            warnings.warn("min intern_nbr_of_switches_per_station_track requried to be > 2!")
+            intern_nbr_of_switches_per_station_track = 2
+        if print_out_info:
+            print("intern_nbr_of_switches_per_station_track:", intern_nbr_of_switches_per_station_track)
+
+        inter_connect_max_nbr_of_shortes_city = connect_max_nbr_of_shortes_city
+        if connect_max_nbr_of_shortes_city < 1:
+            warnings.warn("min inter_connect_max_nbr_of_shortes_city requried to be > 1!")
+            inter_connect_max_nbr_of_shortes_city = 1
+        if print_out_info:
+            print("inter_connect_max_nbr_of_shortes_city:", inter_connect_max_nbr_of_shortes_city)
+
+        agent_start_targets_nodes = []
+
+        # ----------------------------------------------------------------------------------
+        # generate city locations
+        generate_city_locations, max_num_cities = do_generate_city_locations(width, height, intern_city_size,
+                                                                             intern_max_number_of_station_tracks)
+
+        # ----------------------------------------------------------------------------------
+        # apply orientation to cities (horizontal, vertical)
+        generate_city_locations = do_orient_cities(generate_city_locations, intern_city_size, allowed_rotation_angles)
+
+        # ----------------------------------------------------------------------------------
+        # generate city topology
+        nodes_added, train_stations, s_nodes, e_nodes, station_tracks = \
+            create_stations_from_city_locations(rail_trans, rail_array,
+                                                generate_city_locations,
+                                                intern_max_number_of_station_tracks)
+        # build switches
+        create_switches_at_stations(width, height, grid_map, station_tracks, nodes_added,
+                                    intern_nbr_of_switches_per_station_track)
+
+        # ----------------------------------------------------------------------------------
+        # connect stations
+        if True:
+            if do_random_connect_stations:
+                connect_random_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added,
+                                        inter_connect_max_nbr_of_shortes_city)
+            else:
+                connect_stations(rail_trans, rail_array, s_nodes, e_nodes, nodes_added,
+                                 inter_connect_max_nbr_of_shortes_city)
+
+        # ----------------------------------------------------------------------------------
+        # fix all transition at starting / ending points (mostly add a dead end, if missing)
+        for i in range(len(nodes_added)):
+            grid_map.fix_transitions(nodes_added[i])
+
+        # ----------------------------------------------------------------------------------
+        # Slot availability in node
+        node_available_start = []
+        node_available_target = []
+        for node_idx in range(max_num_cities):
+            node_available_start.append(len(train_stations[node_idx]))
+            node_available_target.append(len(train_stations[node_idx]))
+
+        # Assign agents to slots
+        for agent_idx in range(num_agents):
+            avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
+            avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
+            if len(avail_target_nodes) == 0:
+                num_agents -= 1
+                continue
+            start_node = np.random.choice(avail_start_nodes)
+            target_node = np.random.choice(avail_target_nodes)
+            tries = 0
+            found_agent_pair = True
+            while target_node == start_node:
+                target_node = np.random.choice(avail_target_nodes)
+                tries += 1
+                # Test again with new start node if no pair is found (This code needs to be improved)
+                if (tries + 1) % 10 == 0:
+                    start_node = np.random.choice(avail_start_nodes)
+                if tries > 100:
+                    warnings.warn("Could not set trainstations, removing agent!")
+                    found_agent_pair = False
+                    break
+            if found_agent_pair:
+                node_available_start[start_node] -= 1
+                node_available_target[target_node] -= 1
+                agent_start_targets_nodes.append((start_node, target_node))
+            else:
+                num_agents -= 1
+
+        return grid_map, {'agents_hints': {
+            'num_agents': num_agents,
+            'agent_start_targets_nodes': agent_start_targets_nodes,
+            'train_stations': train_stations
+        }}
+
+    return generator
+
+
+for itrials in range(100):
+    print(itrials, "generate new city")
+    np.random.seed(int(time.time()))
+    env = RailEnv(width=40 + np.random.choice(100),
+                  height=40 + np.random.choice(100),
+                  rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10),
+                                                          city_size=10 + np.random.choice(10),
+                                                          allowed_rotation_angles=[-90, -45, 0, 45, 90],
+                                                          max_number_of_station_tracks=np.random.choice(4) + 4,
+                                                          nbr_of_switches_per_station_track=np.random.choice(4) + 2,
+                                                          connect_max_nbr_of_shortes_city=np.random.choice(4) + 2,
+                                                          do_random_connect_stations=np.random.choice(1) == 0,
+                                                          # Number of cities in map
+                                                          seed=int(time.time()),  # Random seed
+                                                          print_out_info=False
+                                                          ),
+                  schedule_generator=sparse_schedule_generator(),
+                  number_of_agents=1000,
+                  obs_builder_object=GlobalObsForRailEnv())
+
+    # reset to initialize agents_static
+    env_renderer = RenderTool(env, gl="PILSVG", screen_width=1400, screen_height=1000)
+    cnt = 0
+    while cnt < 10:
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
+        cnt += 1
+
+    env_renderer.gl.save_image(
+        os.path.join(
+            "./../render_output/",
+            "flatland_frame_{:04d}_{:04d}.png".format(itrials, 0)
+        ))
+
+    env_renderer.close_window()
diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index da8cacef..a4e24f9d 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -1,10 +1,12 @@
+import time
+
 import numpy as np
 
-from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
-from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.envs.schedule_generators import random_schedule_generator
 from flatland.utils.rendertools import RenderTool
 
 np.random.seed(1)
@@ -30,20 +32,20 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=25,  # Number of cities in map (where train stations are)
-                                                   num_intersections=10,  # Number of intersections (no start / target)
+              rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
+                                                   num_intersections=0,  # Number of intersections (no start / target)
                                                    num_trainstations=50,  # Number of possible start/targets on map
                                                    min_node_dist=3,  # Minimal distance of nodes
-                                                   node_radius=4,  # Proximity of stations to city center
-                                                   num_neighb=4,  # Number of connections to other cities/intersections
+                                                   node_radius=5,  # Proximity of stations to city center
+                                                   num_neighb=3,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
                                                    grid_mode=True,
                                                    enhance_intersection=False
                                                    ),
-              schedule_generator=sparse_schedule_generator(speed_ration_map),
-              number_of_agents=20,
+              schedule_generator=random_schedule_generator(),
+              number_of_agents=0,
               stochastic_data=stochastic_data,  # Malfunction data generator
-              obs_builder_object=TreeObservation)
+              obs_builder_object=GlobalObsForRailEnv())
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
@@ -111,6 +113,7 @@ for step in range(500):
     # reward and whether their are done
     next_obs, all_rewards, done, _ = env.step(action_dict)
     env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
+    time.sleep(50)
     frame_step += 1
     # Update replay buffer and train agent
     for a in range(env.get_num_agents()):
diff --git a/examples/simple_example_city_railway_generator.py b/examples/simple_example_city_railway_generator.py
index ce96932c..f44ad84f 100644
--- a/examples/simple_example_city_railway_generator.py
+++ b/examples/simple_example_city_railway_generator.py
@@ -9,7 +9,7 @@ from flatland.envs.rail_generators_city_generator import city_generator
 from flatland.envs.schedule_generators import city_schedule_generator
 from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 
-os.mkdir("./../render_output/")
+# os.mkdir("./../render_output/")
 
 for itrials in np.arange(1, 15, 1):
     print(itrials, "generate new city")
diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py
index fe4a381f..d39fc8a7 100644
--- a/flatland/core/grid/grid_utils.py
+++ b/flatland/core/grid/grid_utils.py
@@ -295,4 +295,4 @@ def coordinate_to_position(depth, coords):
 
 
 def distance_on_rail(pos1, pos2):
-    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])
+    return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 60c606f7..a62f0184 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -570,8 +570,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             nodes_ratio = height / width
             nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
             nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
-            x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
-            y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
+            x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
+            y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
             city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False)
 
             node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions,
@@ -598,11 +598,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         available_cities = np.arange(_num_cities)
         available_intersections = np.arange(_num_cities, nb_nodes)
 
+        # Set up connection points
+        connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points=4)
+        print(connection_points)
         # Start at some node
         current_node = np.random.randint(len(available_nodes_full))
         node_stack = [current_node]
-        allowed_connections = num_neighb
+        allowed_connections = len(connection_points[current_node])
         first_node = True
+        i = 0
         while len(node_stack) > 0:
             current_node = node_stack[0]
             delete_idx = np.where(available_nodes_full == current_node)
@@ -645,37 +649,60 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             for neighb in connected_neighb_idx:
                 if neighb not in node_stack:
                     node_stack.append(neighb)
-                connect_nodes(rail_trans, grid_map, node_positions[current_node], node_positions[neighb])
+
+                dist_from_center = distance_on_rail(node_positions[current_node], node_positions[neighb])
+                for tmp_out_connection_point in connection_points[current_node]:
+                    tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb])
+                    # Check if this connection node is on the city side facing the neighbour
+                    print("Current node", current_node, "Neigh", neighb, "Distance", tmp_dist_to_node, dist_from_center)
+                    if tmp_dist_to_node < dist_from_center - 1:
+                        min_connection_dist = np.inf
+
+                        # Find closes connection point
+                        for tmp_in_connection_point in connection_points[neighb]:
+                            tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
+                            if tmp_dist < min_connection_dist:
+                                min_connection_dist = tmp_dist
+                                neighb_connection_point = tmp_in_connection_point
+                            center_distance = distance_on_rail(node_positions[current_node], tmp_in_connection_point)
+                        if distance_on_rail(tmp_out_connection_point, neighb_connection_point) < center_distance:
+                            i += 1
+                            connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
+
             node_stack.pop(0)
 
         # Place train stations close to the node
         # We currently place them uniformly distributed among all cities
         built_num_trainstation = 0
         train_stations = [[] for i in range(_num_cities)]
-
         if _num_cities > 1:
 
             for station in range(num_trainstations):
                 spot_found = True
+                reduced_node_radius = node_radius - 1
                 trainstation_node = int(station / num_trainstations * _num_cities)
 
-                station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
-                                    0,
-                                    height - 1)
-                station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
-                                    0,
-                                    width - 1)
+                station_x = np.clip(
+                    node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius, reduced_node_radius),
+                    0,
+                    height - 1)
+                station_y = np.clip(
+                    node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius, reduced_node_radius),
+                    0,
+                    width - 1)
                 tries = 0
                 while (station_x, station_y) in train_stations[trainstation_node] \
                     or (station_x, station_y) == node_positions[trainstation_node] \
                     or rail_array[(station_x, station_y)] != 0:  # noqa: E125
 
                     station_x = np.clip(
-                        node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
+                        node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius,
+                                                                                 reduced_node_radius),
                         0,
                         height - 1)
                     station_y = np.clip(
-                        node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
+                        node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius,
+                                                                                 reduced_node_radius),
                         0,
                         width - 1)
                     tries += 1
@@ -687,11 +714,17 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                 if spot_found:
                     train_stations[trainstation_node].append((station_x, station_y))
 
-                # Connect train station to the correct node
-                connection = connect_from_nodes(rail_trans, grid_map, node_positions[trainstation_node],
+                # Connect train station to random nodes
+
+                rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False)
+                connection_1 = connect_from_nodes(rail_trans, grid_map,
+                                                  connection_points[trainstation_node][rand_corner_nodes[0]],
                                                 (station_x, station_y))
+                connection_2 = connect_from_nodes(rail_trans, grid_map,
+                                                  connection_points[trainstation_node][rand_corner_nodes[1]],
+                                                  (station_x, station_y))
                 # Check if connection was made
-                if len(connection) == 0:
+                if len(connection_1) == 0 and len(connection_2) == 0:
                     if len(train_stations[trainstation_node]) > 0:
                         train_stations[trainstation_node].pop(-1)
                 else:
@@ -702,39 +735,10 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             num_agents = built_num_trainstation
             warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
 
-        # Place passing lanes at intersections
-        # We currently place them uniformly distirbuted among all cities
-        if enhance_intersection:
-
-            for intersection in range(_num_intersections):
-                intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3),
-                                        1,
-                                        height - 2)
-                intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3),
-                                        2,
-                                        width - 2)
-                intersect_x_2 = np.clip(
-                    intersection_positions[intersection][0] + np.random.randint(-3, -1),
-                    1,
-                    height - 2)
-                intersect_y_2 = np.clip(
-                    intersection_positions[intersection][1] + np.random.randint(-3, 3),
-                    1,
-                    width - 2)
-
-                # Connect train station to the correct node
-                connect_nodes(rail_trans, grid_map, (intersect_x_1, intersect_y_1),
-                              (intersect_x_2, intersect_y_2))
-                connect_nodes(rail_trans, grid_map, intersection_positions[intersection],
-                              (intersect_x_1, intersect_y_1))
-                connect_nodes(rail_trans, grid_map, intersection_positions[intersection],
-                              (intersect_x_2, intersect_y_2))
-                grid_map.fix_transitions((intersect_x_1, intersect_y_1))
-                grid_map.fix_transitions((intersect_x_2, intersect_y_2))
-
         # Fix all nodes with illegal transition maps
-        for current_node in node_positions:
-            grid_map.fix_transitions(current_node)
+        flat_list = [item for sublist in connection_points for item in sublist]
+        for cell_to_fix in flat_list:
+            grid_map.fix_transitions(cell_to_fix)
 
         # Generate start and target node directory for all agents.
         # Assure that start and target are not in the same node
@@ -773,9 +777,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                 num_agents -= 1
 
         return grid_map, {'agents_hints': {
-            'num_agents': num_agents,
-            'agent_start_targets_nodes': agent_start_targets_nodes,
-            'train_stations': train_stations
+            'num_agents': num_agents
         }}
 
     def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
@@ -820,15 +822,46 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
     def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
                                            nodes_per_row, x_positions, y_positions):
+
         for node_idx in range(nb_nodes):
 
             x_tmp = x_positions[node_idx % nodes_per_row]
             y_tmp = y_positions[node_idx // nodes_per_row]
             if node_idx in city_idx:
                 city_positions.append((x_tmp, y_tmp))
+
             else:
                 intersection_positions.append((x_tmp, y_tmp))
         node_positions = city_positions + intersection_positions
         return node_positions
 
+    def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2):
+        connection_points = []
+        for node_position in node_positions:
+            n_connection_points = max_nr_connection_points  # np.random.randint(1, max_nr_connection_points)
+            connection_per_direction = n_connection_points // 4
+            connection_point_vector = [connection_per_direction, connection_per_direction, connection_per_direction,
+                                       n_connection_points - 3 * connection_per_direction]
+            print(connection_point_vector)
+            connection_points_coordinates = []
+            for direction in range(4):
+                rnd_points = np.random.choice(np.arange(-node_size, node_size), size=connection_point_vector[direction],
+                                              replace=False)
+                for connection_idx in range(connection_point_vector[direction]):
+                    if direction == 0:
+                        connection_points_coordinates.append(
+                            (node_position[0] - node_size, node_position[1] + rnd_points[connection_idx]))
+                    if direction == 1:
+                        connection_points_coordinates.append(
+                            (node_position[0] + rnd_points[connection_idx], node_position[1] + node_size))
+                    if direction == 2:
+                        connection_points_coordinates.append(
+                            (node_position[0] + node_size, node_position[1] + rnd_points[connection_idx]))
+                    if direction == 3:
+                        connection_points_coordinates.append(
+                            (node_position[0] + rnd_points[connection_idx], node_position[1] - node_size))
+
+            connection_points.append(connection_points_coordinates)
+        return connection_points
+
     return generator
-- 
GitLab