diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 6049e6db3f0aa1680934314a95c2e1533a634003..80c9f13b9ed3e69937ee1090bd5e93d90b6a17c2 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -790,6 +790,12 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): if len(new_path) > 0: c = (pos_x - 1, pos_y - 1) make_switch_e_w(width, height, grid_map, c) + agents_positions_backward.append(add_pos) + agents_directions_backward.append(1) + idx_backward.append(idx_target) + add_pos = (goal_track[0], goal_track[1], idx_target) + agents_targets.append(add_pos) + idx_target += 1 else: start_track = (pos_x, pos_y) goal_track = (pos_x, pos_y - 2) @@ -797,6 +803,12 @@ def realistic_rail_generator(nr_start_goal=1, seed=0): if len(new_path) > 0: c = (pos_x - 1, pos_y + 1) make_switch_w_e(width, height, grid_map, c) + agents_positions_forward.append(add_pos) + agents_directions_forward.append(3) + idx_forward.append(idx_target) + add_pos = (goal_track[0], goal_track[1], idx_target) + agents_targets.append(add_pos) + idx_target += 1 agents_position = [] agents_target = [] diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 7d1b818cc47f6e717a9da430fecfcdd48b14be4a..59628274006c23decfb8bd24101e545e821f1c51 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -1,5 +1,5 @@ import numpy as np - +import os from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -19,8 +19,15 @@ def test_realistic_rail_generator(vizualization_folder_name=None): screen_height=1200, screen_width=1600) env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + if vizualization_folder_name is not None: + env_renderer.gl.save_image( + os.path.join( + vizualization_folder_name, + "flatland_frame_{:04d}.png".format(test_loop) + )) env_renderer.close_window() + def test_sparse_rail_generator(): env = RailEnv(width=50, height=50, @@ -40,4 +47,4 @@ def test_sparse_rail_generator(): env_renderer.render_env(show=True, show_observations=True, show_predictions=False) -test_realistic_rail_generator() +test_realistic_rail_generator("rendering/")