From 9eb5bd553fda6918a091d4381d3f9dfd2a2695a3 Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipamc77@gmail.com>
Date: Tue, 7 Sep 2021 17:40:49 +0530
Subject: [PATCH] sample city positions without retries

---
 flatland/envs/rail_generators.py | 87 +++++++++++++++++++++-----------
 1 file changed, 57 insertions(+), 30 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 90dcfb36..9c65c6d6 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -218,7 +218,7 @@ class SparseRailGen(RailGen):
             'city_orientations' : orientation of cities
         """
         if np_random is None:
-            np_random = RandomState()
+            np_random = RandomState(self.seed)
             
         rail_trans = RailEnvTransitions()
         grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
@@ -240,6 +240,7 @@ class SparseRailGen(RailGen):
         # and reduce the number of cities to build to avoid problems
         max_feasible_cities = min(self.max_num_cities,
                                   ((height - 2) // (2 * (city_radius + 1))) * ((width - 2) // (2 * (city_radius + 1))))
+        
         if max_feasible_cities < 2:
             # sys.exit("[ABORT] Cannot fit more than one city in this map, no feasible environment possible! Aborting.")
             raise ValueError("ERROR: Cannot fit more than one city in this map, no feasible environment possible!")
@@ -252,7 +253,6 @@ class SparseRailGen(RailGen):
         else:
             city_positions = self._generate_random_city_positions(max_feasible_cities, city_radius, width, height,
                                                              np_random=np_random)
-
         # reduce num_cities if less were generated in random mode
         num_cities = len(city_positions)
         # If random generation failed just put the cities evenly
@@ -261,12 +261,12 @@ class SparseRailGen(RailGen):
             city_positions = self._generate_evenly_distr_city_positions(max_feasible_cities, city_radius, width,
                                                                    height)
         num_cities = len(city_positions)
-
         # Set up connection points for all cities
         inner_connection_points, outer_connection_points, city_orientations, city_cells = \
             self._generate_city_connection_points(
                 city_positions, city_radius, vector_field, rails_between_cities,
                 rail_pairs_in_city, np_random=np_random)
+        # import pdb; pdb.set_trace()
 
         # Connect the cities through the connection points
         inter_city_lines = self._connect_cities(city_positions, outer_connection_points, city_cells,
@@ -315,27 +315,52 @@ class SparseRailGen(RailGen):
         """
 
         city_positions: IntVector2DArray = []
-        for city_idx in range(num_cities):
-            too_close = True
-            tries = 0
-
-            while too_close:
-                row = city_radius + 1 + np_random.randint(height - 2 * (city_radius + 1))
-                col = city_radius + 1 + np_random.randint(width - 2 * (city_radius + 1))
-                too_close = False
-                # Check distance to cities
-                for city_pos in city_positions:
-                    if self.__class__._are_cities_overlapping((row, col), city_pos, 2 * (city_radius + 1) + 1):
-                        too_close = True
-
-                if not too_close:
-                    city_positions.append((row, col))
-
-                tries += 1
-                if tries > 200:
-                    warnings.warn(
-                        "Could not set all required cities!")
-                    break
+
+        # We track a grid of allowed indexes that can be sampled from for creating a new city
+        # This removes the old sampling method of retrying a random sample on failure
+        allowed_grid = np.zeros((height, width), dtype=np.uint8)
+        city_radius_pad1 = city_radius + 1
+        # Borders have to be not allowed from the start
+        # allowed_grid == 1 indicates locations that are allowed
+        allowed_grid[city_radius_pad1:-city_radius_pad1, city_radius_pad1:-city_radius_pad1] = 1
+        # This tracks the actual city borders
+        city_grid = np.ones((height, width), dtype=np.uint8)
+        for _ in range(num_cities):
+            allowed_indexes = np.where(allowed_grid == 1)
+            num_allowed_points = len(allowed_indexes[0])
+            if num_allowed_points == 0:
+                break
+            # Sample one of the allowed indexes
+            point_index = np_random.randint(num_allowed_points)
+            row = int(allowed_indexes[0][point_index])
+            col = int(allowed_indexes[1][point_index])
+            # # All points in the radius of the allowed point should be 1
+            assert np.all(city_grid[row - city_radius_pad1 : row + city_radius_pad1 + 1, 
+                                    col - city_radius_pad1 : col + city_radius_pad1 + 1]), \
+                                    "Sampling Error, Cities overlap"
+                                    
+                # Need to block city radius and extra margin so that next sampling is correct                                    
+            # Clipping handles the case for negative indexes being generated
+            row_start = max(0, row - 2 * city_radius_pad1)                
+            col_start = max(0, col - 2 * city_radius_pad1)
+            row_end = row + 2 * city_radius_pad1 + 1
+            col_end = col + 2 * city_radius_pad1 + 1
+
+            allowed_grid[row_start : row_end, col_start : col_end] = 0
+            
+            # City grids is needed for redundant assertion check above
+            row_start = max(0, row - city_radius_pad1)                
+            col_start = max(0, col - city_radius_pad1)
+            row_end = row + city_radius_pad1 + 1
+            col_end = col + city_radius_pad1 + 1
+            city_grid[row_start : row_end, col_start : col_end] = 0
+
+            city_positions.append((row, col))
+
+        created_cites = len(city_positions)
+        if created_cites < num_cities:
+            city_warning = f"Could not set all required cities! Created {created_cites}/{num_cities}"
+            warnings.warn(city_warning)
         return city_positions
 
     def _generate_evenly_distr_city_positions(self, num_cities: int, city_radius: int, width: int, height: int
@@ -360,7 +385,6 @@ class SparseRailGen(RailGen):
 
         """
         aspect_ratio = height / width
-
         # Compute max numbe of possible cities per row and col.
         # Respect padding at edges of environment
         # Respect padding between cities
@@ -529,13 +553,13 @@ class SparseRailGen(RailGen):
 
         grid4_directions = [Grid4TransitionsEnum.NORTH, Grid4TransitionsEnum.EAST, Grid4TransitionsEnum.SOUTH,
                             Grid4TransitionsEnum.WEST]
-
+        # import pdb; pdb.set_trace()
         for current_city_idx in np.arange(len(city_positions)):
             closest_neighbours = self._closest_neighbour_in_grid4_directions(current_city_idx, city_positions)
             for out_direction in grid4_directions:
-
+                
                 neighbour_idx = self.get_closest_neighbour_for_direction(closest_neighbours, out_direction)
-
+                
                 for city_out_connection_point in connection_points[current_city_idx][out_direction]:
 
                     min_connection_dist = np.inf
@@ -547,14 +571,17 @@ class SparseRailGen(RailGen):
                             if tmp_dist < min_connection_dist:
                                 min_connection_dist = tmp_dist
                                 neighbour_connection_point = tmp_in_connection_point
-
                     new_line = connect_rail_in_grid_map(grid_map, city_out_connection_point, neighbour_connection_point,
                                                         rail_trans, flip_start_node_trans=False,
                                                         flip_end_node_trans=False, respect_transition_validity=False,
                                                         avoid_rail=True,
                                                         forbidden_cells=city_cells)
+                    if len(new_line) == 0:
+                        warnings.warn("[WARNING] Unable to connect requested stations")                                                    
+                    elif new_line[-1] != neighbour_connection_point or new_line[0] != city_out_connection_point:
+                        warnings.warn("[WARNING] Unable to connect requested stations")
                     all_paths.extend(new_line)
-
+        import pdb; pdb.set_trace()
         return all_paths
 
     def get_closest_neighbour_for_direction(self, closest_neighbours, out_direction):
-- 
GitLab