diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 57e99ffb89e9f6584fff9707de921c5fda101585..6396f4123adf895c381c2a21f6d8dc6e4e823b92 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -32,8 +32,8 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=50, # Number of cities in map (where train stations are) - seed=0, # Random seed + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are) + seed=1, # Random seed grid_mode=False, max_inter_city_rails=2, max_tracks_in_city=4, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 6e021183a77642c8547c5efc3a8c97764fa078d0..887f97aae7c74af8d986454f0c50dcaace59f0dc 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -248,6 +248,7 @@ class RailEnv(Environment): rc_pos = (r, c) check = self.rail.cell_neighbours_valid(rc_pos, True) if not check: + print(self.rail.grid[rc_pos]) warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check)) # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172 # hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 0ec268dc1d3910967541ba0bcc357c8cc4638ec4..ada43ce4e25353cc8e6e639b5a3d1ce6b00c4953 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -582,7 +582,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Connect the cities through the connection points city_connection_time = time.time() - _connect_cities(node_positions, outer_connection_points, connection_info, city_cells, rail_trans, grid_map) + inter_city_lines = _connect_cities(node_positions, outer_connection_points, connection_info, city_cells, + rail_trans, grid_map) print("City connection time", time.time() - city_connection_time) # Build inner cities city_build_time = time.time() @@ -604,7 +605,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, # Fix all transition elements grid_fix_time = time.time() - _fix_transitions(city_cells, grid_map) + _fix_transitions(city_cells, inter_city_lines, grid_map) print("Grid fix time", time.time() - grid_fix_time) # Generate start target pairs @@ -728,6 +729,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, :param grid_map: Grid map :return: """ + all_paths = [] for current_node in np.arange(len(node_positions)): direction = 0 connected_to_city = [] @@ -758,11 +760,12 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, if tmp_dist < min_connection_dist: min_connection_dist = tmp_dist neighb_connection_point = tmp_in_connection_point - connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, - city_cells) - + new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, + city_cells) + all_paths.extend(new_line) direction += 1 - return + + return all_paths def _build_inner_cities(node_positions, inner_connection_points, outer_connection_points, node_radius, rail_trans, grid_map): @@ -891,14 +894,15 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, num_agents -= 1 return agent_start_targets_nodes, num_agents - def _fix_transitions(city_cells, grid_map): + def _fix_transitions(city_cells, inter_city_lines, grid_map): """ Function to fix all transition elements in environment """ # Fix all nodes with illegal transition maps rails_to_fix = np.zeros(2 * grid_map.height * grid_map.width * 2, dtype='int') rails_to_fix_cnt = 0 - for cell in city_cells: + cells_to_fix = city_cells + inter_city_lines + for cell in cells_to_fix: check = grid_map.cell_neighbours_valid(cell, True) if grid_map.grid[cell] == int('1000010000100001', 2): grid_map.fix_transitions(cell)