From 80d4a184354d5668695c7e01c8886d67058f67c0 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 26 Sep 2019 18:00:30 -0400
Subject: [PATCH] fixed errors in random node distribution

---
 examples/flatland_2_0_example.py |  4 ++--
 flatland/envs/rail_generators.py | 10 +++++++---
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 69959ca8..ab730a0a 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -37,8 +37,8 @@ env = RailEnv(width=50,
                                                    min_node_dist=15,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
                                                    seed=15,  # Random seed
-                                                   grid_mode=True,
-                                                   connection_points_per_side=3,
+                                                   grid_mode=False,
+                                                   max_connection_points_per_side=2,
                                                    max_nr_connection_directions=4,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 178da0c7..f7920c29 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -533,7 +533,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
 
 
 def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, node_radius=2,
-                          grid_mode=False, connection_points_per_side=4,
+                          grid_mode=False, max_connection_points_per_side=4,
                           max_nr_connection_directions=2,
                           seed=0) -> RailGenerator:
     """
@@ -598,7 +598,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
 
         # Set up connection points for all cities
         connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius,
-                                                                              connection_points_per_side,
+                                                                              max_connection_points_per_side,
                                                                               max_nr_connection_directions)
 
         # Connect the cities through the connection points
@@ -821,7 +821,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
 
                 # Connect train station to random nodes
 
-                rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False)
+                if len(connection_points[trainstation_node]) > 1:
+                    rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2,
+                                                         replace=False)
+                else:
+                    rand_corner_nodes = [0]
 
                 for corner_node_idx in rand_corner_nodes:
                     connection = connect_nodes(rail_trans, grid_map,
-- 
GitLab