From 743467833ea87162bce7289db86e779749013ec2 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 3 Oct 2019 08:48:20 +0200
Subject: [PATCH] refactor _city_cells to not use for loops

---
 flatland/envs/rail_generators.py | 29 +++++++++--------------------
 1 file changed, 9 insertions(+), 20 deletions(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index b132a5ca..e59c782e 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -641,7 +641,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
 
                 if not to_close:
                     node_positions.append((x_tmp, y_tmp))
-                    city_cells.extend(_city_cells(node_positions[-1], city_radius))
+                    city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius))
 
                 tries += 1
                 if tries > 200:
@@ -664,7 +664,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
             x_tmp = x_positions[node_idx % nodes_per_row]
             y_tmp = y_positions[node_idx // nodes_per_row]
             node_positions.append((x_tmp, y_tmp))
-            city_cells.extend(_city_cells(node_positions[-1], city_radius))
+            city_cells.extend(_get_cells_in_city(node_positions[-1], city_radius))
         return node_positions, city_cells
 
     def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2):
@@ -781,8 +781,6 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         for current_city in range(len(node_positions)):
             all_outer_connection_points = [item for sublist in outer_connection_points[current_city] for item in
                                            sublist]
-            city_boarder = _city_boarder(node_positions[current_city], node_radius)
-
             # This part only works if we have keep same number of connection points for both directions
             # Also only works with two connection direction at each city
             for i in range(4):
@@ -911,30 +909,21 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
         return sorted(range(len(seq)), key=seq.__getitem__)
 
-    def _city_cells(center, radius):
+    def _get_cells_in_city(center: Tuple[int], radius: int) -> List[Tuple[int, int]]:
         """
         Function to return all cells within a city
         :param center: center coordinates of city
         :param radius: radius of city (it is a square)
         :return: returns flat list of all cell coordinates in the city
         """
-        city_cells = []
-        for x in range(-radius, radius + 1):
-            for y in range(-radius, radius + 1):
-                city_cells.append((center[0] + x, center[1] + y))
-
-        return city_cells
-
-    def _city_boarder(center, radius):
-        city_boarder = []
-        for x in range(-radius, radius + 1):
-            for y in range(-radius, radius + 1):
-                if abs(x) == radius or abs(y) == radius:
-                    city_boarder.append((center[0] + x, center[1] + y))
-        return city_boarder
+        x_range = np.arange(center[0] - radius, center[0] + radius + 1)
+        y_range = np.arange(center[1] - radius, center[1] + radius + 1)
+        x_values = np.repeat(x_range, len(y_range))
+        y_values = np.tile(y_range, len(x_range))
+        return list(zip(x_values, y_values))
 
     def _city_overlap(center_1, center_2, radius):
-        return (np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius)
+        return np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius
 
     def _track_number(city_position, city_orientation, position):
         """
-- 
GitLab