From 214b08a14ab39d4882a13566707fcb6dab7de27e Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Fri, 27 Sep 2019 10:13:14 -0400
Subject: [PATCH] code cleanup and added city cells in order to avoid drawing
 paths through cities

---
 examples/flatland_2_0_example.py        |  2 +-
 flatland/core/grid/grid4_astar.py       |  9 ++-
 flatland/envs/grid4_generators_utils.py | 12 +++-
 flatland/envs/rail_generators.py        | 82 ++++++++++---------------
 4 files changed, 51 insertions(+), 54 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index f2bb4aba..ce1dda2d 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -33,7 +33,7 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 env = RailEnv(width=50,
               height=50,
               rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
-                                                   num_trainstations=0,  # Number of possible start/targets on map
+                                                   num_trainstations=50,  # Number of possible start/targets on map
                                                    min_node_dist=8,  # Minimal distance of nodes
                                                    node_radius=3,  # Proximity of stations to city center
                                                    seed=15,  # Random seed
diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index 3b6de032..8b757435 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -37,7 +37,8 @@ class AStarNode:
 
 def a_star(grid_map: GridTransitionMap,
            start: IntVector2D, end: IntVector2D,
-           a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True) -> IntVector2DArray:
+           a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True,
+           forbidden_cells=None) -> IntVector2DArray:
     """
     Returns a list of tuples as a path from the given start to end.
     If no path is found, returns path to closest point to end.
@@ -90,11 +91,15 @@ def a_star(grid_map: GridTransitionMap,
             if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0:
                 continue
 
+            # Skip paths through forbidden regions.
+            if forbidden_cells is not None:
+                if node_pos in forbidden_cells and node_pos != start_node and node_pos != end_node:
+                    continue
+
             # validate positions
             #
             if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice:
                 continue
-
             # create new node
             new_node = AStarNode(node_pos, current_node)
             children.append(new_node)
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index 023e96e0..166094aa 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -20,13 +20,15 @@ def connect_basic_operation(
     flip_start_node_trans=False,
     flip_end_node_trans=False,
     nice=True,
-    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
+    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
+    forbidden_cells=None
+) -> IntVector2DArray:
     """
     Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
     returns the path created as a list of positions.
     """
     # in the worst case we will need to do a A* search, so we might as well set that up
-    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice)
+    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells)
     if len(path) < 2:
         print("No path found", path)
         return []
@@ -87,6 +89,12 @@ def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
     return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function)
 
 
+def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
+                   start: IntVector2D, end: IntVector2D, forbidden_cells,
+                   a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
+    return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function,
+                                   forbidden_cells)
+
 def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                        start: IntVector2D, end: IntVector2D,
                        a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 8c66fe39..0ffbe00d 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
 from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes
+from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_cities
 
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
@@ -573,25 +573,9 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         node_positions: List[Any] = None
         nb_nodes = num_cities
         if grid_mode:
-            nodes_ratio = height / width
-            nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
-            nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
-            x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
-            y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
-            city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False)
-
-            node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions,
-                                                                nb_nodes,
-                                                                nodes_per_row, x_positions,
-                                                                y_positions)
-
-
-
+            node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width)
         else:
-
-            node_positions = _generate_node_positions_not_grid_mode(city_positions, height,
-                                                                    intersection_positions,
-                                                                    nb_nodes, width)
+            node_positions = _generate_node_positions_not_grid_mode(nb_nodes, height, width)
 
         # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
         nb_nodes = len(node_positions)
@@ -624,8 +608,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
             'train_stations': train_stations
         }}
 
-    def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
-                                               width):
+    def _generate_node_positions_not_grid_mode(nb_nodes, height, width):
 
         node_positions = []
         for node_idx in range(nb_nodes):
@@ -637,22 +620,14 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                 y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1)
                 to_close = False
 
-                # Check distance to cities
-                for node_pos in city_positions:
-                    if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
-                        to_close = True
-
-                # Check distance to intersections
-                for node_pos in intersection_positions:
+                # Check distance to nodes
+                for node_pos in node_positions:
                     if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
                         to_close = True
 
                 if not to_close:
                     node_positions.append((x_tmp, y_tmp))
-                    if node_idx < num_cities:
-                        city_positions.append((x_tmp, y_tmp))
-                    else:
-                        intersection_positions.append((x_tmp, y_tmp))
+
                 tries += 1
                 if tries > 100:
                     warnings.warn(
@@ -661,23 +636,21 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                             tries, nb_nodes))
                     break
 
-        node_positions = city_positions + intersection_positions
         return node_positions
 
-    def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
-                                           nodes_per_row, x_positions, y_positions):
-
+    def _generate_node_positions_grid_mode(nb_nodes, height, width):
+        nodes_ratio = height / width
+        nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
+        nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
+        x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
+        y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
+        node_positions = []
+        forbidden_cells = []
         for node_idx in range(nb_nodes):
-
             x_tmp = x_positions[node_idx % nodes_per_row]
             y_tmp = y_positions[node_idx // nodes_per_row]
-            if node_idx in city_idx:
-                city_positions.append((x_tmp, y_tmp))
-
-            else:
-                intersection_positions.append((x_tmp, y_tmp))
-        node_positions = city_positions + intersection_positions
-        return node_positions
+            node_positions.append((x_tmp, y_tmp))
+        return node_positions, forbidden_cells
 
     def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2,
                                          max_nr_connection_directions=2):
@@ -698,8 +671,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
             # Store the directions to these neighbours
             connection_sides_idx = []
             idx = 1
-            # TODO: Change the way this code works! Check that we get sufficient direction.
-            # TODO: Check if this works as expected
             while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist):
                 current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
                 print(node_position)
@@ -707,12 +678,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                     connection_sides_idx.append(current_closest_direction)
                 idx += 1
 
-
             # set the number of connection points for each direction
             connections_per_direction = np.zeros(4, dtype=int)
 
             for idx in connection_sides_idx:
-                nr_of_connection_points = max_nr_connection_points  # np.random.randint(1, max_nr_connection_points + 1)
+                nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1)
 
                 connections_per_direction[idx] = nr_of_connection_points
             connection_points_coordinates = []
@@ -775,7 +745,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
                             if tmp_dist < min_connection_dist:
                                 min_connection_dist = tmp_dist
                                 neighb_connection_point = tmp_in_connection_point
-                        connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
+                        connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, None)
                         boarder_connections.add((tmp_out_connection_point, current_node))
                         boarder_connections.add((neighb_connection_point, neighb_idx))
                 direction += 1
@@ -944,4 +914,18 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
         # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
         return sorted(range(len(seq)), key=seq.__getitem__)
 
+    def _city_cells(center, radius):
+        """
+        Function to return all cells within a city
+        :param center: center coordinates of city
+        :param radius: radius of city (it is a square)
+        :return: returns flat list of all cell coordinates in the city
+        """
+        city_cells = []
+        for x in range(-radius, radius):
+            for y in range(-radius, radius):
+                city_cells.append(center[0] + x, center[1] + y)
+
+        return city_cells
+
     return generator
-- 
GitLab