From 79a1e07677451b9839113129f92a28b00ee2bf7c Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 20 Aug 2019 14:57:11 +0200 Subject: [PATCH] realistic generator --- flatland/envs/generators.py | 16 ++++++++++------ tests/test_flatland_env_sparse_rail_generator.py | 8 +++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 4589633e..d21c6eb1 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -543,7 +543,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return generator -def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_track_back_bone=True): +def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=4, two_track_back_bone=True): """ Parameters ------- @@ -670,7 +670,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t np.random.seed(seed + num_resets) - max_n_track_seg = np.random.choice([3, 4, 5]) + int(two_track_back_bone) + max_n_track_seg = np.random.choice(np.arange(3, int(height / 2))) + int(two_track_back_bone) x_offsets = np.arange(0, height, max_n_track_seg).astype(int) agents_positions = [] @@ -680,8 +680,8 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t for off_set_loop in range(len(x_offsets)): off_set = x_offsets[off_set_loop] # second track - data = np.arange(3, width - 4, 3) - n_track_seg = np.random.choice(max_n_track_seg) + 1 + data = np.arange(2, width - 2) + n_track_seg = np.random.choice([1,2]) track_2 = False if two_track_back_bone: @@ -817,6 +817,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t 0 + int(two_track_back_bone_loop))) for nbr_track_loop in range(max_n_track_seg - 1): + n_track_seg = 1 if len(data) < 2 * n_track_seg + 1: break x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int) @@ -827,7 +828,8 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t goal = ( max(0, min(off_set + nbr_track_loop + 1, height - 1)), max(0, min(x[2 * x_loop + 1], width - 1))) - d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1, 2) + + d = np.arange(x[2 * x_loop] + 1, x[2 * x_loop + 1] - 1) data.extend(d) new_path = connect_rail(rail_trans, rail_array, start, goal) @@ -843,6 +845,8 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t add_pos = (int((start[0] + goal[0]) / 2), int((2 * start[1] + goal[1]) / 3)) agents_targets.append(add_pos) + for off_set_loop in range(len(x_offsets)): + off_set = x_offsets[off_set_loop] pos_ys = np.random.choice(np.arange(width - 7) + 3, min(width - 7, add_max_dead_end), False) for pos_y in pos_ys: pos_x = off_set + 1 + int(two_track_back_bone) @@ -855,7 +859,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t c = (pos_x, pos_y - k + 1) ok &= grid_map.grid[c[0]][c[1]] == 0 if ok: - if np.random.random() < 0.95: + if np.random.random() < 0.5: start_track = (pos_x, pos_y) goal_track = (pos_x, pos_y - 2) new_path = connect_rail(rail_trans, rail_array, start_track, goal_track) diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index b3ee348c..710ae6d5 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -7,14 +7,16 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool, AgentRenderVariant + def test_realistic_rail_generator(vizualization_folder_name=None): for test_loop in range(50): print("test_loop", test_loop) num_agents = np.random.randint(10, 30) env = RailEnv(width=np.random.randint(40, 80), height=np.random.randint(10, 20), - rail_generator=realistic_rail_generator(nr_start_goal=num_agents + 1, seed=test_loop, - add_max_dead_end=20, + rail_generator=realistic_rail_generator(nr_start_goal=num_agents + 1, + seed=test_loop, + add_max_dead_end=4, two_track_back_bone=test_loop % 2 == 0), number_of_agents=num_agents, obs_builder_object=GlobalObsForRailEnv()) @@ -51,4 +53,4 @@ def test_sparse_rail_generator(): env_renderer.render_env(show=True, show_observations=True, show_predictions=False) env_renderer.close_window() -#test_realistic_rail_generator("./../rendering/") +test_realistic_rail_generator("./../rendering/") -- GitLab