From ae31a7b8ffccd1256ec0b7d450d3d1e14ff0c6ab Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 4 Nov 2019 09:37:48 -0500 Subject: [PATCH] updated tests --- examples/introduction_flatland_2_1.py | 9 +++-- examples/simple_example_1.py | 2 +- examples/simple_example_2.py | 2 +- flatland/envs/grid4_generators_utils.py | 3 +- flatland/envs/malfunction_generators.py | 4 ++- flatland/envs/rail_env.py | 1 - flatland/envs/rail_generators.py | 4 +-- flatland/envs/schedule_generators.py | 2 +- flatland/evaluators/aicrowd_helpers.py | 43 +++++++++++------------ flatland/evaluators/client.py | 22 ++++++------ flatland/evaluators/messages.py | 1 - flatland/evaluators/service.py | 24 ++++++------- make_docs.py | 1 - tests/test_flaltland_rail_agent_status.py | 1 + tests/test_flatland_envs_rail_env.py | 1 + tests/test_flatland_malfunction.py | 38 +++++++++++--------- tests/test_flatland_utils_rendertools.py | 1 + tests/test_global_observation.py | 3 +- tests/test_malfunction_generators.py | 17 ++++----- tests/test_random_seeding.py | 6 ++-- 20 files changed, 97 insertions(+), 88 deletions(-) diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 29fafe64..79b06ed9 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -74,8 +74,13 @@ observation_builder = GlobalObsForRailEnv() # observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Construct the enviornment with the given observation, generataors, predictors, and stochastic data -env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, - number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=malfunction_from_params(stochastic_data), +env = RailEnv(width=width, + height=height, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + number_of_agents=nr_trains, + obs_builder_object=observation_builder, + malfunction_generator=malfunction_from_params(stochastic_data), remove_agents_at_target=True) env.reset() diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index ba889301..b8341e6e 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -17,4 +17,4 @@ env_renderer = RenderTool(env) env_renderer.render_env(show=True, show_predictions=False, show_observations=False) # uncomment to keep the renderer open -#input("Press Enter to continue...") +# input("Press Enter to continue...") diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index f9659cbd..bffceddf 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -33,4 +33,4 @@ env_renderer = RenderTool(env, gl="PIL") env_renderer.render_env(show=True) # uncomment to keep the renderer open -#input("Press Enter to continue...") +# input("Press Enter to continue...") diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index ac30ca49..053796e9 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -160,6 +160,7 @@ def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, ra grid_map.grid[tmp_pos] = transition return + def align_cell_to_city(city_center, city_orientation, cell): """ Alig all cells to face the city center along the city orientation @@ -171,4 +172,4 @@ def align_cell_to_city(city_center, city_orientation, cell): if city_orientation % 2 == 0: return int(2 * np.clip(cell[0] - city_center[0], 0, 1)) else: - return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1 + return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1 diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 0de2f4b5..1c1df1de 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -1,6 +1,6 @@ """Malfunction generators for rail systems""" -from typing import Tuple, List, Callable +from typing import Tuple, Callable import msgpack @@ -36,6 +36,7 @@ def malfunction_from_file(filename) -> MalfunctionGenerator: return generator + def malfunction_from_params(parameters) -> MalfunctionGenerator: """ Utility to load malfunction from parameters @@ -60,6 +61,7 @@ def malfunction_from_params(parameters) -> MalfunctionGenerator: return generator + def no_malfunction_generator() -> MalfunctionGenerator: """ Utility to load malfunction from parameters diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index fe51db2f..0adbd451 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -994,4 +994,3 @@ class RailEnv(Environment): """ return agent.malfunction_data['malfunction'] < 1 - diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 3e90128c..231cc825 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -345,7 +345,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R def get_matching_templates(template): """ Returns a list of possible transition maps for a given template - + Parameters: ------ template:List[int] @@ -751,7 +751,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ # Respect padding between cities padding = 2 city_size = 2 * (city_radius + 1) - max_cities_per_row =int((height - padding) // city_size) + max_cities_per_row = int((height - padding) // city_size) max_cities_per_col = int((width - padding) // city_size) # Choose number of cities per row. diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 903b58f9..f48264d5 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -204,7 +204,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = if len(valid_positions) < num_agents: warnings.warn("schedule_generators: len(valid_positions) < num_agents") return Schedule(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) + agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] diff --git a/flatland/evaluators/aicrowd_helpers.py b/flatland/evaluators/aicrowd_helpers.py index 1f79eac9..779f5ad6 100644 --- a/flatland/evaluators/aicrowd_helpers.py +++ b/flatland/evaluators/aicrowd_helpers.py @@ -30,14 +30,14 @@ def get_boto_client(): import boto3 except ImportError as e: raise Exception( - "boto3 is not installed. Please manually install by : ", - " pip install -U boto3" - ) + "boto3 is not installed. Please manually install by : ", + " pip install -U boto3" + ) return boto3.client( - 's3', - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY + 's3', + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY ) @@ -50,7 +50,7 @@ def is_aws_configured(): def is_grading(): return os.getenv("CROWDAI_IS_GRADING", False) or \ - os.getenv("AICROWD_IS_GRADING", False) + os.getenv("AICROWD_IS_GRADING", False) def upload_random_frame_to_s3(frames_folder): @@ -61,7 +61,7 @@ def upload_random_frame_to_s3(frames_folder): raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...") if not S3_BUCKET: raise Exception("S3_BUCKET not provided...") - + image_target_key = S3_UPLOAD_PATH_TEMPLATE.replace(".mp4", ".png").format(str(uuid.uuid4())) s3.put_object( ACL="public-read", @@ -78,7 +78,7 @@ def upload_to_s3(localpath): raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...") if not S3_BUCKET: raise Exception("S3_BUCKET not provided...") - + image_target_key = S3_UPLOAD_PATH_TEMPLATE.format(str(uuid.uuid4())) s3.put_object( ACL="public-read", @@ -91,11 +91,11 @@ def upload_to_s3(localpath): def make_subprocess_call(command, shell=False): result = subprocess.run( - command.split(), - shell=shell, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE - ) + command.split(), + shell=shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) stdout = result.stdout.decode('utf-8') stderr = result.stderr.decode('utf-8') return result.returncode, stdout, stderr @@ -103,7 +103,7 @@ def make_subprocess_call(command, shell=False): def generate_movie_from_frames(frames_folder): """ - Expects the frames in the frames_folder folder + Expects the frames in the frames_folder folder and then use ffmpeg to generate the video which writes the output to the frames_folder """ @@ -112,9 +112,9 @@ def generate_movie_from_frames(frames_folder): frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png") thumb_output_path = os.path.join(frames_folder, "out_thumb.mp4") return_code, output, output_err = make_subprocess_call( - "ffmpeg -r 7 -start_number 0 -i " + - frames_path + - " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " + + "ffmpeg -r 7 -start_number 0 -i " + + frames_path + + " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " + thumb_output_path ) if return_code != 0: @@ -125,13 +125,12 @@ def generate_movie_from_frames(frames_folder): frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png") output_path = os.path.join(frames_folder, "out.mp4") return_code, output, output_err = make_subprocess_call( - "ffmpeg -r 7 -start_number 0 -i " + - frames_path + - " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " + + "ffmpeg -r 7 -start_number 0 -i " + + frames_path + + " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " + output_path ) if return_code != 0: raise Exception(output_err) return output_path, thumb_output_path - diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 7b2e1899..922f7fc0 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -11,8 +11,6 @@ import numpy as np import redis import flatland -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file @@ -223,11 +221,11 @@ class FlatlandRemoteClient(object): time_start = time.time() local_observation, info = self.env.reset( - regenerate_rail=True, - regenerate_schedule=True, - activate_agents=False, - random_seed=random_seed - ) + regenerate_rail=True, + regenerate_schedule=True, + activate_agents=False, + random_seed=random_seed + ) time_diff = time.time() - time_start self.update_running_mean_stats("internal_env_reset_time", time_diff) # Use the local observation @@ -266,14 +264,14 @@ class FlatlandRemoteClient(object): ###################################################################### # Print Local Stats ###################################################################### - print("="*100) - print("="*100) + print("=" * 100) + print("=" * 100) print("## Client Performance Stats") - print("="*100) + print("=" * 100) for _key in self.stats: if _key.endswith("_mean"): print("\t - {}\t:{}".format(_key, self.stats[_key])) - print("="*100) + print("=" * 100) if os.getenv("AICROWD_BLOCKING_SUBMIT"): """ If the submission is supposed to happen as a blocking submit, @@ -288,12 +286,14 @@ class FlatlandRemoteClient(object): if __name__ == "__main__": remote_client = FlatlandRemoteClient() + def my_controller(obs, _env): _action = {} for _idx, _ in enumerate(_env.agents): _action[_idx] = np.random.randint(0, 5) return _action + my_observation_builder = DummyObservationBuilder() episode = 0 diff --git a/flatland/evaluators/messages.py b/flatland/evaluators/messages.py index dfe71efb..35c8b372 100644 --- a/flatland/evaluators/messages.py +++ b/flatland/evaluators/messages.py @@ -15,4 +15,3 @@ class FLATLAND_RL: ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE" ERROR = "FLATLAND_RL.ERROR" - \ No newline at end of file diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index ce4cb8cf..8a70f197 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -8,7 +8,6 @@ import shutil import time import traceback -import flatland import crowdai_api import msgpack import msgpack_numpy as m @@ -16,9 +15,10 @@ import numpy as np import redis import timeout_decorator +import flatland from flatland.core.env_observation_builder import DummyObservationBuilder -from flatland.envs.rail_env import RailEnv from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file from flatland.evaluators import aicrowd_helpers @@ -353,11 +353,11 @@ class FlatlandRemoteEvaluationService: self.current_step = 0 _observation, _info = self.env.reset( - regenerate_rail=True, - regenerate_schedule=True, - activate_agents=False, - random_seed=RANDOM_SEED - ) + regenerate_rail=True, + regenerate_schedule=True, + activate_agents=False, + random_seed=RANDOM_SEED + ) if self.visualize: if self.env_renderer: @@ -477,14 +477,14 @@ class FlatlandRemoteEvaluationService: ###################################################################### # Print Local Stats ###################################################################### - print("="*100) - print("="*100) + print("=" * 100) + print("=" * 100) print("## Server Performance Stats") - print("="*100) + print("=" * 100) for _key in self.stats: if _key.endswith("_mean"): print("\t - {}\t:{}".format(_key, self.stats[_key])) - print("="*100) + print("=" * 100) # Register simulation time of the last episode self.simulation_times.append(time.time() - self.begin_simulation) @@ -615,7 +615,7 @@ class FlatlandRemoteEvaluationService: print("Self.Reward : ", self.reward) print("Current Simulation : ", self.simulation_count) if self.env_file_paths and \ - self.simulation_count < len(self.env_file_paths): + self.simulation_count < len(self.env_file_paths): print("Current Env Path : ", self.env_file_paths[self.simulation_count]) diff --git a/make_docs.py b/make_docs.py index 81fe5873..2b9d92b2 100644 --- a/make_docs.py +++ b/make_docs.py @@ -50,7 +50,6 @@ for image_file in glob.glob(r'./specifications/img/*'): subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build']) - # we do not currrently use pydeps, commented out https://gitlab.aicrowd.com/flatland/flatland/issues/149 # subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow']) diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index a573e55d..cb1ebd0c 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -120,6 +120,7 @@ def test_initial_status(): run_replay_config(env, [test_config], activate_agents=False) + def test_status_done_remove(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" rail, rail_map = make_simple_rail() diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index e6550f17..a7fd93d0 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -217,6 +217,7 @@ def test_get_entry_directions(): # nowhere _assert((0, 0), [False, False, False, False]) + def test_rail_env_reset(): file_name = "test_rail_env_reset.pkl" diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index b2c1ca11..4f0dc13d 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -75,7 +75,8 @@ def test_malfunction_process(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static obs, info = env.reset(False, False, True, random_seed=10) @@ -124,24 +125,25 @@ def test_malfunction_process_statistically(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static env.reset(True, True, False, random_seed=10) env.agents[0].target = (0, 0) # Next line only for test generation - #agent_malfunction_list = [[] for i in range(10)] + # agent_malfunction_list = [[] for i in range(10)] agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], - [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0], - [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0], - [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], - [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2], - [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]] + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0], + [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0], + [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -149,10 +151,10 @@ def test_malfunction_process_statistically(): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: - #agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) + # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict) - #print(agent_malfunction_list) + # print(agent_malfunction_list) def test_malfunction_before_entry(): @@ -185,7 +187,7 @@ def test_malfunction_before_entry(): assert env.agents[8].malfunction_data['malfunction'] == 10 assert env.agents[9].malfunction_data['malfunction'] == 10 - #for a in range(10): + # for a in range(10): # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction'])) @@ -230,7 +232,8 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static env.reset(False, False, True, random_seed=10) print(env.agents[0].malfunction_data) @@ -297,7 +300,8 @@ def test_initial_malfunction_stop_moving(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index 6ed92fef..18b68f2a 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -49,6 +49,7 @@ def test_render_env(save_new_images=False): oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) + def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True) diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index bb5cd34e..d16cb3d5 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -29,7 +29,8 @@ def test_get_global_observation(): grid_mode=False ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=GlobalObsForRailEnv(), + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index fa455b75..4c6c2085 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -1,13 +1,7 @@ -import random -from typing import Dict, List - import numpy as np -from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map @@ -40,6 +34,7 @@ def test_malfanction_from_params(): assert env.min_number_of_steps_broken == 2 assert env.max_number_of_steps_broken == 5 + def test_malfanction_to_and_from_file(): """ Test loading malfunction from @@ -65,11 +60,11 @@ def test_malfanction_to_and_from_file(): env.save("./malfunction_saving_loading_tests.pkl") env2 = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), - number_of_agents=1, - malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl")) + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=10), + number_of_agents=1, + malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl")) env2.reset() diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 75634a22..b60c40ca 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -109,12 +109,14 @@ def test_seeding_and_malfunction(): for tests in range(1, 100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=GlobalObsForRailEnv(), + malfunction_generator=malfunction_from_params(stochastic_data)) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=GlobalObsForRailEnv(), + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset(True, False, True, random_seed=tests) env2.reset(True, False, True, random_seed=tests) -- GitLab