diff --git a/examples/training_example.py b/examples/training_example.py index df93479f5a5ee05abfcb1a98b07ef052bffc2bd4..78c0299d4cee8ae588bcf8e9e7559ff1c8364c26 100644 --- a/examples/training_example.py +++ b/examples/training_example.py @@ -16,7 +16,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2) env = RailEnv(width=20, height=20, - rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0), + rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=0), schedule_generator=complex_schedule_generator(), obs_builder_object=TreeObservation, number_of_agents=3) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index dfd9218f0e575fc8909994d635315e6cab6cd5b3..a332cd70b3c8055856332ba1214be9e1c0df55f8 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -19,10 +19,10 @@ def connect_rail( grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, flip_start_node_trans=False, flip_end_node_trans=False, nice=True, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, forbidden_cells=None ) -> IntVector2DArray: """ diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index ee31666c99082e77855764c7bdee72031034ad22..557f25e4ca956e541ae80be051e7f9fa7e93fc93 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -7,6 +7,7 @@ import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \ Vec2dOperations from flatland.core.grid.rail_env_grid import RailEnvTransitions @@ -125,7 +126,8 @@ def complex_rail_generator(nr_start_goal=1, # we might as well give up at this point break - new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True, + new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance, + flip_start_node_trans=True, flip_end_node_trans=True, nice=True, forbidden_cells=None) if len(new_path) >= 2: @@ -152,7 +154,8 @@ def complex_rail_generator(nr_start_goal=1, break if not all_ok: break - new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True, + new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance, + flip_start_node_trans=True, flip_end_node_trans=True, nice=True, forbidden_cells=None)