Skip to content
Snippets Groups Projects
Commit 35cbfa77 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

backwards compatibility with complex rail generator

parent af5ed987
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,7 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
env = RailEnv(width=20,
height=20,
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=0),
schedule_generator=complex_schedule_generator(),
obs_builder_object=TreeObservation,
number_of_agents=3)
......
......@@ -19,10 +19,10 @@ def connect_rail(
grid_map: GridTransitionMap,
start: IntVector2D,
end: IntVector2D,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
flip_start_node_trans=False,
flip_end_node_trans=False,
nice=True,
a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance,
forbidden_cells=None
) -> IntVector2DArray:
"""
......
......@@ -7,6 +7,7 @@ import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
......@@ -125,7 +126,8 @@ def complex_rail_generator(nr_start_goal=1,
# we might as well give up at this point
break
new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
flip_start_node_trans=True,
flip_end_node_trans=True, nice=True,
forbidden_cells=None)
if len(new_path) >= 2:
......@@ -152,7 +154,8 @@ def complex_rail_generator(nr_start_goal=1,
break
if not all_ok:
break
new_path = connect_rail(rail_trans, grid_map, start, goal, flip_start_node_trans=True,
new_path = connect_rail(rail_trans, grid_map, start, goal, Vec2d.get_chebyshev_distance,
flip_start_node_trans=True,
flip_end_node_trans=True, nice=True,
forbidden_cells=None)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment