diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 6d1205b3b505f77dd57cd670e669c93401200b5d..81ee9a92ea7da799a809f90e6f954c3820b193b0 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -693,7 +693,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis return generator -def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2, +def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2, seed=0): ''' @@ -706,10 +706,6 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi ''' def generator(width, height, num_agents, num_resets=0): - - if num_agents > nr_train_stations: - num_agents = nr_train_stations - warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid @@ -780,26 +776,24 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi height - 1) target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, width - 1) - if agent_idx == 0: - agents_position.append((start_x, start_y)) - agents_target.append((target_x, target_y)) - else: - # Make sure we don't put to starts or targets on same cell - while (start_x, start_y) in agents_position or (target_x, target_y) in agents_target: - start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), - 0, - height - 1) - start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), - 0, - width - 1) - target_x = np.clip( - node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, - height - 1) - target_y = np.clip( - node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, - width - 1) - agents_position.append((start_x, start_y)) - agents_target.append((target_x, target_y)) + # Make sure we don't put to starts or targets on same cell + while (start_x, start_y) in agents_position or (start_x, start_y) == node_positions[start_target[0]]: + start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + while (target_x, target_y) in agents_target or (target_x, target_y) == node_positions[start_target[1]] or \ + rail_array[(target_x, target_y)] != 0: + target_x = np.clip( + node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, + height - 1) + target_y = np.clip( + node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, + width - 1) + agents_position.append((start_x, start_y)) + agents_target.append((target_x, target_y)) new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx], node_positions[start_target[0]]) new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]], diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index b64aaa640d27f24b1ab2bd87e30a29784f7787e7..54bee175808de144bd9ac9cbb7a2c0462e1db286 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -10,9 +10,9 @@ def test_sparse_rail_generator(): env = RailEnv(width=20, height=20, - rail_generator=sparse_rail_generator(nr_train_stations=3, nr_nodes=2, min_node_dist=5, - node_radius=4), - number_of_agents=3, + rail_generator=sparse_rail_generator(nr_nodes=5, min_node_dist=8, + node_radius=3), + number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", )