From 65844b60aabe2040b267a50848e2a91e3a62bdcc Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 28 Sep 2019 16:11:25 -0400
Subject: [PATCH] extended size of rails to get more possible train station
 positions

---
 examples/flatland_2_0_example.py | 2 +-
 flatland/envs/rail_generators.py | 7 ++++---
 2 files changed, 5 insertions(+), 4 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 4a614b2f..2d393fd4 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -34,7 +34,7 @@ env = RailEnv(width=50,
                                                    seed=0,  # Random seed
                                                    grid_mode=False,
                                                    max_inter_city_rails=2,
-                                                   max_tracks_in_city=8,
+                                                   max_tracks_in_city=4,
                                                    ),
               schedule_generator=sparse_schedule_generator(),
               number_of_agents=10,
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 74d17b83..725943a4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -551,7 +551,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
         rail_array = grid_map.grid
         rail_array.fill(0)
         np.random.seed(seed + num_resets)
-        node_radius = int(np.ceil(max_tracks_in_city / 2))
+        node_radius = int(np.ceil(max_tracks_in_city / 2)) + 2
         max_inter_city_rails_allowed = max_inter_city_rails
         if max_inter_city_rails_allowed > max_tracks_in_city:
             max_inter_city_rails_allowed = max_tracks_in_city
@@ -616,8 +616,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
             tries = 0
 
             while to_close:
-                x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius)
-                y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius)
+                x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius - 1)
+                y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius - 1)
                 to_close = False
                 # Check distance to nodes
                 for node_pos in node_positions:
@@ -806,6 +806,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
                 if 1 <= nbits <= 2:
                     built_num_trainstations += 1
                     train_stations[current_city].append(possible_location)
+        print(built_num_trainstations)
         return train_stations, built_num_trainstations
 
     def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
-- 
GitLab