From 35cbfa77546883d6f224ab8bda965e7730748736 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 3 Oct 2019 15:41:24 -0400 Subject: [PATCH] backwards compatibility with complex rail generator --- examples/training_example.py | 2 +- flatland/envs/grid4_generators_utils.py | 2 +- flatland/envs/rail_generators.py | 7 +++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/training_example.py b/examples/training_example.py index df93479f..78c0299d 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -16,7 +16,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=0), schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index dfd9218f..a332cd70 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -19,10 +19,10 @@ def connect_rail( grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, flip_start_node_trans=False, flip_end_node_trans=False, nice=True, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, forbidden_cells=None ) -> IntVector2DArray: """ diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index ee31666c..557f25e4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -7,6 +7,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \ Vec2dOperations from flatland.core.grid.rail_env_grid import RailEnvTransitions @@ -125,7 +126,8 @@ def complex_rail_generator(nr_start_goal=1, # we might as well give up at this point break - new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True, + new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance, + flip_start_node_trans=True, flip_end_node_trans=True, nice=True, forbidden_cells=None) if len(new_path) >= 2: @@ -152,7 +154,8 @@ def complex_rail_generator(nr_start_goal=1, break if not all_ok: break - new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True, + new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance, + flip_start_node_trans=True, flip_end_node_trans=True, nice=True, forbidden_cells=None) -- GitLab