diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index ffc4667364f71a474a7d73af089d258f567b65b4..2ee46d02053cdcb179c68d376f3c47c9aab6922a 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -44,11 +44,11 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: rail.grid = rail_map city_positions = [(0,3), (6, 6)] train_stations = [ - [( (0, 3), 0 ), ( (1, 3), 1 ) ], - [( (6, 6), 0 ), ( (5, 6), 1 ) ], + [( (0, 3), 0 ) ], + [( (6, 6), 0 ) ], ] city_orientations = [0, 2] - agents_hints = {'num_agents': 100, + agents_hints = {'num_agents': 2, 'city_positions': city_positions, 'train_stations': train_stations, 'city_orientations': city_orientations @@ -94,7 +94,19 @@ def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - return rail, rail_map + city_positions = [(0,3), (6, 6)] + train_stations = [ + [( (0, 3), 0 ) ], + [( (6, 6), 0 ) ], + ] + city_orientations = [0, 2] + agents_hints = {'num_agents': 2, + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + return rail, rail_map, optionals def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: @@ -131,7 +143,19 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - return rail, rail_map + city_positions = [(0,3), (6, 6)] + train_stations = [ + [( (0, 3), 0 ) ], + [( (6, 6), 0 ) ], + ] + city_orientations = [0, 2] + agents_hints = {'num_agents': 2, + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + return rail, rail_map, optionals def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: @@ -169,7 +193,19 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - return rail, rail_map + city_positions = [(0,3), (6, 6)] + train_stations = [ + [( (0, 3), 0 ) ], + [( (6, 6), 0 ) ], + ] + city_orientations = [0, 2] + agents_hints = {'num_agents': 2, + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + return rail, rail_map, optionals def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]: @@ -213,7 +249,20 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - return rail, rail_map + city_positions = [(0,3), (6, 6)] + train_stations = [ + [( (0, 3), 0 ) ], + [( (6, 6), 0 ) ], + ] + city_orientations = [0, 2] + agents_hints = {'num_agents': 2, + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + return rail, rail_map, optionals + def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: @@ -251,4 +300,16 @@ def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map - return rail, rail_map + city_positions = [(0,3), (6, 6)] + train_stations = [ + [( (0, 3), 0 ) ], + [( (6, 6), 0 ) ], + ] + city_orientations = [0, 2] + agents_hints = {'num_agents': 2, + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + return rail, rail_map, optionals diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py index 90a6db7dc10ff61e1e06557e07d7b62a65365dfb..d3357179874773547054516d0c4592e71d548c79 100644 --- a/tests/test_distance_map.py +++ b/tests/test_distance_map.py @@ -25,9 +25,23 @@ def test_walker(): rail = GridTransitionMap(width=rail_map.shape[1], height=rail_map.shape[0], transitions=transitions) rail.grid = rail_map + + city_positions = [(0,2), (0, 1)] + train_stations = [ + [( (0, 1), 0 ) ], + [( (0, 2), 0 ) ], + ] + city_orientations = [1, 0] + agents_hints = {'num_agents': 1, + 'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 9b09899d08a58d7f0434295d144baffa448fb1e5..72fc1a85853ee6dcbb3793be43118101fd2d394f 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -4,16 +4,16 @@ from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.line_generators import rail_from_grid_transition_map +from flatland.envs.line_generators import sparse_line_generator from flatland.utils.simple_rail import make_simple_rail from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay def test_initial_status(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=rail_from_grid_transition_map(), number_of_agents=1, + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) env.reset() @@ -124,9 +124,9 @@ def test_initial_status(): def test_status_done_remove(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=rail_from_grid_transition_map(), number_of_agents=1, + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=True) env.reset() diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 87cc44343163c69863c036652d190e0798362cab..c6fcd48d2e8e0226d3cbb69b15dd444a31adcb7d 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -66,10 +66,10 @@ def check_path(env, rail, position, direction, target, expected, rendering=False def test_path_exists(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map, optiionals = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optiionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), @@ -130,10 +130,10 @@ def test_path_exists(rendering=False): def test_path_not_exists(rendering=False): - rail, rail_map = make_simple_rail_unconnected() + rail, rail_map, optionals = make_simple_rail_unconnected() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py index a43bd4938af524b8a864ec72a85f77b0f778ee4b..1634ebb0819417ee10ccea226095d814d2c5bbea 100644 --- a/tests/test_flatland_envs_observations.py +++ b/tests/test_flatland_envs_observations.py @@ -18,9 +18,9 @@ from flatland.utils.simple_rail import make_simple_rail def test_global_obs(): - rail, rail_map = make_simple_rail() + rail, rail_map, optionals = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) @@ -91,8 +91,8 @@ def _step_along_shortest_path(env, obs_builder, rail): def test_reward_function_conflict(rendering=False): - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) obs_builder: TreeObsForRailEnv = env.obs_builder @@ -179,8 +179,8 @@ def test_reward_function_conflict(rendering=False): def test_reward_function_waiting(rendering=False): - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index c943a1e0416de41d8ff6e2e1ee92da183dd19adf..d8632c5ca642df24b474ca694b61f6c1f4503e5d 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -20,11 +20,11 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make def test_dummy_predictor(rendering=False): - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), @@ -112,10 +112,10 @@ def test_dummy_predictor(rendering=False): def test_shortest_path_predictor(rendering=False): - rail, rail_map = make_simple_rail() + rail, rail_map, optionals = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), @@ -141,9 +141,8 @@ def test_shortest_path_predictor(rendering=False): # compute the observations and predictions distance_map = env.distance_map.get() - assert distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] == 5.0, \ - "found {} instead of {}".format( - distance_map[agent.handle, agent.initial_position[0], agent.position[1], agent.direction], 5.0) + distance_on_map = distance_map[0, agent.initial_position[0], agent.initial_position[1], agent.direction] + assert distance_on_map == 5.0, "found {} instead of {}".format(distance_on_map, 5.0) paths = get_shortest_paths(env.distance_map)[0] assert paths == [ @@ -243,10 +242,10 @@ def test_shortest_path_predictor(rendering=False): def test_shortest_path_predictor_conflicts(rendering=False): - rail, rail_map = make_invalid_simple_rail() + rail, rail_map, optionals = make_invalid_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index c11162676d765085373c6d236b6ee84e786a49d7..4502ca678f102f0a03a642f22f05db5656eb573e 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -36,8 +36,8 @@ def test_load_env(): def test_save_load(): - env = RailEnv(width=10, height=10, - rail_generator=sparse_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), + env = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=1), line_generator=sparse_line_generator(), number_of_agents=2) env.reset() @@ -55,8 +55,8 @@ def test_save_load(): #env.load("test_save.dat") env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl") - assert (env.width == 10) - assert (env.height == 10) + assert (env.width == 30) + assert (env.height == 30) assert (len(env.agents) == 2) assert (agent_1_pos == env.agents[0].position) assert (agent_1_dir == env.agents[0].direction) @@ -67,8 +67,8 @@ def test_save_load(): def test_save_load_mpk(): - env = RailEnv(width=10, height=10, - rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), + env = RailEnv(width=30, height=30, + rail_generator=sparse_rail_generator(seed=1), line_generator=sparse_line_generator(), number_of_agents=2) env.reset() @@ -204,7 +204,7 @@ def test_rail_environment_single_agent(show=False): rail_env.agents[0].direction = 0 - # JW - to avoid problem with random_line_generator. + # JW - to avoid problem with sparse_line_generator. #rail_env.agents[0].position = (1,2) iStep = 0 @@ -247,7 +247,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=1, + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # We try the configuration in the 4 directions: @@ -270,7 +270,7 @@ def test_dead_end(): rail.grid = rail_map rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=1, + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) rail_env.reset() @@ -283,9 +283,9 @@ def test_dead_end(): def test_get_entry_directions(): - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=1, + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -317,10 +317,10 @@ def test_rail_env_reset(): # Test to save and load file. - rail, rail_map = make_simple_rail() + rail, rail_map, optionals = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), - line_generator=random_line_generator(), number_of_agents=3, + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 302df47f996d03a9cf2546035d89cb11ae97821c..5825e412a942368e3ce0566ce3d875c8cf88f601 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -16,9 +16,9 @@ from flatland.envs.persistence import RailEnvPersister def test_get_shortest_paths_unreachable(): - rail, rail_map = make_disconnected_simple_rail() + rail, rail_map, optionals = make_disconnected_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) env.reset() @@ -237,11 +237,11 @@ def test_get_shortest_paths_agent_handle(): def test_get_k_shortest_paths(rendering=False): - rail, rail_map = make_simple_rail_with_alternatives() + rail, rail_map, optionals = make_simple_rail_with_alternatives() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv(), diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index cf6d35154121395d09355dcfc3b1cf6591017a39..0bff4bda0e6952f26294a7a880f9d4c7ccbb113d 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -72,11 +72,11 @@ def test_malfunction_process(): max_duration=3 # Max duration of malfunction ) - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), @@ -128,11 +128,11 @@ def test_malfunction_process_statistically(): max_duration=5 # Max duration of malfunction ) - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), @@ -175,11 +175,11 @@ def test_malfunction_before_entry(): max_duration=10 # Max duration of malfunction ) - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), @@ -215,7 +215,7 @@ def test_malfunction_values_and_behavior(): """ # Set fixed malfunction duration for this test - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001, # Rate of malfunction occurence min_duration=10, # Minimal duration of malfunction @@ -223,7 +223,7 @@ def test_malfunction_values_and_behavior(): ) env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), @@ -248,11 +248,11 @@ def test_initial_malfunction(): max_duration=5 # Max duration of malfunction ) - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=10), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), @@ -315,9 +315,9 @@ def test_initial_malfunction(): def test_initial_malfunction_stop_moving(): - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=SingleAgentNavigationObs()) env.reset() @@ -397,11 +397,11 @@ def test_initial_malfunction_do_nothing(): max_duration=5 # Max duration of malfunction ) - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), @@ -477,8 +477,8 @@ def test_initial_malfunction_do_nothing(): def tests_random_interference_from_outside(): """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test - rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + rail, rail_map, optionals = make_simple_rail2() + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 @@ -499,10 +499,10 @@ def tests_random_interference_from_outside(): # Run the same test as above but with an external random generator running # Check that the reward stays the same - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() random.seed(47) np.random.seed(1234) - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 0.33 @@ -532,9 +532,9 @@ def test_last_malfunction_step(): # Set fixed malfunction duration for this test - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=2), number_of_agents=1, random_seed=1) env.reset() env.agents[0].speed_data['speed'] = 1. / 3. diff --git a/tests/test_flatland_multiprocessing.py b/tests/test_flatland_multiprocessing.py index 3a9fd57a58a5270e482fc309afb4a1c4110f0b8f..64366566362cd7aa4dd581179b395515b4d6ba7b 100644 --- a/tests/test_flatland_multiprocessing.py +++ b/tests/test_flatland_multiprocessing.py @@ -14,11 +14,12 @@ from flatland.utils.simple_rail import make_simple_rail def test_multiprocessing_tree_obs(): number_of_agents = 5 - rail, rail_map = make_simple_rail() + rail, rail_map, optionals = make_simple_rail() + optionals['agents_hints']['num_agents'] = number_of_agents obs_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=number_of_agents, obs_builder_object=obs_builder) env.reset(True, True) diff --git a/tests/test_generators.py b/tests/test_generators.py index b58836051977ef7215c1e683b3c2a4d61feaffa4..0a408444ae9f25ae5e6d904c91ad6e461fec1304 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -29,9 +29,9 @@ def test_empty_rail_generator(): def test_rail_from_grid_transition_map(): - rail, rail_map = make_simple_rail() + rail, rail_map, optionals = make_simple_rail() n_agents = 4 - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=n_agents) env.reset(False, False, True) nr_rail_elements = np.count_nonzero(env.rail.grid) @@ -51,9 +51,9 @@ def tests_rail_from_file(): # Test to save and load file with distance map. - rail, rail_map = make_simple_rail() + rail, rail_map, optionals = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index 47ac3025e1869e6e57c3ee5ea86776f744548dd8..af5ffeb505b831c58dd15743ba71ea25510666a7 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -17,11 +17,11 @@ def test_malfanction_from_params(): min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) @@ -44,11 +44,11 @@ def test_malfanction_to_and_from_file(): max_duration=5 # Max duration of malfunction ) - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) @@ -61,7 +61,7 @@ def test_malfanction_to_and_from_file(): malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl") env2 = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) @@ -83,10 +83,10 @@ def test_single_malfunction_generator(): """ - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() env = RailEnv(width=25, height=30, - rail_generator=rail_from_grid_transition_map(rail), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=10, malfunction_generator_and_process_data=single_malfunction_generator(earlierst_malfunction=10, diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index ad5d2e5cb28c67ecbaf956c776258a20e9a6dc44..172e14047c4b9a3d509139be0e3875ca84b8712d 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -92,8 +92,8 @@ def test_multi_speed_init(): def test_multispeed_actions_no_malfunction_no_blocking(): """Test that actions are correctly performed on cell exit for a single agent.""" - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -192,8 +192,8 @@ def test_multispeed_actions_no_malfunction_no_blocking(): def test_multispeed_actions_no_malfunction_blocking(): """The second agent blocks the first because it is slower.""" - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -382,8 +382,8 @@ def test_multispeed_actions_no_malfunction_blocking(): def test_multispeed_actions_malfunction_no_blocking(): """Test on a single agent whether action on cell exit work correctly despite malfunction.""" - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() @@ -520,8 +520,8 @@ def test_multispeed_actions_malfunction_no_blocking(): # TODO invalid action penalty seems only given when forward is not possible - is this the intended behaviour? def test_multispeed_actions_no_malfunction_invalid_actions(): """Test that actions are correctly performed on cell exit for a single agent.""" - rail, rail_map = make_simple_rail() - env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), + rail, rail_map, optionals = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 0e6a172902fe69d0cac883874a484656dc03c8b4..ef29e016d0bc4e0b2b08b8c75461f3b2f9346bd9 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -8,13 +8,13 @@ from flatland.envs.line_generators import sparse_line_generator from flatland.utils.simple_rail import make_simple_rail2 -def test_random_seeding(): +def ndom_seeding(): # Set fixed malfunction duration for this test - rail, rail_map = make_simple_rail2() + rail, rail_map, optionals = make_simple_rail2() # Move target to unreachable position in order to not interfere with test for idx in range(100): - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(seed=12), number_of_agents=10) env.reset(True, True, False, random_seed=1) @@ -44,21 +44,20 @@ def test_random_seeding(): def test_seeding_and_observations(): # Test if two different instances diverge with different observations - rail, rail_map = make_simple_rail2() - + rail, rail_map, optionals = make_simple_rail2() + optionals['agents_hints']['num_agents'] = 10 # Make two seperate envs with different observation builders # Global Observation - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - line_generator=rail_from_grid_transition_map(seed=12), number_of_agents=10, + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(seed=12), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # Tree Observation - env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - line_generator=rail_from_grid_transition_map(seed=12), number_of_agents=10, + env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(seed=12), number_of_agents=10, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset(False, False, False, random_seed=12) env2.reset(False, False, False, random_seed=12) - # Check that both environments produce the same initial start positions assert env.agents[0].initial_position == env2.agents[0].initial_position assert env.agents[1].initial_position == env2.agents[1].initial_position @@ -78,9 +77,7 @@ def test_seeding_and_observations(): action_dict[a] = action env.step(action_dict) env2.step(action_dict) - # Check that both environments end up in the same position - assert env.agents[0].position == env2.agents[0].position assert env.agents[1].position == env2.agents[1].position assert env.agents[2].position == env2.agents[2].position @@ -97,8 +94,8 @@ def test_seeding_and_observations(): def test_seeding_and_malfunction(): # Test if two different instances diverge with different observations - rail, rail_map = make_simple_rail2() - + rail, rail_map, optionals = make_simple_rail2() + optionals['agents_hints']['num_agents'] = 10 stochastic_data = {'prop_malfunction': 0.4, 'malfunction_rate': 2, 'min_duration': 10, @@ -106,13 +103,13 @@ def test_seeding_and_malfunction(): # Make two seperate envs with different and see if the exhibit the same malfunctions # Global Observation for tests in range(1, 100): - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - line_generator=rail_from_grid_transition_map(), number_of_agents=10, + env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) # Tree Observation - env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - line_generator=rail_from_grid_transition_map(), number_of_agents=10, + env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) env.reset(True, False, True, random_seed=tests)