From 853dd42c29e9ed6e646a4f14536dda85d9b58940 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 20 Aug 2019 09:43:30 +0200 Subject: [PATCH] realistic generator supports single or two track backbone --- flatland/envs/generators.py | 18 ++++++++++++++++-- .../test_flatland_env_sparse_rail_generator.py | 2 ++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 1dedec19..24671d79 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -705,6 +705,20 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t goal_track = (off_set, width - 1 - int((off_set_loop) % 2) * int(two_track_back_bone) * int(track_2)) new_path = connect_rail(rail_trans, rail_array, start_track, goal_track) + if track_2: + if np.random.random() < 0.75: + c = (off_set, 3) + if np.random.random() < 0.5: + make_switch_e_w(width, height, grid_map, c) + else: + make_switch_w_e(width, height, grid_map, c) + if np.random.random() < 0.5: + c = (off_set, width - 3) + if np.random.random() < 0.5: + make_switch_e_w(width, height, grid_map, c) + else: + make_switch_w_e(width, height, grid_map, c) + # track one (full track : left right) for two_track_back_bone_loop in range(1 + int(track_2) * int(two_track_back_bone)): if off_set_loop > 0: @@ -832,7 +846,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0, add_max_dead_end=20, two_t agents_targets.append(add_pos) idx_target += 1 - for pos_y in np.random.choice(np.arange(width - 7) + 3, add_max_dead_end, False): + for pos_y in np.random.choice(np.arange(width - 7) + 3, min(width - 7, add_max_dead_end), False): pos_x = off_set + 1 + int(two_track_back_bone) if pos_x < height - 1: ok = True @@ -1054,7 +1068,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation width - 1) while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[ trainstation_node] or \ - rail_array[(station_x, station_y)] != 0: + rail_array[(station_x, station_y)] != 0: station_x = np.clip( node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0, diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index eac83381..3ac0636e 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -50,3 +50,5 @@ def test_sparse_rail_generator(): # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + +test_realistic_rail_generator() -- GitLab