From 1018ce8cd8f257a39072525cdb3c639905886a42 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 5 Sep 2019 13:34:20 +0200
Subject: [PATCH] #164 improving stability sparse level generator

---
 flatland/envs/rail_generators.py              | 141 +++++++++++-------
 ...est_flatland_envs_sparse_rail_generator.py |  60 +++++---
 tests/test_flatland_malfunction.py            |   3 +-
 3 files changed, 130 insertions(+), 74 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 9d55198b..7515009c 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -1,6 +1,6 @@
 """Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
 import warnings
-from typing import Callable, Tuple, Optional, Dict
+from typing import Callable, Tuple, Optional, Dict, List, Any
 
 import msgpack
 import numpy as np
@@ -560,63 +560,43 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
         # Generate a set of nodes for the sparse network
         # Try to connect cities to nodes first
-        node_positions = []
         city_positions = []
         intersection_positions = []
+
         # Evenly distribute cities and intersections
+        node_positions: List[Any] = None
+        nb_nodes = num_cities + num_intersections
         if grid_mode:
-            tot_num_node = num_intersections + num_cities
             nodes_ratio = height / width
-            nodes_per_row = int(np.ceil(np.sqrt(tot_num_node * nodes_ratio)))
-            nodes_per_col = int(np.ceil(tot_num_node / nodes_per_row))
+            nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
+            nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
             x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int)
             y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int)
-            city_idx = np.random.choice(np.arange(tot_num_node), num_cities)
+            city_idx = np.random.choice(np.arange(nb_nodes), num_cities)
+
+            node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions,
+                                                                nb_nodes,
+                                                                nodes_per_row, x_positions,
+                                                                y_positions)
 
-        for node_idx in range(num_cities + num_intersections):
-            to_close = True
-            tries = 0
 
-            if not grid_mode:
-                while to_close:
-                    x_tmp = node_radius + np.random.randint(height - node_radius)
-                    y_tmp = node_radius + np.random.randint(width - node_radius)
-                    to_close = False
-
-                    # Check distance to cities
-                    for node_pos in city_positions:
-                        if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
-                            to_close = True
-
-                    # Check distance to intersections
-                    for node_pos in intersection_positions:
-                        if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
-                            to_close = True
-
-                    if not to_close:
-                        node_positions.append((x_tmp, y_tmp))
-                        if node_idx < num_cities:
-                            city_positions.append((x_tmp, y_tmp))
-                        else:
-                            intersection_positions.append((x_tmp, y_tmp))
-                    tries += 1
-                    if tries > 100:
-                        warnings.warn("Could not set nodes, please change initial parameters!!!!")
-                        break
-            else:
-                x_tmp = x_positions[node_idx % nodes_per_row]
-                y_tmp = y_positions[node_idx // nodes_per_row]
-                if node_idx in city_idx:
-                    city_positions.append((x_tmp, y_tmp))
-                else:
-                    intersection_positions.append((x_tmp, y_tmp))
-        node_positions = city_positions + intersection_positions
+
+        else:
+
+            node_positions = _generate_node_positions_not_grid_mode(city_positions, height,
+                                                                    intersection_positions,
+                                                                    nb_nodes, width)
+
+        # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
+        nb_nodes = len(node_positions)
+        _num_cities = len(city_positions)
+        _num_intersections = len(intersection_positions)
 
         # Chose node connection
         # Set up list of available nodes to connect to
-        available_nodes_full = np.arange(num_cities + num_intersections)
-        available_cities = np.arange(num_cities)
-        available_intersections = np.arange(num_cities, num_cities + num_intersections)
+        available_nodes_full = np.arange(nb_nodes)
+        available_cities = np.arange(_num_cities)
+        available_intersections = np.arange(_num_cities, nb_nodes)
 
         # Start at some node
         current_node = np.random.randint(len(available_nodes_full))
@@ -629,13 +609,13 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
 
             # Priority city to intersection connections
-            if current_node < num_cities and len(available_intersections) > 0:
+            if current_node < _num_cities and len(available_intersections) > 0:
                 available_nodes = available_intersections
                 delete_idx = np.where(available_cities == current_node)
                 available_cities = np.delete(available_cities, delete_idx, 0)
 
             # Priority intersection to city connections
-            elif current_node >= num_cities and len(available_cities) > 0:
+            elif current_node >= _num_cities and len(available_cities) > 0:
                 available_nodes = available_cities
                 delete_idx = np.where(available_intersections == current_node)
                 available_intersections = np.delete(available_intersections, delete_idx, 0)
@@ -669,15 +649,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             node_stack.pop(0)
 
         # Place train stations close to the node
-        # We currently place them uniformly distirbuted among all cities
+        # We currently place them uniformly distributed among all cities
         built_num_trainstation = 0
-        train_stations = [[] for i in range(num_cities)]
+        train_stations = [[] for i in range(_num_cities)]
 
-        if num_cities > 1:
+        if _num_cities > 1:
 
             for station in range(num_trainstations):
                 spot_found = True
-                trainstation_node = int(station / num_trainstations * num_cities)
+                trainstation_node = int(station / num_trainstations * _num_cities)
 
                 station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
                                     0,
@@ -725,7 +705,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         # We currently place them uniformly distirbuted among all cities
         if enhance_intersection:
 
-            for intersection in range(num_intersections):
+            for intersection in range(_num_intersections):
                 intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3),
                                         1,
                                         height - 2)
@@ -762,7 +742,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
         # Slot availability in node
         node_available_start = []
         node_available_target = []
-        for node_idx in range(num_cities):
+        for node_idx in range(_num_cities):
             node_available_start.append(len(train_stations[node_idx]))
             node_available_target.append(len(train_stations[node_idx]))
 
@@ -797,4 +777,57 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             'train_stations': train_stations
         }}
 
+    def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
+                                               width):
+
+        node_positions = []
+        for node_idx in range(nb_nodes):
+            to_close = True
+            tries = 0
+
+            while to_close:
+                x_tmp = node_radius + np.random.randint(height - node_radius)
+                y_tmp = node_radius + np.random.randint(width - node_radius)
+                to_close = False
+
+                # Check distance to cities
+                for node_pos in city_positions:
+                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                        to_close = True
+
+                # Check distance to intersections
+                for node_pos in intersection_positions:
+                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
+                        to_close = True
+
+                if not to_close:
+                    node_positions.append((x_tmp, y_tmp))
+                    if node_idx < num_cities:
+                        city_positions.append((x_tmp, y_tmp))
+                    else:
+                        intersection_positions.append((x_tmp, y_tmp))
+                tries += 1
+                if tries > 100:
+                    warnings.warn(
+                        "Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
+                            len(node_positions),
+                            tries, nb_nodes))
+                    break
+
+        node_positions = city_positions + intersection_positions
+        return node_positions
+
+    def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
+                                           nodes_per_row, x_positions, y_positions):
+        for node_idx in range(nb_nodes):
+
+            x_tmp = x_positions[node_idx % nodes_per_row]
+            y_tmp = y_positions[node_idx // nodes_per_row]
+            if node_idx in city_idx:
+                city_positions.append((x_tmp, y_tmp))
+            else:
+                intersection_positions.append((x_tmp, y_tmp))
+        node_positions = city_positions + intersection_positions
+        return node_positions
+
     return generator
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index 4645d80a..a0e2b995 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -55,24 +55,25 @@ def test_rail_env_action_required_info():
                                 obs_builder_object=GlobalObsForRailEnv())
     np.random.seed(0)
     env_only_if_action_required = RailEnv(width=50,
-                                   height=50,
-                                   rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
-                                                                        num_intersections=10,
-                                                                        # Number of interesections in map
-                                                                        num_trainstations=50,
-                                                                        # Number of possible start/targets on map
-                                                                        min_node_dist=6,  # Minimal distance of nodes
-                                                                        node_radius=3,
-                                                                        # Proximity of stations to city center
-                                                                        num_neighb=3,
-                                                                        # Number of connections to other cities
-                                                                        seed=5,  # Random seed
-                                                                        grid_mode=False
-                                                                        # Ordered distribution of nodes
-                                                                        ),
-                                   schedule_generator=sparse_schedule_generator(speed_ration_map),
-                                   number_of_agents=10,
-                                   obs_builder_object=GlobalObsForRailEnv())
+                                          height=50,
+                                          rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
+                                                                               num_intersections=10,
+                                                                               # Number of interesections in map
+                                                                               num_trainstations=50,
+                                                                               # Number of possible start/targets on map
+                                                                               min_node_dist=6,
+                                                                               # Minimal distance of nodes
+                                                                               node_radius=3,
+                                                                               # Proximity of stations to city center
+                                                                               num_neighb=3,
+                                                                               # Number of connections to other cities
+                                                                               seed=5,  # Random seed
+                                                                               grid_mode=False
+                                                                               # Ordered distribution of nodes
+                                                                               ),
+                                          schedule_generator=sparse_schedule_generator(speed_ration_map),
+                                          number_of_agents=10,
+                                          obs_builder_object=GlobalObsForRailEnv())
     env_renderer = RenderTool(env_always_action, gl="PILSVG", )
 
     for step in range(100):
@@ -87,7 +88,8 @@ def test_rail_env_action_required_info():
             if step == 0 or info_only_if_action_required['action_required'][a]:
                 action_dict_only_if_action_required.update({a: action})
             else:
-                print("[{}] not action_required {}, speed_data={}".format(step, a, env_always_action.agents[a].speed_data))
+                print("[{}] not action_required {}, speed_data={}".format(step, a,
+                                                                          env_always_action.agents[a].speed_data))
 
         obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
             action_dict_always_action)
@@ -156,3 +158,23 @@ def test_rail_env_malfunction_speed_info():
 
         if done['__all__']:
             break
+
+
+def test_sparse_generator_with_too_man_cities_does_not_break_down():
+    np.random.seed(0)
+
+    RailEnv(width=50,
+            height=50,
+            rail_generator=sparse_rail_generator(
+                num_cities=100,  # Number of cities in map
+                num_intersections=10,  # Number of interesections in map
+                num_trainstations=50,  # Number of possible start/targets on map
+                min_node_dist=6,  # Minimal distance of nodes
+                node_radius=3,  # Proximity of stations to city center
+                num_neighb=3,  # Number of connections to other cities
+                seed=5,  # Random seed
+                grid_mode=False  # Ordered distribution of nodes
+            ),
+            schedule_generator=sparse_schedule_generator(),
+            number_of_agents=10,
+            obs_builder_object=GlobalObsForRailEnv())
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index e60386c9..a63e9722 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -143,4 +143,5 @@ def test_malfunction_process_statistically():
         env.step(action_dict)
 
     # check that generation of malfunctions works as expected
-    assert nb_malfunction == 156
+    # results are different in py36 and py37, therefore no exact test on nb_malfunction
+    assert nb_malfunction > 150
-- 
GitLab