From 3e902906f2d1bf5928cf8dd0d395ee2510b10d4b Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Wed, 14 Aug 2019 18:32:40 -0400 Subject: [PATCH] checking that targets don't end up on a connecting rail --- flatland/envs/generators.py | 44 ++++++++----------- ...test_flatland_env_sparse_rail_generator.py | 6 +-- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 6d1205b3..81ee9a92 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -693,7 +693,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis return generator -def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2, +def sparse_rail_generator(nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2, seed=0): ''' @@ -706,10 +706,6 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi ''' def generator(width, height, num_agents, num_resets=0): - - if num_agents > nr_train_stations: - num_agents = nr_train_stations - warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid @@ -780,26 +776,24 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, mi height - 1) target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, width - 1) - if agent_idx == 0: - agents_position.append((start_x, start_y)) - agents_target.append((target_x, target_y)) - else: - # Make sure we don't put to starts or targets on same cell - while (start_x, start_y) in agents_position or (target_x, target_y) in agents_target: - start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), - 0, - height - 1) - start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), - 0, - width - 1) - target_x = np.clip( - node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, - height - 1) - target_y = np.clip( - node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, - width - 1) - agents_position.append((start_x, start_y)) - agents_target.append((target_x, target_y)) + # Make sure we don't put to starts or targets on same cell + while (start_x, start_y) in agents_position or (start_x, start_y) == node_positions[start_target[0]]: + start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), + 0, + height - 1) + start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), + 0, + width - 1) + while (target_x, target_y) in agents_target or (target_x, target_y) == node_positions[start_target[1]] or \ + rail_array[(target_x, target_y)] != 0: + target_x = np.clip( + node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, + height - 1) + target_y = np.clip( + node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, + width - 1) + agents_position.append((start_x, start_y)) + agents_target.append((target_x, target_y)) new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx], node_positions[start_target[0]]) new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]], diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index b64aaa64..54bee175 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -10,9 +10,9 @@ def test_sparse_rail_generator(): env = RailEnv(width=20, height=20, - rail_generator=sparse_rail_generator(nr_train_stations=3, nr_nodes=2, min_node_dist=5, - node_radius=4), - number_of_agents=3, + rail_generator=sparse_rail_generator(nr_nodes=5, min_node_dist=8, + node_radius=3), + number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) -- GitLab