diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index d576641632bfe1d0ae20baf485afdb2aca1d640b..eb8a4e8cdb2640a68d5d748830d6732f70dd6cc8 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -217,9 +217,6 @@ class RailEnv(Environment): self.max_number_of_steps_broken = malfunction_max_duration # Reset environment - self.reset() - self.num_resets = 0 # yes, set it to zero again! - self.valid_positions = None def _seed(self, seed=None): diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index 14a3e48af08aa50c2b628e8b0c8749937745e32b..987be3967746c7859414e24a452baab7758ee46b 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -23,6 +23,7 @@ def test_initial_status(): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( replay=[ @@ -133,6 +134,7 @@ def test_status_done_remove(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), remove_agents_at_target=True ) + env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 7c8a65fd6501689a0911872a00150a87a082805d..066de27f39639b3438ccb0ab191ffbbc6bf41218 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -19,6 +19,7 @@ from flatland.utils.simple_rail import make_simple_rail def test_load_env(): env = RailEnv(10, 10) + env.reset() env.load_resource('env_data.tests', 'test-10x10.mpk') agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) @@ -83,6 +84,7 @@ def test_rail_environment_single_agent(): schedule_generator=random_schedule_generator(), number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) + rail_env.reset() for _ in range(200): _ = rail_env.reset(False, False, True) @@ -204,6 +206,7 @@ def test_get_entry_directions(): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + env.reset() def _assert(position, expected): actual = env.get_valid_directions_on_grid(*position) diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index 65d2d68c45155efda24536ecfd776bef5ebaab0c..ef18cf72d04a5e6ccfd89ad379bd30df77fc4075 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -21,6 +21,7 @@ def test_get_shortest_paths_unreachable(): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) + env.reset() # set the initial position agent = env.agents_static[0] @@ -41,6 +42,7 @@ def test_get_shortest_paths_unreachable(): def test_get_shortest_paths(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + env.reset() actual = get_shortest_paths(env.distance_map) expected = { @@ -169,6 +171,7 @@ def test_get_shortest_paths(): def test_get_shortest_paths_max_depth(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') + env.reset() actual = get_shortest_paths(env.distance_map, max_depth=2) expected = { diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py index caf702516a22b75186d775c26564f5e21cfb6175..379c90a44c635378e690cebf4f72aae0acbd0598 100644 --- a/tests/test_flatland_envs_sparse_rail_generator.py +++ b/tests/test_flatland_envs_sparse_rail_generator.py @@ -24,6 +24,7 @@ def test_sparse_rail_generator(): number_of_agents=10, obs_builder_object=GlobalObsForRailEnv() ) + env.reset() env.reset(False, False, True) # for r in range(env.height): # for c in range (env.width): @@ -535,7 +536,8 @@ def test_sparse_rail_generator_deterministic(): number_of_agents=1, stochastic_data=stochastic_data, # Malfunction data generator ) - # for r in range(env.height): + env.reset() + # for r in range(env.height): # for c in range(env.width): # print("assert env.rail.get_full_transitions({}, {}) == {}, \"[{}][{}]\"".format(r, c, # env.rail.get_full_transitions( @@ -1311,6 +1313,7 @@ def test_rail_env_action_required_info(): schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) + env_always_action.reset() np.random.seed(0) random.seed(0) env_only_if_action_required = RailEnv(width=50, @@ -1326,6 +1329,7 @@ def test_rail_env_action_required_info(): number_of_agents=10, obs_builder_object=GlobalObsForRailEnv()) env_renderer = RenderTool(env_always_action, gl="PILSVG", ) + env_only_if_action_required.reset() env_always_action.reset(False, False, True) env_only_if_action_required.reset(False, False, True) @@ -1395,6 +1399,7 @@ def test_rail_env_malfunction_speed_info(): number_of_agents=10, obs_builder_object=GlobalObsForRailEnv(), stochastic_data=stochastic_data) + env.reset() env.reset(False, False, True) env_renderer = RenderTool(env, gl="PILSVG", ) diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 8fdd907cae94c6e8f45b5538bfabad7bdb4d60f5..e694275bc353e620f0357101bc5fd11a5c273212 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -81,6 +81,7 @@ def test_malfunction_process(): stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() ) + env.reset() # reset to initialize agents_static obs, info = env.reset(False, False, True, random_seed=10) @@ -150,21 +151,21 @@ def test_malfunction_process_statistically(): stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() ) + env.reset() # reset to initialize agents_static env.reset(True, True, False, random_seed=10) env.agents[0].target = (0, 0) - nb_malfunction = 0 - agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0], + agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2], + [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1], [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4], + [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0], + [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4], - [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0], - [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3], - [0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5], - [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0], - [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1], - [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0]] + [0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0], + [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -175,7 +176,6 @@ def test_malfunction_process_statistically(): # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict) - # print(agent_malfunction_list) def test_malfunction_before_entry(): @@ -196,6 +196,7 @@ def test_malfunction_before_entry(): random_seed=1, stochastic_data=stochastic_data, # Malfunction data generator ) + env.reset() # reset to initialize agents_static env.reset(False, False, False, random_seed=10) env.agents[0].target = (0, 0) @@ -254,6 +255,7 @@ def test_initial_malfunction(): stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() ) + env.reset() # reset to initialize agents_static env.reset(False, False, True, random_seed=10) @@ -327,7 +329,7 @@ def test_initial_malfunction_stop_moving(): stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=SingleAgentNavigationObs() ) - # reset to initialize agents_static + env.reset() print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status) @@ -532,6 +534,7 @@ def tests_random_interference_from_outside(): random_seed=1, stochastic_data=stochastic_data, # Malfunction data generator ) + env.reset() # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 env.agents[0].initial_position = (3, 0) @@ -564,6 +567,7 @@ def tests_random_interference_from_outside(): random_seed=1, stochastic_data=stochastic_data, # Malfunction data generator ) + env.reset() # reset to initialize agents_static env.agents[0].speed_data['speed'] = 0.33 env.agents[0].initial_position = (3, 0) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index 853b025f2ebd39949453f35e0d053e519163237c..6e1fb2441d4428b24da3c37d764a7676f3929a2a 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -42,6 +42,7 @@ def test_render_env(save_new_images=False): number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2) ) + oEnv.reset() oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.render_env(show=False) @@ -50,7 +51,7 @@ def test_render_env(save_new_images=False): oRT = rt.RenderTool(oEnv, gl="PIL") oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) - + def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True) diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index 401992e790b96df297c10f71bd152cbdef0edb9c..d3bcf7779dbc4c8dbafe6e726aeff33c757c25fb 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -32,6 +32,7 @@ def test_get_global_observation(): schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) + env.reset() obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) for i in range(len(env.agents)): diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py index d190a5769415c65948b2d0e12c7f9b86b044a494..676cace09b398bfdc849e4e6678e0f155465b388 100644 --- a/tests/test_multi_speed.py +++ b/tests/test_multi_speed.py @@ -55,6 +55,7 @@ def test_multi_speed_init(): seed=1), schedule_generator=complex_schedule_generator(), number_of_agents=5) + env.reset() # Initialize the agent with the parameters corresponding to the environment and observation_builder agent = RandomAgent(218, 4) @@ -104,6 +105,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( @@ -207,6 +209,7 @@ def test_multispeed_actions_no_malfunction_blocking(): number_of_agents=2, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + env.reset() set_penalties_for_replay(env) test_configs = [ ReplayConfig( @@ -394,6 +397,7 @@ def test_multispeed_actions_malfunction_no_blocking(): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( @@ -531,6 +535,7 @@ def test_multispeed_actions_no_malfunction_invalid_actions(): number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) + env.reset() set_penalties_for_replay(env) test_config = ReplayConfig( diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index bd3c5d08a5ee00cac9c56be539497368bf19e323..ed9631763b95a92f13499526eec7ae1fb193aa11 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -21,6 +21,7 @@ def test_random_seeding(): schedule_generator=random_schedule_generator(seed=12), number_of_agents=10 ) + env.reset() env.reset(True, True, False, random_seed=1) env.agents[0].target = (0, 0) @@ -60,6 +61,7 @@ def test_seeding_and_observations(): number_of_agents=10, obs_builder_object=GlobalObsForRailEnv() ) + env.reset() # Tree Observation env2 = RailEnv(width=25, height=30, @@ -68,6 +70,7 @@ def test_seeding_and_observations(): number_of_agents=10, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) ) + env2.reset() env.reset(False, False, False, random_seed=12) env2.reset(False, False, False, random_seed=12) @@ -127,6 +130,7 @@ def test_seeding_and_malfunction(): obs_builder_object=GlobalObsForRailEnv(), stochastic_data=stochastic_data, # Malfunction data generator ) + env.reset() # Tree Observation env2 = RailEnv(width=25, @@ -137,6 +141,7 @@ def test_seeding_and_malfunction(): obs_builder_object=GlobalObsForRailEnv(), stochastic_data=stochastic_data, # Malfunction data generator ) + env2.reset() env.reset(True, False, True, random_seed=tests) env2.reset(True, False, True, random_seed=tests)