From 672d1169b2a028af0d1fff2425908a003844d0ae Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 25 Sep 2019 15:14:52 -0400
Subject: [PATCH] fixed how a-start works. we can now force it to ignore
 illegal rail elements. This might be a bit dangerous as you need to fix them
 later one ( done in current code). Maybe better suggestions here!?

---
 examples/flatland_2_0_example.py        |  8 +++----
 flatland/core/grid/grid4_astar.py       |  7 +++----
 flatland/envs/grid4_generators_utils.py |  7 +++++--
 flatland/envs/rail_env.py               |  1 +
 flatland/envs/rail_generators.py        | 28 ++++++-------------------
 5 files changed, 19 insertions(+), 32 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 2334450c..5c112434 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -32,14 +32,14 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
 
 env = RailEnv(width=50,
               height=50,
-              rail_generator=sparse_rail_generator(num_cities=9,  # Number of cities in map (where train stations are)
+              rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map (where train stations are)
                                                    num_intersections=0,  # Number of intersections (no start / target)
-                                                   num_trainstations=15,  # Number of possible start/targets on map
+                                                   num_trainstations=50,  # Number of possible start/targets on map
                                                    min_node_dist=10,  # Minimal distance of nodes
                                                    node_radius=4,  # Proximity of stations to city center
-                                                   num_neighb=2,  # Number of connections to other cities/intersections
+                                                   num_neighb=3,  # Number of connections to other cities/intersections
                                                    seed=15,  # Random seed
-                                                   grid_mode=False,
+                                                   grid_mode=True,
                                                    enhance_intersection=False
                                                    ),
               schedule_generator=sparse_schedule_generator(),
diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py
index 3a75aa81..3b6de032 100644
--- a/flatland/core/grid/grid4_astar.py
+++ b/flatland/core/grid/grid4_astar.py
@@ -1,5 +1,3 @@
-import numpy as np
-
 from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance
 from flatland.core.grid.grid_utils import IntVector2DArray
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
@@ -39,7 +37,7 @@ class AStarNode:
 
 def a_star(grid_map: GridTransitionMap,
            start: IntVector2D, end: IntVector2D,
-           a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
+           a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True) -> IntVector2DArray:
     """
     Returns a list of tuples as a path from the given start to end.
     If no path is found, returns path to closest point to end.
@@ -93,7 +91,8 @@ def a_star(grid_map: GridTransitionMap,
                 continue
 
             # validate positions
-            if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos):
+            #
+            if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice:
                 continue
 
             # create new node
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index fce3ffdf..72b59d8c 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -19,14 +19,16 @@ def connect_basic_operation(
     end: IntVector2D,
     flip_start_node_trans=False,
     flip_end_node_trans=False,
+    nice=True,
     a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
     """
     Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and
     returns the path created as a list of positions.
     """
     # in the worst case we will need to do a A* search, so we might as well set that up
-    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function)
+    path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice)
     if len(path) < 2:
+        print("No path found", path)
         return []
     current_dir = get_direction(path[0], path[1])
     end_pos = path[-1]
@@ -54,6 +56,7 @@ def connect_basic_operation(
             new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1)
         grid_map.grid[current_pos] = new_trans
 
+
         if new_pos == end_pos:
             # setup end pos setup
             new_trans_e = grid_map.grid[end_pos]
@@ -81,7 +84,7 @@ def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
 def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
                   start: IntVector2D, end: IntVector2D,
                   a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray:
-    return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function)
+    return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function)
 
 
 def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 86277431..5472f0c5 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -252,6 +252,7 @@ class RailEnv(Environment):
                     rc_pos = (r, c)
                     check = self.rail.cell_neighbours_valid(rc_pos, True)
                     if not check:
+                        self.rail.fix_transitions(rc_pos)
                         warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
 
         if replace_agents:
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index c0721f62..5f1d2f5a 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -595,16 +595,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
         # Chose node connection
         # Set up list of available nodes to connect to
-        available_nodes_full = np.arange(nb_nodes)
-        available_cities = np.arange(_num_cities)
-        available_intersections = np.arange(_num_cities, nb_nodes)
+        available_nodes = np.arange(nb_nodes)
 
-        # Set up connection points
+        # Set up connection points for all cities
         connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points=8)
+
         # Start at some node
-        current_node = np.random.randint(len(available_nodes_full))
+        current_node = np.random.randint(len(available_nodes))
         node_stack = [current_node]
-        open_nodes = np.copy(available_nodes_full)
+        open_nodes = np.copy(available_nodes)
         allowed_connections = num_neighb
         i = 0
         boarder_connections = set()
@@ -617,22 +616,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
             delete_idx = np.where(open_nodes == current_node)
             open_nodes = np.delete(open_nodes, delete_idx, 0)
 
-            # Priority city to intersection connections
-            if current_node < _num_cities and len(available_intersections) > 0:
-                available_nodes = available_intersections
-                delete_idx = np.where(available_cities == current_node)
-                # available_cities = np.delete(available_cities, delete_idx, 0)
-
-            # Priority intersection to city connections
-            elif current_node >= _num_cities and len(available_cities) > 0:
-                available_nodes = available_cities
-                delete_idx = np.where(available_intersections == current_node)
-                # available_intersections = np.delete(available_intersections, delete_idx, 0)
-
-            # If no options possible connect to whatever node is still available
-            else:
-                available_nodes = available_nodes_full
-
             # Sort available neighbors according to their distance.
             node_dist = []
             for av_node in available_nodes:
@@ -764,6 +747,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
 
         for tbd in to_be_deleted:
             boarder_connections.remove(tbd)
+        print(boarder_connections)
         # Fix all nodes with illegal transition maps
         flat_trainstation_list = [item for sublist in train_stations for item in sublist]
         for cell_to_fix in flat_trainstation_list:
-- 
GitLab