diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index be21f86cba60eb6b4d3b3ad56e32202ad002a9b7..2441e3e7652828d9c83a4ca5860e030f7921e8a8 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -673,10 +673,16 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): max_n_track_seg = np.random.choice([3, 4, 5]) x_offsets = np.arange(0, height, max_n_track_seg).astype(int) - agents_positions = [] + agents_positions_forward = [] + agents_directions_forward = [] + agents_positions_backward = [] + agents_directions_backward = [] agents_targets = [] - agents_directions = [] + idx_forward = [] + idx_backward = [] + + idx_target=0 for off_set_loop in range(len(x_offsets)): off_set = x_offsets[off_set_loop] # second track @@ -733,7 +739,7 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): (x_offsets[off_set_loop] - 1, 0), (x_offsets[off_set_loop] - 2, 0)) - for nbr_track_loop in range(height - 1): + for nbr_track_loop in range(max_n_track_seg-1): if len(data) < 2 * n_track_seg + 1: break x = np.sort(np.random.choice(data, 2 * n_track_seg, False)).astype(int) @@ -752,24 +758,52 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): make_switch_w_e(width, height, grid_map, c) add_pos = (int((start[0] + goal[0]) / 2), int((start[1] + goal[1]) / 2)) - agents_positions.append(add_pos) - agents_directions.append(3) - add_pos = (int((start[0] + goal[0]) / 2), int((2*start[1] + goal[1]) / 3)) + if nbr_track_loop % 2 == 0: + agents_positions_forward.append(add_pos) + agents_directions_forward.append(([1, 3][off_set_loop % 2])) + idx_forward.append(idx_target) + else: + agents_positions_backward.append(add_pos) + agents_directions_backward.append(([1, 3][off_set_loop % 2])) + idx_backward.append(idx_target) + + add_pos = (int((start[0] + goal[0]) / 2), int((2*start[1] + goal[1]) / 3),idx_target) agents_targets.append(add_pos) + idx_target+=1 agents_position = [] agents_target = [] agents_direction = [] - filter_agent = np.random.choice(np.arange(len(agents_positions)),min(len(agents_positions),num_agents),False) - for f in filter_agent: - d = agents_positions[f] - agents_position.append(d) - d = agents_directions[f] - agents_direction.append(d) - filter_target = np.random.choice(np.arange(len(agents_targets)),min(len(agents_targets),num_agents),False) - for f in filter_target: - d = agents_targets[f] - agents_target.append(d) + + for a in range(min(len(agents_targets),num_agents)): + t = np.random.choice(range(len(agents_targets))) + d = agents_targets[t] + agents_targets.pop(t) + if d[2] < idx_target / 2: + if len(idx_backward) > 0: + agents_target.append((d[0], d[1])) + sel = np.random.choice(range(len(idx_backward))) + # backward + p = agents_positions_backward[sel] + d = agents_directions_backward[sel] + agents_positions_backward.pop(sel) + agents_directions_backward.pop(sel) + idx_backward.pop(sel) + agents_position.append((p[0],p[1])) + agents_direction.append(d) + else: + if len(idx_forward) > 0: + agents_target.append((d[0], d[1])) + sel = np.random.choice(range(len(idx_forward))) + # forward + p = agents_positions_forward[sel] + d = agents_directions_forward[sel] + agents_positions_forward.pop(sel) + agents_directions_forward.pop(sel) + idx_forward.pop(sel) + agents_position.append((p[0],p[1])) + agents_direction.append(d) + return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 2a361ca99c4b83da933abe674428afc15347e4db..3a08f5b2f7cdf9b7103194348266d997be3dda04 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -9,7 +9,7 @@ from flatland.utils.rendertools import RenderTool def test_realistic_rail_generator(): - for test_loop in range(5): + for test_loop in range(20): num_agents = np.random.randint(10,30) env = RailEnv(width=np.random.randint(40,80), height=np.random.randint(10,20), @@ -23,6 +23,18 @@ def test_realistic_rail_generator(): env_renderer.close_window() def test_sparse_rail_generator(): + + env = RailEnv(width=20, + height=20, + rail_generator=sparse_rail_generator(nr_nodes=3, min_node_dist=8, + node_radius=4), + number_of_agents=15, + + env = RailEnv(width=20, + height=20, + rail_generator=sparse_rail_generator(nr_nodes=3, min_node_dist=8, + node_radius=4), + number_of_agents=15, env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map @@ -38,5 +50,7 @@ def test_sparse_rail_generator(): # reset to initialize agents_static env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + time.sleep(2) + env_renderer.gl.save_image("flatalnd_2_0.png") time.sleep(100)