From 21223866907579d446d9e17fcc5c3059a98cb8de Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 3 Oct 2019 08:16:49 +0200
Subject: [PATCH] refactor initial variables in sparse_rail_generator

---
 examples/flatland_2_0_example.py              |   6 +-
 flatland/envs/rail_generators.py              | 110 +++++++-----------
 ...est_flatland_envs_sparse_rail_generator.py |  12 +-
 tests/test_flatland_malfunction.py            |   8 +-
 tests/test_global_observation.py              |   2 +-
 5 files changed, 59 insertions(+), 79 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 40b41591..4e97ff11 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -30,11 +30,11 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=40,
               height=40,
-              rail_generator=sparse_rail_generator(num_cities=8,  # Number of cities in map (where train stations are)
+              rail_generator=sparse_rail_generator(max_num_cities=8,  # Number of cities in map (where train stations are)
                                                    seed=1,  # Random seed
                                                    grid_mode=False,
-                                                   max_inter_city_rails=2,
-                                                   max_tracks_in_city=4,
+                                                   max_rails_between_cities=2,
+                                                   max_rails_in_city=4,
                                                    ),
               schedule_generator=sparse_schedule_generator(speed_ration_map),
               number_of_agents=20,
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index d98b9644..b132a5ca 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -4,7 +4,6 @@ import warnings
 from typing import Callable, Tuple, Optional, Dict, List, Any
 
 import msgpack
-import networkx as nx
 import numpy as np
 
 from flatland.core.grid.grid4_utils import get_direction, mirror
@@ -534,88 +533,74 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
     return generator
 
 
-def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, max_tracks_in_city=4,
-                          seed=0) -> RailGenerator:
+def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4,
+                          max_rails_in_city: int = 4, seed: int = 0) -> RailGenerator:
     """
     Generates railway networks with cities and inner city rails
-    :param num_cities: Number of city centers in the map
-    :param grid_mode: Arange cities in a grid or randomly
-    :param max_inter_city_rails: Maximum number of connecting rails going out from a city
-    :param max_tracks_in_city: maximum number of internal rails
+    :param max_num_cities: Number of city centers in the map
+    :param grid_mode: arrange cities in a grid or randomly
+    :param max_rails_between_cities: Maximum number of connecting rails going out from a city
+    :param max_rails_in_city: maximum number of internal rails
     :param seed: Random seed to initiate rail
     :return: generator
     """
-    G = nx.DiGraph()
 
     DEBUG_PRINT_TIMING = False
 
-    def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct:
+    def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
+        np.random.seed(seed + num_resets)
 
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
-        rail_array = grid_map.grid
-        rail_array.fill(0)
-        np.random.seed(seed + num_resets)
 
-        # Graph to be able to create correct start/end pairs for schedule
+        city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 1
 
-        node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 1
-        if 3 > max_tracks_in_city:
-            rail_in_city = 3
-        else:
-            rail_in_city = max_tracks_in_city
-        max_inter_city_rails_allowed = max_inter_city_rails
-        if max_inter_city_rails_allowed > rail_in_city:
-            max_inter_city_rails_allowed = rail_in_city
-        # Generate a set of nodes for the sparse network
-        # Try to connect cities to nodes first
-        city_positions = []
-        intersection_positions = []
-
-        # Evenly distribute cities and intersections
+        min_nr_rails_in_city = 3
+        rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city
+        rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities
+
+        # Evenly distribute cities
         node_time_start = time.time()
-        node_positions: List[Any] = None
-        nb_nodes = num_cities
         if grid_mode:
-            node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width)
+            city_positions, city_cells = _generate_evenly_distr_city_positions(max_num_cities, city_radius, width, height)
         else:
-            node_positions, city_cells = _generate_random_node_positions(nb_nodes, node_radius, height, width)
+            city_positions, city_cells = _generate_random_city_positions(max_num_cities, city_radius, width, height)
 
-        # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
-        nb_nodes = len(node_positions)
+        # reduce num_cities, _num_cities, _num_intersections if less were generated in not_grid_mode
+        num_cities = len(city_positions)
         if DEBUG_PRINT_TIMING:
             print("City position time", time.time() - node_time_start, "Seconds")
+
         # Set up connection points for all cities
         node_connection_time = time.time()
         inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_node_connection_points(
-            node_positions, node_radius, max_inter_city_rails_allowed,
-            rail_in_city)
+            city_positions, city_radius, rails_between_cities,
+            rails_in_city)
         if DEBUG_PRINT_TIMING:
             print("Connection points", time.time() - node_connection_time)
 
         # Connect the cities through the connection points
         city_connection_time = time.time()
-        inter_city_lines = _connect_cities(node_positions, outer_connection_points, connection_info, city_cells,
+        inter_city_lines = _connect_cities(city_positions, outer_connection_points, connection_info, city_cells,
                                            rail_trans, grid_map)
         if DEBUG_PRINT_TIMING:
             print("City connection time", time.time() - city_connection_time)
         # Build inner cities
         city_build_time = time.time()
-        through_tracks, free_tracks = _build_inner_cities(node_positions, inner_connection_points,
+        through_tracks, free_tracks = _build_inner_cities(city_positions, inner_connection_points,
                                                           outer_connection_points,
-                                                          node_radius,
+                                                          city_radius,
                                                           rail_trans,
                                                           grid_map)
         if DEBUG_PRINT_TIMING:
             print("City build time", time.time() - city_build_time)
         # Populate cities
         train_station_time = time.time()
-        train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, node_radius, free_tracks,
+        train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_tracks,
                                                                              grid_map)
         if DEBUG_PRINT_TIMING:
             print("Trainstation placing time", time.time() - train_station_time)
 
-
         # Fix all transition elements
         grid_fix_time = time.time()
         _fix_transitions(city_cells, inter_city_lines, grid_map)
@@ -624,7 +609,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
 
         # Generate start target pairs
         schedule_time = time.time()
-        agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations,
+        agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, num_cities, train_stations,
                                                                              city_orientations)
         if DEBUG_PRINT_TIMING:
             print("Schedule time", time.time() - schedule_time)
@@ -636,52 +621,50 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
             'city_orientations': city_orientations
         }}
 
-    def _generate_random_node_positions(nb_nodes, node_radius, height, width):
+    def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
 
         node_positions = []
         city_cells = []
 
-        for node_idx in range(nb_nodes):
+        for node_idx in range(num_cities):
             to_close = True
             tries = 0
 
             while to_close:
-                x_tmp = node_radius + 1 + np.random.randint(height - 2 * (node_radius + 1))
-                y_tmp = node_radius + 1 + np.random.randint(width - 2 * (node_radius + 1))
+                x_tmp = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1))
+                y_tmp = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1))
                 to_close = False
                 # Check distance to nodes
                 for node_pos in node_positions:
-                    if _city_overlap((x_tmp, y_tmp), node_pos, 2 * (node_radius + 1) + 1):
+                    if _city_overlap((x_tmp, y_tmp), node_pos, 2 * (city_radius + 1) + 1):
                         to_close = True
 
                 if not to_close:
                     node_positions.append((x_tmp, y_tmp))
-                    city_cells.extend(_city_cells(node_positions[-1], node_radius))
+                    city_cells.extend(_city_cells(node_positions[-1], city_radius))
 
                 tries += 1
                 if tries > 200:
                     warnings.warn(
                         "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
                             len(node_positions),
-                            tries, nb_nodes))
+                            tries, num_cities))
                     break
-        G.add_node(node_idx)
         return node_positions, city_cells
 
-    def _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width):
+    def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
         nodes_ratio = height / width
-        nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
-        nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
-        x_positions = np.linspace(node_radius + 1, height - node_radius - 2, nodes_per_row, dtype=int)
-        y_positions = np.linspace(node_radius + 1, width - node_radius - 2, nodes_per_col, dtype=int)
+        nodes_per_row = int(np.ceil(np.sqrt(num_cities * nodes_ratio)))
+        nodes_per_col = int(np.ceil(num_cities / nodes_per_row))
+        x_positions = np.linspace(city_radius + 1, height - city_radius - 2, nodes_per_row, dtype=int)
+        y_positions = np.linspace(city_radius + 1, width - city_radius - 2, nodes_per_col, dtype=int)
         node_positions = []
         city_cells = []
-        for node_idx in range(nb_nodes):
+        for node_idx in range(num_cities):
             x_tmp = x_positions[node_idx % nodes_per_row]
             y_tmp = y_positions[node_idx // nodes_per_row]
             node_positions.append((x_tmp, y_tmp))
-            city_cells.extend(_city_cells(node_positions[-1], node_radius))
-            G.add_node(node_idx)
+            city_cells.extend(_city_cells(node_positions[-1], city_radius))
         return node_positions, city_cells
 
     def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2):
@@ -745,7 +728,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                         rail_trans, grid_map):
         """
         Function to connect the different cities through their connection points
-        :param node_positions: Positions of city centers
+        :param city_positions: Positions of city centers
         :param connection_points: Boarder connection points of cities
         :param connection_info: Number of connection points per direction NESW
         :param rail_trans: Transitions
@@ -778,9 +761,6 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                     new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point,
                                               neighb_connection_point,
                                               city_cells)
-                    G.add_edge(current_node, neighb_idx, direction=out_direction, length=len(new_line))
-                    G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line))
-
                     all_paths.extend(new_line)
 
         return all_paths
@@ -789,7 +769,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                             grid_map):
         """
         Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
-        :param node_positions: Positions of the cities
+        :param city_positions: Positions of the cities
         :param inner_connection_points: Points on city boarder that are used to generate inner city track
         :param outer_connection_points: Points where the city is connected to neighboring cities
         :param rail_trans:
@@ -833,7 +813,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
     def _set_trainstation_positions(node_positions, node_radius, free_tracks, grid_map):
         """
 
-        :param node_positions:
+        :param city_positions:
         :param num_trainstations:
         :return:
         """
@@ -906,8 +886,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
     def _closest_neigh_in_direction(current_node, node_positions):
         """
         Returns indices of closest neighbours in every direction NESW
-        :param current_node: Index of node in node_positions list
-        :param node_positions: list of all points being considered
+        :param current_node: Index of node in city_positions list
+        :param city_positions: list of all points being considered
         :return: list of index of closest neighbours in all directions
         """
         node_dist = []
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index fd6e7b88..b94de82c 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -13,7 +13,7 @@ from flatland.utils.rendertools import RenderTool
 def test_sparse_rail_generator():
     env = RailEnv(width=50,
                   height=50,
-                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                  rail_generator=sparse_rail_generator(max_num_cities=10,  # Number of cities in map
                                                        num_intersections=10,  # Number of interesections in map
                                                        num_trainstations=50,  # Number of possible start/targets on map
                                                        min_node_dist=6,  # Minimal distance of nodes
@@ -733,7 +733,7 @@ def test_sparse_rail_generator_deterministic():
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=sparse_rail_generator(num_cities=5,
+                  rail_generator=sparse_rail_generator(max_num_cities=5,
                                                        # Number of cities in map (where train stations are)
                                                        num_intersections=4,
                                                        # Number of intersections (no start / target)
@@ -1509,7 +1509,7 @@ def test_rail_env_action_required_info():
                         1. / 4.: 0.25}  # Slow freight train
     env_always_action = RailEnv(width=50,
                                 height=50,
-                                rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                rail_generator=sparse_rail_generator(max_num_cities=10,  # Number of cities in map
                                                                      num_intersections=10,
                                                                      # Number of interesections in map
                                                                      num_trainstations=50,
@@ -1528,7 +1528,7 @@ def test_rail_env_action_required_info():
     np.random.seed(0)
     env_only_if_action_required = RailEnv(width=50,
                                           height=50,
-                                          rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                          rail_generator=sparse_rail_generator(max_num_cities=10,  # Number of cities in map
                                                                                num_intersections=10,
                                                                                # Number of interesections in map
                                                                                num_trainstations=50,
@@ -1592,7 +1592,7 @@ def test_rail_env_malfunction_speed_info():
                        }
     env = RailEnv(width=50,
                   height=50,
-                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                  rail_generator=sparse_rail_generator(max_num_cities=10,  # Number of cities in map
                                                        num_intersections=10,
                                                        # Number of interesections in map
                                                        num_trainstations=50,
@@ -1640,7 +1640,7 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down():
     RailEnv(width=50,
             height=50,
             rail_generator=sparse_rail_generator(
-                num_cities=100,  # Number of cities in map
+                max_num_cities=100,  # Number of cities in map
                 num_intersections=10,  # Number of interesections in map
                 num_trainstations=50,  # Number of possible start/targets on map
                 min_node_dist=6,  # Minimal distance of nodes
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index cc611503..16f993a8 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -165,7 +165,7 @@ def test_initial_malfunction():
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=sparse_rail_generator(num_cities=5,
+                  rail_generator=sparse_rail_generator(max_num_cities=5,
                                                        # Number of cities in map (where train stations are)
                                                        num_intersections=4,
                                                        # Number of intersections (no start / target)
@@ -247,7 +247,7 @@ def test_initial_malfunction_stop_moving():
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=sparse_rail_generator(num_cities=5,
+                  rail_generator=sparse_rail_generator(max_num_cities=5,
                                                        # Number of cities in map (where train stations are)
                                                        num_intersections=4,
                                                        # Number of intersections (no start / target)
@@ -339,7 +339,7 @@ def test_initial_malfunction_do_nothing():
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=sparse_rail_generator(num_cities=5,
+                  rail_generator=sparse_rail_generator(max_num_cities=5,
                                                        # Number of cities in map (where train stations are)
                                                        num_intersections=4,
                                                        # Number of intersections (no start / target)
@@ -430,7 +430,7 @@ def test_initial_nextmalfunction_not_below_zero():
 
     env = RailEnv(width=25,
                   height=30,
-                  rail_generator=sparse_rail_generator(num_cities=5,
+                  rail_generator=sparse_rail_generator(max_num_cities=5,
                                                        # Number of cities in map (where train stations are)
                                                        num_intersections=4,
                                                        # Number of intersections (no start / target)
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
index 7213560f..7f8f62c0 100644
--- a/tests/test_global_observation.py
+++ b/tests/test_global_observation.py
@@ -23,7 +23,7 @@ def test_get_global_observation():
 
     env = RailEnv(width=50,
                   height=50,
-                  rail_generator=sparse_rail_generator(num_cities=25,
+                  rail_generator=sparse_rail_generator(max_num_cities=25,
                                                        # Number of cities in map (where train stations are)
                                                        num_intersections=10,
                                                        # Number of intersections (no start / target)
-- 
GitLab