diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index ce1dda2dcf09223f0265ecada213d5d04ff249a2..ccfe5a1df8f869c5a45880a1fcc8c72295ddcac9 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,7 +37,7 @@ env = RailEnv(width=50, min_node_dist=8, # Minimal distance of nodes node_radius=3, # Proximity of stations to city center seed=15, # Random seed - grid_mode=True, + grid_mode=False, max_connection_points_per_side=2, max_nr_connection_directions=4 ), diff --git a/flatland/core/grid/grid4_astar.py b/flatland/core/grid/grid4_astar.py index 8b7574354326d9897269eb0c4b4698ae9e22b145..f760e5dca0b0d27d214f18365839d6796acc1c31 100644 --- a/flatland/core/grid/grid4_astar.py +++ b/flatland/core/grid/grid4_astar.py @@ -91,17 +91,17 @@ def a_star(grid_map: GridTransitionMap, if node_pos[0] >= rail_shape[0] or node_pos[0] < 0 or node_pos[1] >= rail_shape[1] or node_pos[1] < 0: continue - # Skip paths through forbidden regions. - if forbidden_cells is not None: - if node_pos in forbidden_cells and node_pos != start_node and node_pos != end_node: - continue - # validate positions # if not grid_map.validate_new_transition(prev_pos, current_node.pos, node_pos, end_node.pos) and nice: continue # create new node new_node = AStarNode(node_pos, current_node) + + # Skip paths through forbidden regions. + if forbidden_cells is not None: + if node_pos in forbidden_cells and new_node != start_node and new_node != end_node: + continue children.append(new_node) # loop through children diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 0ffbe00d31cdfb119dd7a325848092345b9c33a9..9e060986daa852c1231a725d1ee93469465cc896 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -575,7 +575,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n if grid_mode: node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width) else: - node_positions = _generate_node_positions_not_grid_mode(nb_nodes, height, width) + node_positions, city_cells = _generate_node_positions_not_grid_mode(nb_nodes, height, width) + print(city_cells) # reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode nb_nodes = len(node_positions) @@ -586,7 +587,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n max_nr_connection_directions) # Connect the cities through the connection points - _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map) + _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map) # Build inner cities train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map) @@ -611,6 +612,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n def _generate_node_positions_not_grid_mode(nb_nodes, height, width): node_positions = [] + city_cells = [] + for node_idx in range(nb_nodes): to_close = True tries = 0 @@ -627,6 +630,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n if not to_close: node_positions.append((x_tmp, y_tmp)) + city_cells.extend(_city_cells(node_positions[-1], node_radius)) tries += 1 if tries > 100: @@ -636,7 +640,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n tries, nb_nodes)) break - return node_positions + return node_positions, city_cells def _generate_node_positions_grid_mode(nb_nodes, height, width): nodes_ratio = height / width @@ -645,12 +649,13 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int) y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int) node_positions = [] - forbidden_cells = [] + city_cells = [] for node_idx in range(nb_nodes): x_tmp = x_positions[node_idx % nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row] node_positions.append((x_tmp, y_tmp)) - return node_positions, forbidden_cells + city_cells.extend(_city_cells(node_positions[-1], node_radius)) + return node_positions, city_cells def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2, max_nr_connection_directions=2): @@ -673,7 +678,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n idx = 1 while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist): current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) - print(node_position) if current_closest_direction not in connection_sides_idx: connection_sides_idx.append(current_closest_direction) idx += 1 @@ -708,7 +712,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n connection_info.append(connections_per_direction) return connection_points, connection_info - def _connect_cities(node_positions, connection_points, connection_info, rail_trans, grid_map): + def _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map): """ Function to connect the different cities through their connection points :param node_positions: Positions of city centers @@ -724,8 +728,6 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n for nbr_connection_points in connection_info[current_node]: if nbr_connection_points > 0: neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions) - print(current_node, node_positions[current_node], direction, neighb_idx, - connection_info[current_node], connection_points[current_node]) else: direction += 1 continue @@ -745,7 +747,8 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n if tmp_dist < min_connection_dist: min_connection_dist = tmp_dist neighb_connection_point = tmp_in_connection_point - connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, None) + connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, + city_cells) boarder_connections.add((tmp_out_connection_point, current_node)) boarder_connections.add((neighb_connection_point, neighb_idx)) direction += 1 @@ -924,7 +927,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n city_cells = [] for x in range(-radius, radius): for y in range(-radius, radius): - city_cells.append(center[0] + x, center[1] + y) + city_cells.append((center[0] + x, center[1] + y)) return city_cells