From 297b65c5f7b1208aa4da7005df1fccd02b07445f Mon Sep 17 00:00:00 2001 From: Dipam Chakraborty <dipam@aicrowd.com> Date: Thu, 5 Aug 2021 02:15:31 +0530 Subject: [PATCH] some tests working --- flatland/envs/agent_utils.py | 6 ++- flatland/envs/persistence.py | 4 +- tests/test_action_plan.py | 42 ++++++++++--------- tests/test_flaltland_rail_agent_status.py | 9 ++++ tests/test_flatland_envs_rail_env.py | 9 ++-- ...t_flatland_envs_rail_env_shortest_paths.py | 6 ++- ...est_flatland_envs_sparse_rail_generator.py | 2 +- ...ile.py => test_flatland_line_from_file.py} | 26 +++++++----- tests/test_flatland_malfunction.py | 2 + tests/test_global_observation.py | 4 ++ tests/test_multi_speed.py | 17 +++++++- tests/test_utils.py | 2 +- 12 files changed, 89 insertions(+), 40 deletions(-) rename tests/{test_flatland_schedule_from_file.py => test_flatland_line_from_file.py} (79%) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 00dabd31..b90d38a4 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -109,7 +109,11 @@ class EnvAgent: return get_shortest_paths(distance_map=distance_map, agent_handle=self.handle)[self.handle] def get_travel_time_on_shortest_path(self, distance_map) -> int: - distance = len(self.get_shortest_path(distance_map)) + shortest_path = self.get_shortest_path(distance_map) + if shortest_path is not None: + distance = len(shortest_path) + else: + distance = 0 speed = self.speed_data['speed'] return int(np.ceil(distance / speed)) diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index bc4b169b..b8354078 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -21,7 +21,7 @@ from flatland.envs.distance_map import DistanceMap # cannot import objects / classes directly because of circular import from flatland.envs import malfunction_generators as mal_gen from flatland.envs import rail_generators as rail_gen -from flatland.envs import schedule_generators as sched_gen +from flatland.envs import line_generators as line_gen msgpack_numpy.patch() @@ -122,7 +122,7 @@ class RailEnvPersister(object): width=width, height=height, rail_generator=rail_gen.rail_from_file(filename, load_from_package=load_from_package), - schedule_generator=sched_gen.schedule_from_file(filename, + line_generator=line_gen.line_from_file(filename, load_from_package=load_from_package), #malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename, # load_from_package=load_from_package), diff --git a/tests/test_action_plan.py b/tests/test_action_plan.py index 815ecbcd..d5c95408 100644 --- a/tests/test_action_plan.py +++ b/tests/test_action_plan.py @@ -34,25 +34,29 @@ def test_action_plan(rendering: bool = False): for handle, agent in enumerate(env.agents): print("[{}] {} -> {}".format(handle, agent.initial_position, agent.target)) - chosen_path_dict = {0: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), - TrainrunWaypoint(lined_at=2, waypoint=Waypoint(position=(3, 1), direction=1)), - TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 2), direction=1)), - TrainrunWaypoint(lined_at=14, waypoint=Waypoint(position=(3, 3), direction=1)), - TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(3, 4), direction=1)), - TrainrunWaypoint(lined_at=16, waypoint=Waypoint(position=(3, 5), direction=1)), - TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(3, 6), direction=1)), - TrainrunWaypoint(lined_at=18, waypoint=Waypoint(position=(3, 7), direction=1)), - TrainrunWaypoint(lined_at=19, waypoint=Waypoint(position=(3, 8), direction=1)), - TrainrunWaypoint(lined_at=20, waypoint=Waypoint(position=(3, 8), direction=5))], - 1: [TrainrunWaypoint(lined_at=0, waypoint=Waypoint(position=(3, 8), direction=3)), - TrainrunWaypoint(lined_at=3, waypoint=Waypoint(position=(3, 7), direction=3)), - TrainrunWaypoint(lined_at=5, waypoint=Waypoint(position=(3, 6), direction=3)), - TrainrunWaypoint(lined_at=7, waypoint=Waypoint(position=(3, 5), direction=3)), - TrainrunWaypoint(lined_at=9, waypoint=Waypoint(position=(3, 4), direction=3)), - TrainrunWaypoint(lined_at=11, waypoint=Waypoint(position=(3, 3), direction=3)), - TrainrunWaypoint(lined_at=13, waypoint=Waypoint(position=(2, 3), direction=0)), - TrainrunWaypoint(lined_at=15, waypoint=Waypoint(position=(1, 3), direction=0)), - TrainrunWaypoint(lined_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]} + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + + chosen_path_dict = {0: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 0), direction=3)), + TrainrunWaypoint(scheduled_at=2, waypoint=Waypoint(position=(3, 1), direction=1)), + TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 2), direction=1)), + TrainrunWaypoint(scheduled_at=14, waypoint=Waypoint(position=(3, 3), direction=1)), + TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(3, 4), direction=1)), + TrainrunWaypoint(scheduled_at=16, waypoint=Waypoint(position=(3, 5), direction=1)), + TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(3, 6), direction=1)), + TrainrunWaypoint(scheduled_at=18, waypoint=Waypoint(position=(3, 7), direction=1)), + TrainrunWaypoint(scheduled_at=19, waypoint=Waypoint(position=(3, 8), direction=1)), + TrainrunWaypoint(scheduled_at=20, waypoint=Waypoint(position=(3, 8), direction=5))], + 1: [TrainrunWaypoint(scheduled_at=0, waypoint=Waypoint(position=(3, 8), direction=3)), + TrainrunWaypoint(scheduled_at=3, waypoint=Waypoint(position=(3, 7), direction=3)), + TrainrunWaypoint(scheduled_at=5, waypoint=Waypoint(position=(3, 6), direction=3)), + TrainrunWaypoint(scheduled_at=7, waypoint=Waypoint(position=(3, 5), direction=3)), + TrainrunWaypoint(scheduled_at=9, waypoint=Waypoint(position=(3, 4), direction=3)), + TrainrunWaypoint(scheduled_at=11, waypoint=Waypoint(position=(3, 3), direction=3)), + TrainrunWaypoint(scheduled_at=13, waypoint=Waypoint(position=(2, 3), direction=0)), + TrainrunWaypoint(scheduled_at=15, waypoint=Waypoint(position=(1, 3), direction=0)), + TrainrunWaypoint(scheduled_at=17, waypoint=Waypoint(position=(0, 3), direction=0))]} expected_action_plan = [[ # take action to enter the grid ActionPlanElement(0, RailEnvActions.MOVE_FORWARD), diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 4f507415..87b37305 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -17,6 +17,11 @@ def test_initial_status(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=False) env.reset() + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ @@ -126,6 +131,10 @@ def test_status_done_remove(): remove_agents_at_target=True) env.reset() + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index c531359a..53a61c8f 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -9,9 +9,9 @@ from flatland.envs.agent_utils import EnvAgent from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions -from flatland.envs.rail_generators import complex_rail_generator, rail_from_file +from flatland.envs.rail_generators import sparse_rail_generator, rail_from_file from flatland.envs.rail_generators import rail_from_grid_transition_map -from flatland.envs.line_generators import random_line_generator, complex_line_generator, line_from_file +from flatland.envs.line_generators import random_line_generator, sparse_line_generator, line_from_file from flatland.utils.simple_rail import make_simple_rail from flatland.envs.persistence import RailEnvPersister from flatland.utils.rendertools import RenderTool @@ -37,9 +37,10 @@ def test_load_env(): def test_save_load(): env = RailEnv(width=10, height=10, - rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), - line_generator=complex_line_generator(), number_of_agents=2) + rail_generator=sparse_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), + line_generator=sparse_line_generator(), number_of_agents=2) env.reset() + agent_1_pos = env.agents[0].position agent_1_dir = env.agents[0].direction agent_1_tar = env.agents[0].target diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index b2ee9b01..ce88aeb8 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -23,6 +23,10 @@ def test_get_shortest_paths_unreachable(): obs_builder_object=GlobalObsForRailEnv()) env.reset() + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + # set the initial position agent = env.agents[0] agent.position = (3, 1) # west dead-end @@ -36,7 +40,7 @@ def test_get_shortest_paths_unreachable(): actual = get_shortest_paths(env.distance_map) expected = {0: None} - assert actual == expected, "actual={},expected={}".format(actual, expected) + assert actual[0] == expected[0], "actual={},expected={}".format(actual[0], expected[0]) # todo file test_002.pkl has to be generated automatically diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index 3e74d720..74e71dac 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -1512,7 +1512,7 @@ def test_sparse_generator_changes_to_grid_mode(): rail_env = RailEnv(width=10, height=20, rail_generator=sparse_rail_generator( max_num_cities=100, max_rails_between_cities=2, - max_rails_in_city=2, + max_rail_pairs_in_city=1, seed=15, grid_mode=False ), line_generator=sparse_line_generator(), number_of_agents=10, diff --git a/tests/test_flatland_schedule_from_file.py b/tests/test_flatland_line_from_file.py similarity index 79% rename from tests/test_flatland_schedule_from_file.py rename to tests/test_flatland_line_from_file.py index 0b903eae..b324af98 100644 --- a/tests/test_flatland_schedule_from_file.py +++ b/tests/test_flatland_line_from_file.py @@ -25,12 +25,14 @@ def test_line_from_file_sparse(): seed=1, grid_mode=False, max_rails_between_cities=3, - max_rails_in_city=6, + max_rail_pairs_in_city=3, ) line_generator = sparse_line_generator(speed_ration_map) - create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator, + env = create_and_save_env(file_name="./sparse_env_test.pkl", rail_generator=rail_generator, line_generator=line_generator) + old_num_steps = env._max_episode_steps + old_num_agents = len(env.agents) # Sparse generator @@ -41,10 +43,10 @@ def test_line_from_file_sparse(): sparse_env_from_file.reset(True, True) # Assert loaded agent number is correct - assert sparse_env_from_file.get_num_agents() == 10 + assert sparse_env_from_file.get_num_agents() == old_num_agents # Assert max steps is correct - assert sparse_env_from_file._max_episode_steps == 500 + assert sparse_env_from_file._max_episode_steps == old_num_steps @@ -65,8 +67,10 @@ def test_line_from_file_random(): rail_generator = random_rail_generator() line_generator = random_line_generator(speed_ration_map) - create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, + env = create_and_save_env(file_name="./random_env_test.pkl", rail_generator=rail_generator, line_generator=line_generator) + old_num_steps = env._max_episode_steps + old_num_agents = len(env.agents) # Random generator @@ -77,10 +81,10 @@ def test_line_from_file_random(): random_env_from_file.reset(True, True) # Assert loaded agent number is correct - assert random_env_from_file.get_num_agents() == 10 + assert random_env_from_file.get_num_agents() == old_num_agents # Assert max steps is correct - assert random_env_from_file._max_episode_steps == 1350 + assert random_env_from_file._max_episode_steps == old_num_steps @@ -105,8 +109,10 @@ def test_line_from_file_complex(): max_dist=99999) line_generator = complex_line_generator(speed_ration_map) - create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, + env = create_and_save_env(file_name="./complex_env_test.pkl", rail_generator=rail_generator, line_generator=line_generator) + old_num_steps = env._max_episode_steps + old_num_agents = len(env.agents) # Load the different envs and check the parameters @@ -119,7 +125,7 @@ def test_line_from_file_complex(): complex_env_from_file.reset(True, True) # Assert loaded agent number is correct - assert complex_env_from_file.get_num_agents() == 10 + assert complex_env_from_file.get_num_agents() == old_num_agents # Assert max steps is correct - assert complex_env_from_file._max_episode_steps == 1350 + assert complex_env_from_file._max_episode_steps == old_num_steps diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 53915102..8675f54f 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -97,6 +97,8 @@ def test_malfunction_process(): actions[i] = np.argmax(obs[i]) + 1 obs, all_rewards, done, _ = env.step(actions) + if done["__all__"]: + break if env.agents[0].malfunction_data['malfunction'] > 0: agent_malfunctioning = True diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 5b090681..851d849d 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -30,6 +30,10 @@ def test_get_global_observation(): obs_builder_object=GlobalObsForRailEnv()) env.reset() + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) for i in range(len(env.agents)): agent: EnvAgent = env.agents[i] diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index 08b46d00..2664c5b4 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -51,6 +51,7 @@ def test_multi_speed_init(): rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=1), line_generator=complex_line_generator(), number_of_agents=5) + # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) @@ -197,6 +198,12 @@ def test_multispeed_actions_no_malfunction_blocking(): line_generator=random_line_generator(), number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + + set_penalties_for_replay(env) test_configs = [ ReplayConfig( @@ -381,7 +388,11 @@ def test_multispeed_actions_malfunction_no_blocking(): line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() - + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents + set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ @@ -515,6 +526,10 @@ def test_multispeed_actions_no_malfunction_invalid_actions(): line_generator=random_line_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())) env.reset() + + # Perform DO_NOTHING actions until all trains get to READY_TO_DEPART + for _ in range(max([agent.earliest_departure for agent in env.agents])): + env.step({}) # DO_NOTHING for all agents set_penalties_for_replay(env) test_config = ReplayConfig( diff --git a/tests/test_utils.py b/tests/test_utils.py index 4e7c30ca..062d56f0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -152,4 +152,4 @@ def create_and_save_env(file_name: str, line_generator: LineGenerator, rail_gene env.reset(True, True) #env.save(file_name) RailEnvPersister.save(env, file_name) - + return env -- GitLab