diff --git a/flatland/envs/line_generators.py b/flatland/envs/line_generators.py index 8b412783ca999c5383e102804928888d43aee32a..79241f2489b4a7b3ab3008f269d2c03fbafd27c8 100644 --- a/flatland/envs/line_generators.py +++ b/flatland/envs/line_generators.py @@ -84,11 +84,6 @@ class SparseLineGen(BaseLineGen): train_stations = hints['train_stations'] city_positions = hints['city_positions'] city_orientation = hints['city_orientations'] - max_num_agents = hints['num_agents'] - city_orientations = hints['city_orientations'] - if num_agents > max_num_agents: - num_agents = max_num_agents - warnings.warn("Too many agents! Changes number of agents.") # Place agents and targets within available train stations agents_position = [] agents_target = [] diff --git a/flatland/utils/simple_rail.py b/flatland/utils/simple_rail.py index 2ee46d02053cdcb179c68d376f3c47c9aab6922a..445b856d83847813f86ac4dca80a02cf33d27e29 100644 --- a/flatland/utils/simple_rail.py +++ b/flatland/utils/simple_rail.py @@ -48,11 +48,10 @@ def make_simple_rail() -> Tuple[GridTransitionMap, np.array]: [( (6, 6), 0 ) ], ] city_orientations = [0, 2] - agents_hints = {'num_agents': 2, - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - } + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -100,11 +99,10 @@ def make_disconnected_simple_rail() -> Tuple[GridTransitionMap, np.array]: [( (6, 6), 0 ) ], ] city_orientations = [0, 2] - agents_hints = {'num_agents': 2, - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - } + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -149,11 +147,10 @@ def make_simple_rail2() -> Tuple[GridTransitionMap, np.array]: [( (6, 6), 0 ) ], ] city_orientations = [0, 2] - agents_hints = {'num_agents': 2, - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - } + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -199,11 +196,10 @@ def make_simple_rail_unconnected() -> Tuple[GridTransitionMap, np.array]: [( (6, 6), 0 ) ], ] city_orientations = [0, 2] - agents_hints = {'num_agents': 2, - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - } + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -255,11 +251,10 @@ def make_simple_rail_with_alternatives() -> Tuple[GridTransitionMap, np.array]: [( (6, 6), 0 ) ], ] city_orientations = [0, 2] - agents_hints = {'num_agents': 2, - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - } + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals @@ -306,10 +301,45 @@ def make_invalid_simple_rail() -> Tuple[GridTransitionMap, np.array]: [( (6, 6), 0 ) ], ] city_orientations = [0, 2] - agents_hints = {'num_agents': 2, - 'city_positions': city_positions, - 'train_stations': train_stations, - 'city_orientations': city_orientations - } + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } optionals = {'agents_hints': agents_hints} return rail, rail_map, optionals + +def make_oval_rail() -> Tuple[GridTransitionMap, np.array]: + transitions = RailEnvTransitions() + cells = transitions.transition_list + + empty = cells[0] + vertical_straight = cells[1] + horizontal_straight = transitions.rotate_transition(vertical_straight, 90) + right_turn_from_south = cells[8] + right_turn_from_west = transitions.rotate_transition(right_turn_from_south, 90) + right_turn_from_north = transitions.rotate_transition(right_turn_from_south, 180) + right_turn_from_east = transitions.rotate_transition(right_turn_from_south, 270) + + rail_map = np.array( + [[empty] * 9] + + [[empty] + [right_turn_from_south] + [horizontal_straight] * 5 + [right_turn_from_west] + [empty]] + + [[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]]+ + [[empty] + [vertical_straight] + [empty] * 5 + [vertical_straight] + [empty]] + + [[empty] + [right_turn_from_east] + [horizontal_straight] * 5 + [right_turn_from_north] + [empty]] + + [[empty] * 9], dtype=np.uint16) + + rail = GridTransitionMap(width=rail_map.shape[1], + height=rail_map.shape[0], transitions=transitions) + rail.grid = rail_map + city_positions = [(1, 4), (4, 4)] + train_stations = [ + [((1, 4), 0)], + [((4, 4), 0)], + ] + city_orientations = [1, 3] + agents_hints = {'city_positions': city_positions, + 'train_stations': train_stations, + 'city_orientations': city_orientations + } + optionals = {'agents_hints': agents_hints} + return rail, rail_map, optionals \ No newline at end of file diff --git a/tests/test_flatland_envs_agent_utils.py b/tests/test_flatland_envs_agent_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1824797c19ce82f39a8095441cbed0e3bd48a38a --- /dev/null +++ b/tests/test_flatland_envs_agent_utils.py @@ -0,0 +1,102 @@ +import pytest + +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.line_generators import sparse_line_generator +from flatland.utils.simple_rail import make_oval_rail + + +def test_shortest_paths(): + rail, rail_map, optionals = make_oval_rail() + + speed_ratio_map = {1.: 1.0} + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(speed_ratio_map), + number_of_agents=2) + env.reset() + + agent0_shortest_path = env.agents[0].get_shortest_path(env.distance_map) + agent1_shortest_path = env.agents[1].get_shortest_path(env.distance_map) + + assert len(agent0_shortest_path) == 10 + assert len(agent1_shortest_path) == 10 + + +def test_travel_time_on_shortest_paths(): + rail, rail_map, optionals = make_oval_rail() + + speed_ratio_map = {1.: 1.0} + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(speed_ratio_map), + number_of_agents=2) + env.reset() + + agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map) + agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map) + + assert agent0_travel_time == 10 + assert agent1_travel_time == 10 + + + speed_ratio_map = {1/2: 1.0} + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(speed_ratio_map), + number_of_agents=2) + env.reset() + + agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map) + agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map) + + assert agent0_travel_time == 20 + assert agent1_travel_time == 20 + + + speed_ratio_map = {1/3: 1.0} + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(speed_ratio_map), + number_of_agents=2) + env.reset() + + agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map) + agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map) + + + assert agent0_travel_time == 30 + assert agent1_travel_time == 30 + + + speed_ratio_map = {1/4: 1.0} + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(speed_ratio_map), + number_of_agents=2) + env.reset() + + agent0_travel_time = env.agents[0].get_travel_time_on_shortest_path(env.distance_map) + agent1_travel_time = env.agents[1].get_travel_time_on_shortest_path(env.distance_map) + + assert agent0_travel_time == 40 + assert agent1_travel_time == 40 + + +# def test_latest_arrival_validity(): +# pass + + +# def test_time_remaining_until_latest_arrival(): +# pass + +def main(): + pass + +if __name__ == "__main__": + main() diff --git a/tests/test_flatland_envs_persistence.py b/tests/test_flatland_envs_persistence.py new file mode 100644 index 0000000000000000000000000000000000000000..7e26389f58dd87ab2fee6099f691c2b6ce9c5266 --- /dev/null +++ b/tests/test_flatland_envs_persistence.py @@ -0,0 +1,36 @@ +import numpy as np + +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.line_generators import sparse_line_generator +from flatland.utils.simple_rail import make_simple_rail +from flatland.envs.persistence import RailEnvPersister + +def test_load_new(): + + filename = "test_load_new.pkl" + + rail, rail_map, optionals = make_simple_rail() + n_agents = 2 + env_initial = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals), + line_generator=sparse_line_generator(), number_of_agents=n_agents) + env_initial.reset(False, False) + + rails_initial = env_initial.rail.grid + agents_initial = env_initial.agents + + RailEnvPersister.save(env_initial, filename) + + env_loaded, _ = RailEnvPersister.load_new(filename) + + rails_loaded = env_loaded.rail.grid + agents_loaded = env_loaded.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded + +def main(): + pass + +if __name__ == "__main__": + main() diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 942c71b171edf3aa679d41c88330c0fe97097bd7..1e6fb82079911e5a25170514d4d859b2b5b6a1cf 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -373,9 +373,13 @@ def test_rail_env_reset(): env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), line_generator=line_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - env3.reset(False, True, False) + env3.reset(False, True) rails_loaded = env3.rail.grid agents_loaded = env3.agents + # override `earliest_departure` & `latest_arrival` since they aren't expected to be the same + for agent_initial, agent_loaded in zip(agents_initial, agents_loaded): + agent_loaded.earliest_departure = agent_initial.earliest_departure + agent_loaded.latest_arrival = agent_initial.latest_arrival assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded @@ -383,16 +387,21 @@ def test_rail_env_reset(): env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), line_generator=line_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) - env4.reset(True, False, False) + env4.reset(True, False) rails_loaded = env4.rail.grid agents_loaded = env4.agents + # override `earliest_departure` & `latest_arrival` since they aren't expected to be the same + for agent_initial, agent_loaded in zip(agents_initial, agents_loaded): + agent_loaded.earliest_departure = agent_initial.earliest_departure + agent_loaded.latest_arrival = agent_initial.latest_arrival assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded def main(): - test_rail_environment_single_agent(show=True) + # test_rail_environment_single_agent(show=True) + test_rail_env_reset() if __name__=="__main__": - main() \ No newline at end of file + main() diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flatland_rail_agent_status.py similarity index 100% rename from tests/test_flaltland_rail_agent_status.py rename to tests/test_flatland_rail_agent_status.py diff --git a/tests/test_generators.py b/tests/test_generators.py index 7d91bce89bd2d840f433de9f895b29e5a822cf3d..16e40bc00fac37b51c8c9d37051828cf05ac3803 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -72,6 +72,10 @@ def tests_rail_from_file(): env.reset() rails_loaded = env.rail.grid agents_loaded = env.agents + # override `earliest_departure` & `latest_arrival` since they aren't expected to be the same + for agent_initial, agent_loaded in zip(agents_initial, agents_loaded): + agent_loaded.earliest_departure = agent_initial.earliest_departure + agent_loaded.latest_arrival = agent_initial.latest_arrival assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded @@ -85,7 +89,7 @@ def tests_rail_from_file(): file_name_2 = "test_without_distance_map.pkl" env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], - rail_generator=rail_from_grid_transition_map(rail), line_generator=sparse_line_generator(), + rail_generator=rail_from_grid_transition_map(rail, optionals), line_generator=sparse_line_generator(), number_of_agents=3, obs_builder_object=GlobalObsForRailEnv()) env2.reset() #env2.save(file_name_2) @@ -100,6 +104,10 @@ def tests_rail_from_file(): env2.reset() rails_loaded_2 = env2.rail.grid agents_loaded_2 = env2.agents + # override `earliest_departure` & `latest_arrival` since they aren't expected to be the same + for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_2): + agent_loaded.earliest_departure = agent_initial.earliest_departure + agent_loaded.latest_arrival = agent_initial.latest_arrival assert np.all(np.array_equal(rails_initial_2, rails_loaded_2)) assert agents_initial_2 == agents_loaded_2 @@ -113,6 +121,10 @@ def tests_rail_from_file(): env3.reset() rails_loaded_3 = env3.rail.grid agents_loaded_3 = env3.agents + # override `earliest_departure` & `latest_arrival` since they aren't expected to be the same + for agent_initial, agent_loaded in zip(agents_initial, agents_loaded_3): + agent_loaded.earliest_departure = agent_initial.earliest_departure + agent_loaded.latest_arrival = agent_initial.latest_arrival assert np.all(np.array_equal(rails_initial, rails_loaded_3)) assert agents_initial == agents_loaded_3 @@ -130,7 +142,11 @@ def tests_rail_from_file(): env4.reset() rails_loaded_4 = env4.rail.grid agents_loaded_4 = env4.agents - + # override `earliest_departure` & `latest_arrival` since they aren't expected to be the same + for agent_initial, agent_loaded in zip(agents_initial_2, agents_loaded_4): + agent_loaded.earliest_departure = agent_initial.earliest_departure + agent_loaded.latest_arrival = agent_initial.latest_arrival + # Check that no distance map was saved assert not hasattr(env2.obs_builder, "distance_map") assert np.all(np.array_equal(rails_initial_2, rails_loaded_4)) @@ -139,3 +155,10 @@ def tests_rail_from_file(): # Check that distance map was generated with correct shape assert env4.distance_map.get() is not None assert np.shape(env4.distance_map.get()) == dist_map_shape + + +def main(): + tests_rail_from_file() + +if __name__ == "__main__": + main()