diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 0ec268dc1d3910967541ba0bcc357c8cc4638ec4..ada43ce4e25353cc8e6e639b5a3d1ce6b00c4953 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -582,7 +582,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Connect the cities through the connection points city_connection_time = time.time() - _connect_cities(node_positions, outer_connection_points, connection_info, city_cells, rail_trans, grid_map) + inter_city_lines = _connect_cities(node_positions, outer_connection_points, connection_info, city_cells, + rail_trans, grid_map) print("City connection time", time.time() - city_connection_time) # Build inner cities city_build_time = time.time() @@ -604,7 +605,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Fix all transition elements grid_fix_time = time.time() - _fix_transitions(city_cells, grid_map) + _fix_transitions(city_cells, inter_city_lines, grid_map) print("Grid fix time", time.time() - grid_fix_time) # Generate start target pairs @@ -728,6 +729,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, :param grid_map: Grid map :return: """ + all_paths = [] for current_node in np.arange(len(node_positions)): direction = 0 connected_to_city = [] @@ -758,11 +760,12 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, if tmp_dist < min_connection_dist: min_connection_dist = tmp_dist neighb_connection_point = tmp_in_connection_point - connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, - city_cells) - + new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, + city_cells) + all_paths.extend(new_line) direction += 1 - return + + return all_paths def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans, grid_map): @@ -891,14 +894,15 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, num_agents -= 1 return agent_start_targets_nodes, num_agents - def _fix_transitions(city_cells, grid_map): + def _fix_transitions(city_cells, inter_city_lines, grid_map): """ Function to fix all transition elements in environment """ # Fix all nodes with illegal transition maps rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int') rails_to_fix_cnt = 0 - for cell in city_cells: + cells_to_fix = city_cells + inter_city_lines + for cell in cells_to_fix: check = grid_map.cell_neighbours_valid(cell, True) if grid_map.grid[cell] == int('1000010000100001', 2): grid_map.fix_transitions(cell)