From 214b08a14ab39d4882a13566707fcb6dab7de27e Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 27 Sep 2019 10:13:14 -0400 Subject: [PATCH] code cleanup and added city cells in order to avoid drawing paths through cities --- examples/flatland_2_0_example.py | 2 +- flatland/core/grid/grid4_astar.py | 9 ++- flatland/envs/grid4_generators_utils.py | 12 +++- flatland/envs/rail_generators.py | 82 ++++++++++--------------- 4 files changed, 51 insertions(+), 54 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index f2bb4aba..ce1dda2d 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -33,7 +33,7 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) - num_trainstations=0, # Number of possible start/targets on map + num_trainstations=50, # Number of possible start/targets on map min_node_dist=8, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center seed=15, # Random seed diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index 3b6de032..8b757435 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -37,7 +37,8 @@ class AStarNode: def a_star(grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True) -> IntVector2DArray: + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, nice=True, + forbidden_cells=None) -> IntVector2DArray: """ Returns a list of tuples as a path from the given start to end. If no path is found, returns path to closest point to end. @@ -90,11 +91,15 @@ def a_star(grid_map: GridTransitionMap, if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: continue + # Skip paths through forbidden regions. + if forbidden_cells is not None: + if node_pos in forbidden_cells and node_pos != start_node and node_pos != end_node: + continue + # validate positions # if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice: continue - # create new node new_node = AStarNode(node_pos, current_node) children.append(new_node) diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index 023e96e0..166094aa 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -20,13 +20,15 @@ def connect_basic_operation( flip_start_node_trans=False, flip_end_node_trans=False, nice=True, - a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance, + forbidden_cells=None +) -> IntVector2DArray: """ Creates a new path [start,end] in `grid_map.grid`, based on rail_trans, and returns the path created as a list of positions. """ # in the worst case we will need to do a A* search, so we might as well set that up - path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice) + path: IntVector2DArray = a_star(grid_map, start, end, a_star_distance_function, nice, forbidden_cells) if len(path) < 2: print("No path found", path) return [] @@ -87,6 +89,12 @@ def connect_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function) +def connect_cities(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + start: IntVector2D, end: IntVector2D, forbidden_cells, + a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance) -> IntVector2DArray: + return connect_basic_operation(rail_trans, grid_map, start, end, False, False, False, a_star_distance_function, + forbidden_cells) + def connect_from_nodes(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, start: IntVector2D, end: IntVector2D, a_star_distance_function: IntVector2DDistance = Vec2d.get_manhattan_distance diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 8c66fe39..0ffbe00d 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap -from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes +from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_cities RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]] RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct] @@ -573,25 +573,9 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n node_positions: List[Any] = None nb_nodes = num_cities if grid_mode: - nodes_ratio = height / width - nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) - nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) - x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int) - y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int) - city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False) - - node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, - nb_nodes, - nodes_per_row, x_positions, - y_positions) - - - + node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width) else: - - node_positions = _generate_node_positions_not_grid_mode(city_positions, height, - intersection_positions, - nb_nodes, width) + node_positions = _generate_node_positions_not_grid_mode(nb_nodes, height, width) # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode nb_nodes = len(node_positions) @@ -624,8 +608,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n 'train_stations': train_stations }} - def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, - width): + def _generate_node_positions_not_grid_mode(nb_nodes, height, width): node_positions = [] for node_idx in range(nb_nodes): @@ -637,22 +620,14 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n y_tmp = node_radius + np.random.randint(width - 2 * node_radius - 1) to_close = False - # Check distance to cities - for node_pos in city_positions: - if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: - to_close = True - - # Check distance to intersections - for node_pos in intersection_positions: + # Check distance to nodes + for node_pos in node_positions: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: to_close = True if not to_close: node_positions.append((x_tmp, y_tmp)) - if node_idx < num_cities: - city_positions.append((x_tmp, y_tmp)) - else: - intersection_positions.append((x_tmp, y_tmp)) + tries += 1 if tries > 100: warnings.warn( @@ -661,23 +636,21 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n tries, nb_nodes)) break - node_positions = city_positions + intersection_positions return node_positions - def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes, - nodes_per_row, x_positions, y_positions): - + def _generate_node_positions_grid_mode(nb_nodes, height, width): + nodes_ratio = height / width + nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) + nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) + x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int) + y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int) + node_positions = [] + forbidden_cells = [] for node_idx in range(nb_nodes): - x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] - if node_idx in city_idx: - city_positions.append((x_tmp, y_tmp)) - - else: - intersection_positions.append((x_tmp, y_tmp)) - node_positions = city_positions + intersection_positions - return node_positions + node_positions.append((x_tmp, y_tmp)) + return node_positions, forbidden_cells def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2, max_nr_connection_directions=2): @@ -698,8 +671,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Store the directions to these neighbours connection_sides_idx = [] idx = 1 - # TODO: Change the way this code works! Check that we get sufficient direction. - # TODO: Check if this works as expected while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist): current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) print(node_position) @@ -707,12 +678,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n connection_sides_idx.append(current_closest_direction) idx += 1 - # set the number of connection points for each direction connections_per_direction = np.zeros(4, dtype=int) for idx in connection_sides_idx: - nr_of_connection_points = max_nr_connection_points # np.random.randint(1, max_nr_connection_points + 1) + nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1) connections_per_direction[idx] = nr_of_connection_points connection_points_coordinates = [] @@ -775,7 +745,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n if tmp_dist < min_connection_dist: min_connection_dist = tmp_dist neighb_connection_point = tmp_in_connection_point - connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point) + connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, None) boarder_connections.add((tmp_out_connection_point, current_node)) boarder_connections.add((neighb_connection_point, neighb_idx)) direction += 1 @@ -944,4 +914,18 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python return sorted(range(len(seq)), key=seq.__getitem__) + def _city_cells(center, radius): + """ + Function to return all cells within a city + :param center: center coordinates of city + :param radius: radius of city (it is a square) + :return: returns flat list of all cell coordinates in the city + """ + city_cells = [] + for x in range(-radius, radius): + for y in range(-radius, radius): + city_cells.append(center[0] + x, center[1] + y) + + return city_cells + return generator -- GitLab