diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index ccfe5a1df8f869c5a45880a1fcc8c72295ddcac9..cfe733a37a817b25a4ebb16a4bdbd4bce1e3dfa8 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -34,10 +34,10 @@ env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) num_trainstations=50, # Number of possible start/targets on map - min_node_dist=8, # Minimal distance of nodes + min_node_dist=5, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center seed=15, # Random seed - grid_mode=False, + grid_mode=True, max_connection_points_per_side=2, max_nr_connection_directions=4 ), diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index f760e5dca0b0d27d214f18365839d6796acc1c31..a0e07fc02eda83cd6de4a12534fa7daf77e58876 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -98,10 +98,11 @@ def a_star(grid_map: GridTransitionMap, # create new node new_node = AStarNode(node_pos, current_node) - # Skip paths through forbidden regions. + # Skip paths through forbidden regions if they are provided if forbidden_cells is not None: if node_pos in forbidden_cells and new_node != start_node and new_node != end_node: continue + children.append(new_node) # loop through children diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 9e060986daa852c1231a725d1ee93469465cc896..47669c9bf68c72f0ef7267b83e1f6e6a9dd626d8 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -576,7 +576,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width) else: node_positions, city_cells = _generate_node_positions_not_grid_mode(nb_nodes, height, width) - print(city_cells) # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode nb_nodes = len(node_positions)