diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 2334450c00233cd86ad4210b3200b0a91988723a..5c112434563eadd66111421c2748dce8c87388c9 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -32,14 +32,14 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are) num_intersections=0, # Number of intersections (no start / target) - num_trainstations=15, # Number of possible start/targets on map + num_trainstations=50, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center - num_neighb=2, # Number of connections to other cities/intersections + num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed - grid_mode=False, + grid_mode=True, enhance_intersection=False ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index 3a75aa81193d2355f71a05d8825bc64da4547f6f..3b6de032aa3034ea6599505f48ab4a632dcf1c28 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -1,5 +1,3 @@ -import numpy as np - from flatland.core.grid.grid_utils import IntVector2D, IntVector2DDistance from flatland.core.grid.grid_utils import IntVector2DArray from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d @@ -39,7 +37,7 @@ class AStarNode: def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True) -> IntVector2DArray: """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. @@ -93,7 +91,8 @@ def a_star(grid_map: GridTransitionMap, continue # validate positions - if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos): + # + if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice: continue # create new node diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index fce3ffdf320a3c38d7f0551151ffdc8debe6ab5d..72b59d8c4ea8d911c971645541c056704917745f 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -19,14 +19,16 @@ def connect_basic_operation( end: IntVector2D, flip_start_node_trans=False, flip_end_node_trans=False, + nice=True, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: """ Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and returns the path created as a list of positions. """ # in the worst case we will need to do a A* search, so we might as well set that up - path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function) + path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice) if len(path) < 2: + print("No path found", path) return [] current_dir = get_direction(path[0], path[1]) end_pos = path[-1] @@ -54,6 +56,7 @@ def connect_basic_operation( new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) grid_map.grid[current_pos] = new_trans + if new_pos == end_pos: # setup end pos setup new_trans_e = grid_map.grid[end_pos] @@ -81,7 +84,7 @@ def connect_rail(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: - return connect_basic_operation(rail_trans, grid_map, start, end, False, False, a_star_distance_function) + return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function) def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 862774319ce58c2625b227b89b77940c12016e89..5472f0c5b8d81a2925261d72cab8f8743a632c30 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -252,6 +252,7 @@ class RailEnv(Environment): rc_pos = (r, c) check = self.rail.cell_neighbours_valid(rc_pos, True) if not check: + self.rail.fix_transitions(rc_pos) warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) if replace_agents: diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index c0721f62631e35bdeb8e0db7e4aa9ff260153106..5f1d2f5a3c9acefcb448b6cb4a920f7a50a26346 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -595,16 +595,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 # Chose node connection # Set up list of available nodes to connect to - available_nodes_full = np.arange(nb_nodes) - available_cities = np.arange(_num_cities) - available_intersections = np.arange(_num_cities, nb_nodes) + available_nodes = np.arange(nb_nodes) - # Set up connection points + # Set up connection points for all cities connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points=8) + # Start at some node - current_node = np.random.randint(len(available_nodes_full)) + current_node = np.random.randint(len(available_nodes)) node_stack = [current_node] - open_nodes = np.copy(available_nodes_full) + open_nodes = np.copy(available_nodes) allowed_connections = num_neighb i = 0 boarder_connections = set() @@ -617,22 +616,6 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 delete_idx = np.where(open_nodes == current_node) open_nodes = np.delete(open_nodes, delete_idx, 0) - # Priority city to intersection connections - if current_node < _num_cities and len(available_intersections) > 0: - available_nodes = available_intersections - delete_idx = np.where(available_cities == current_node) - # available_cities = np.delete(available_cities, delete_idx, 0) - - # Priority intersection to city connections - elif current_node >= _num_cities and len(available_cities) > 0: - available_nodes = available_cities - delete_idx = np.where(available_intersections == current_node) - # available_intersections = np.delete(available_intersections, delete_idx, 0) - - # If no options possible connect to whatever node is still available - else: - available_nodes = available_nodes_full - # Sort available neighbors according to their distance. node_dist = [] for av_node in available_nodes: @@ -764,6 +747,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 for tbd in to_be_deleted: boarder_connections.remove(tbd) + print(boarder_connections) # Fix all nodes with illegal transition maps flat_trainstation_list = [item for sublist in train_stations for item in sublist] for cell_to_fix in flat_trainstation_list: