From fe85e973970e43423f4ea3311b5f296bc83dfa7b Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Fri, 27 Sep 2019 15:06:43 -0400
Subject: [PATCH] added city boarder to as forbidden zone for inner city
 connection. This strictly seperates inner from outer connections. should help
 us to get nicer infrastructures.

---
 examples/flatland_2_0_example.py |  12 +--
 flatland/envs/rail_generators.py | 158 +++++++++++++++----------------
 2 files changed, 81 insertions(+), 89 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index cfe733a3..8b5a01cc 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -33,16 +33,16 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 env = RailEnv(width=50,
               height=50,
               rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
-                                                   num_trainstations=50,  # Number of possible start/targets on map
-                                                   min_node_dist=5,  # Minimal distance of nodes
-                                                   node_radius=3,  # Proximity of stations to city center
+                                                   num_trainstations=45,  # Number of possible start/targets on map
+                                                   min_node_dist=10,  # Minimal distance of nodes
+                                                   node_radius=4,  # Proximity of stations to city center
                                                    seed=15,  # Random seed
-                                                   grid_mode=True,
+                                                   grid_mode=False,
                                                    max_connection_points_per_side=2,
-                                                   max_nr_connection_directions=4
+                                                   max_nr_connection_directions=2
                                                    ),
               schedule_generator=sparse_schedule_generator(),
-              number_of_agents=50,
+              number_of_agents=15,
               stochastic_data=stochastic_data,  # Malfunction data generator
               obs_builder_object=GlobalObsForRailEnv())
 
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 47669c9b..5626c65a 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_cities
+from flatland.envs.grid4_generators_utils import connect_rail, connect_cities
 
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
@@ -589,7 +589,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map)
 
         # Build inner cities
-        train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map)
+        _build_inner_cities(node_positions, connection_points, rail_trans, grid_map)
+
+        # Populate cities
+        train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, city_cells,
+                                                                             num_trainstations, grid_map)
 
         # Adjust the number of agents if you could not build enough trainstations
         if num_agents > built_num_trainstation:
@@ -600,6 +604,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         _fix_transitions(grid_map)
 
         # Generate start target paris
+        print(train_stations)
         agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations)
 
         return grid_map, {'agents_hints': {
@@ -688,25 +693,25 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                 nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1)
 
                 connections_per_direction[idx] = nr_of_connection_points
-            connection_points_coordinates = []
+            connection_points_coordinates = [[] for i in range(4)]
 
             for direction in range(4):
                 connection_slots = np.arange(connections_per_direction[direction]) - int(
                         connections_per_direction[direction] / 2)
                 for connection_idx in range(connections_per_direction[direction]):
                     if direction == 0:
-                        connection_points_coordinates.append(
-                            (node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
+                        tmp_coordinates = (
+                        node_position[0] - node_size, node_position[1] + connection_slots[connection_idx])
                     if direction == 1:
-                        connection_points_coordinates.append(
-                            (node_position[0] + connection_slots[connection_idx], node_position[1] + node_size))
+                        tmp_coordinates = (
+                        node_position[0] + connection_slots[connection_idx], node_position[1] + node_size)
                     if direction == 2:
-                        connection_points_coordinates.append(
-                            (node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
+                        tmp_coordinates = (
+                        node_position[0] + node_size, node_position[1] + connection_slots[connection_idx])
                     if direction == 3:
-                        connection_points_coordinates.append(
-                            (node_position[0] + connection_slots[connection_idx], node_position[1] - node_size))
-
+                        tmp_coordinates = (
+                        node_position[0] + connection_slots[connection_idx], node_position[1] - node_size)
+                    connection_points_coordinates[direction].append(tmp_coordinates)
             connection_points.append(connection_points_coordinates)
             connection_info.append(connections_per_direction)
         return connection_points, connection_info
@@ -733,15 +738,13 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
 
                 if neighb_idx is not None:
                     connection_distances = []
-                    for tmp_out_connection_point in connection_points[current_node]:
-                        tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb_idx])
-                        connection_distances.append(tmp_dist_to_node)
-                    possible_connection_points = argsort(connection_distances)
-                    for sort_idx in possible_connection_points[:connection_info[current_node][direction]]:
+                    for tmp_out_connection_point in connection_points[current_node][direction]:
                         # Find closest connection point
-                        tmp_out_connection_point = connection_points[current_node][sort_idx]
                         min_connection_dist = np.inf
-                        for tmp_in_connection_point in connection_points[neighb_idx]:
+                        all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in
+                                                        sublist]
+
+                        for tmp_in_connection_point in all_neighb_connection_points:
                             tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
                             if tmp_dist < min_connection_dist:
                                 min_connection_dist = tmp_dist
@@ -753,71 +756,51 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                 direction += 1
         return boarder_connections
 
+    def _build_inner_cities(node_positions, connection_points, rail_trans, grid_map):
+        """
+        Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
+        :param node_positions:
+        :param connection_points:
+        :param rail_trans:
+        :param grid_map:
+        :return:
+        """
+        for current_city in range(len(node_positions)):
+            for boarder in range(4):
+                for source in connection_points[current_city][boarder]:
+                    for other_boarder in range(4):
+                        if boarder != other_boarder and len(connection_points[current_city][other_boarder]) > 0:
+                            for target in connection_points[current_city][other_boarder]:
+                                city_boarder = _city_boarder(node_positions[current_city], node_radius)
+                                connect_cities(rail_trans, grid_map, source, target, city_boarder)
+                        else:
+                            continue
 
-    def _build_cities(node_positions, connection_points, rail_trans, grid_map):
-        # Place train stations close to the node
-        # We currently place them uniformly distributed among all cities
-        built_num_trainstation = 0
-        nb_nodes = len(node_positions)
-        height, width = np.shape(grid_map.grid)
-        train_stations = [[] for i in range(nb_nodes)]
-        if nb_nodes > 1:
-
-            for station in range(num_trainstations):
-                spot_found = True
-                reduced_node_radius = node_radius - 1
-                trainstation_node = int(station / num_trainstations * nb_nodes)
-
-                station_x = np.clip(
-                    node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius, reduced_node_radius),
-                    0,
-                    height - 1)
-                station_y = np.clip(
-                    node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius, reduced_node_radius),
-                    0,
-                    width - 1)
-                tries = 0
-                while (station_x, station_y) in train_stations[trainstation_node]:
-
-                    station_x = np.clip(
-                        node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius,
-                                                                                 reduced_node_radius),
-                        0,
-                        height - 1)
-                    station_y = np.clip(
-                        node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius,
-                                                                                 reduced_node_radius),
-                        0,
-                        width - 1)
-                    tries += 1
-                    if tries > 100:
-                        warnings.warn("Could not set trainstations, please change initial parameters!!!!")
-                        spot_found = False
-                        break
-
-                if spot_found:
-                    train_stations[trainstation_node].append((station_x, station_y))
-
-                # Connect train station to random nodes
+        return
 
-                if len(connection_points[trainstation_node]) > 1:
-                    rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2,
-                                                         replace=False)
-                else:
-                    rand_corner_nodes = [0]
-
-                for corner_node_idx in rand_corner_nodes:
-                    connection = connect_nodes(rail_trans, grid_map,
-                                               connection_points[trainstation_node][corner_node_idx],
-                                               (station_x, station_y))
-                # Check if connection was made
-                if len(connection) == 0:
-                    if len(train_stations[trainstation_node]) > 0:
-                        train_stations[trainstation_node].pop(-1)
-                else:
+    def _set_trainstation_positions(node_positions, city_cells, num_trainstations, grid_map):
+        """
 
-                    built_num_trainstation += 1
-        return train_stations, built_num_trainstation
+        :param node_positions:
+        :param num_trainstations:
+        :return:
+        """
+        nb_nodes = len(node_positions)
+        train_stations = [[] for i in range(nb_nodes)]
+        num_cities = len(node_positions)
+        built_num_trainstations = 0
+        stations_per_city = int(num_trainstations / num_cities)
+        for current_city in range(len(node_positions)):
+            for possible_location in _city_cells(node_positions[current_city], node_radius - 1):
+                cell_type = grid_map.get_full_transitions(*possible_location)
+                nbits = 0
+                while cell_type > 0:
+                    nbits += (cell_type & 1)
+                    cell_type = cell_type >> 1
+                if 1 <= nbits <= 2:
+                    built_num_trainstations += 1
+                    train_stations[current_city].append(possible_location)
+        return train_stations, built_num_trainstations
 
     def _fix_transitions(grid_map):
         """
@@ -924,10 +907,19 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         :return: returns flat list of all cell coordinates in the city
         """
         city_cells = []
-        for x in range(-radius, radius):
-            for y in range(-radius, radius):
+        for x in range(-radius, radius + 1):
+            for y in range(-radius, radius + 1):
                 city_cells.append((center[0] + x, center[1] + y))
 
         return city_cells
 
+    def _city_boarder(center, radius):
+        city_boarder = []
+        for x in range(-radius, radius + 1):
+            for y in range(-radius, radius + 1):
+                print(x, y, radius)
+                if abs(x) == radius or abs(y) == radius:
+                    city_boarder.append((center[0] + x, center[1] + y))
+        return city_boarder
+
     return generator
-- 
GitLab