diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 652671ae48b5cd269a20e79a31657c2e7434b8d4..ab2d02912bf61b796b5c7ced312918b7cf56615c 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -690,3 +690,93 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator + + +def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbours=2, min_node_dist=20, node_radius=2, + seed=0): + ''' + + :param nr_train_stations: + :param nr_nodes: + :param mean_node_neighbours: + :param min_node_dist: + :param seed: + :return: + ''' + + def generator(width, height, num_agents, num_resets=0): + + if num_agents > nr_train_stations: + num_agents = nr_train_stations + print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") + rail_trans = RailEnvTransitions() + grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) + rail_array = grid_map.grid + rail_array.fill(0) + np.random.seed(seed + num_resets) + + # Generate a set of nodes for the sparse network + node_positions = [] + for node_idx in range(nr_nodes): + to_close = True + while to_close: + x_tmp = np.random.randint(width) + y_tmp = np.random.randint(height) + to_close = False + for node_pos in node_positions: + if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: + to_close = True + if not to_close: + node_positions.append((x_tmp, y_tmp)) + + # Generate start and target node directory for all agents + agent_start_targets_nodes = [] + for agent_idx in range(num_agents): + start_target_tuple = np.random.choice(nr_nodes, 2, replace=False) + agent_start_targets_nodes.append(start_target_tuple) + # Generate actual start and target locations from around nodes + agents_position = [] + agents_target = [] + agents_direction = [] + agent_idx = 0 + for start_target in agent_start_targets_nodes: + start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0, + width - 1) + start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0, + height - 1) + target_x = np.clip(node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, + width - 1) + target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, + height - 1) + if agent_idx == 0: + agents_position.append((start_y, start_x)) + agents_target.append((target_y, target_x)) + else: + while ((start_x, start_y) in agents_position or (target_x, target_y) in agents_target): + start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + target_x = np.clip( + node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, + width - 1) + target_y = np.clip( + node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, + height - 1) + agents_position.append((start_y, start_x)) + agents_target.append((target_y, target_x)) + + agents_direction.append(0) + agent_idx += 1 + + print(agents_position) + print(agents_target) + print(node_positions) + for n in node_positions: + for m in node_positions: + print(distance_on_rail(n, m)) + return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..8934a4022611364f0adc4df806376fd43e5bc8f4 --- /dev/null +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -0,0 +1,14 @@ +from flatland.envs.generators import sparse_rail_generator +from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.rail_env import RailEnv + + +def test_sparse_rail_generator(): + env = RailEnv(width=20, + height=20, + rail_generator=sparse_rail_generator(nr_train_stations=10, nr_nodes=5, min_node_dist=10, + node_radius=4), + number_of_agents=10, + obs_builder_object=GlobalObsForRailEnv()) + # reset to initialize agents_static + env.reset()