From 35cbfa77546883d6f224ab8bda965e7730748736 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 3 Oct 2019 15:41:24 -0400
Subject: [PATCH] backwards compatibility with complex rail generator

---
 examples/training_example.py            | 2 +-
 flatland/envs/grid4_generators_utils.py | 2 +-
 flatland/envs/rail_generators.py        | 7 +++++--
 3 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index df93479f..78c0299d 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -16,7 +16,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
 LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
 env = RailEnv(width=20,
               height=20,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=0),
               schedule_generator=complex_schedule_generator(),
               obs_builder_object=TreeObservation,
               number_of_agents=3)
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index dfd9218f..a332cd70 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -19,10 +19,10 @@ def connect_rail(
     grid_map: GridTransitionMap,
     start: IntVector2D,
     end: IntVector2D,
+    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
     flip_start_node_trans=False,
     flip_end_node_trans=False,
     nice=True,
-    a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
     forbidden_cells=None
 ) -> IntVector2DArray:
     """
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index ee31666c..557f25e4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -7,6 +7,7 @@ import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_direction, mirror
+from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
 from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \
     Vec2dOperations
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
@@ -125,7 +126,8 @@ def complex_rail_generator(nr_start_goal=1,
                 # we might as well give up at this point
                 break
 
-            new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
+            new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
+                                    flip_start_node_trans=True,
                                     flip_end_node_trans=True, nice=True,
                                     forbidden_cells=None)
             if len(new_path) >= 2:
@@ -152,7 +154,8 @@ def complex_rail_generator(nr_start_goal=1,
                     break
             if not all_ok:
                 break
-            new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
+            new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
+                                    flip_start_node_trans=True,
                                     flip_end_node_trans=True, nice=True,
                                     forbidden_cells=None)
 
-- 
GitLab