diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index 29fafe6436de2e04a780b34bd3eddba8a4533355..79b06ed99a069a0e200dfd40c7c0d58931e4fa18 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -74,8 +74,13 @@ observation_builder = GlobalObsForRailEnv() # observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) # Construct the enviornment with the given observation, generataors, predictors, and stochastic data -env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator, - number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=malfunction_from_params(stochastic_data), +env = RailEnv(width=width, + height=height, + rail_generator=rail_generator, + schedule_generator=schedule_generator, + number_of_agents=nr_trains, + obs_builder_object=observation_builder, + malfunction_generator=malfunction_from_params(stochastic_data), remove_agents_at_target=True) env.reset() diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py index ba88930142f8344b13a3cae8de2178148e459998..b8341e6e5f63d315fc8e31e88d99dd8f57463a34 100644 --- a/examples/simple_example_1.py +++ b/examples/simple_example_1.py @@ -17,4 +17,4 @@ env_renderer = RenderTool(env) env_renderer.render_env(show=True, show_predictions=False, show_observations=False) # uncomment to keep the renderer open -#input("Press Enter to continue...") +# input("Press Enter to continue...") diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index f9659cbdb666a2a6bd94db8aa560ded40f69079a..bffceddfb5916aef1c038eb0e7a5d0109578c64b 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -33,4 +33,4 @@ env_renderer = RenderTool(env, gl="PIL") env_renderer.render_env(show=True) # uncomment to keep the renderer open -#input("Press Enter to continue...") +# input("Press Enter to continue...") diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index ac30ca49a734fb6d9f556dc32b63c79c2837223a..053796e93c8df19cddfbe9ecd04b1679fed008fc 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -160,6 +160,7 @@ def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, ra grid_map.grid[tmp_pos] = transition return + def align_cell_to_city(city_center, city_orientation, cell): """ Alig all cells to face the city center along the city orientation @@ -171,4 +172,4 @@ def align_cell_to_city(city_center, city_orientation, cell): if city_orientation % 2 == 0: return int(2 * np.clip(cell[0] - city_center[0], 0, 1)) else: - return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1 + return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1 diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 0de2f4b598023884a49a47473670bb8613c71eeb..1c1df1deb5a5c49b638569ef9e93462afbce4226 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -1,6 +1,6 @@ """Malfunction generators for rail systems""" -from typing import Tuple, List, Callable +from typing import Tuple, Callable import msgpack @@ -36,6 +36,7 @@ def malfunction_from_file(filename) -> MalfunctionGenerator: return generator + def malfunction_from_params(parameters) -> MalfunctionGenerator: """ Utility to load malfunction from parameters @@ -60,6 +61,7 @@ def malfunction_from_params(parameters) -> MalfunctionGenerator: return generator + def no_malfunction_generator() -> MalfunctionGenerator: """ Utility to load malfunction from parameters diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index fe51db2f993d1760e092bfcfebf9cee43521fd60..0adbd45157c73d68f879d7ed666b0b4fe3dfab6c 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -994,4 +994,3 @@ class RailEnv(Environment): """ return agent.malfunction_data['malfunction'] < 1 - diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 3e90128c74bda860d8a4c75d71652071660482a3..231cc82547cacc9cd1f066c197204c562079d7b8 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -345,7 +345,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R def get_matching_templates(template): """ Returns a list of possible transition maps for a given template - + Parameters: ------ template:List[int] @@ -751,7 +751,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_ # Respect padding between cities padding = 2 city_size = 2 * (city_radius + 1) - max_cities_per_row =int((height - padding) // city_size) + max_cities_per_row = int((height - padding) // city_size) max_cities_per_col = int((width - padding) // city_size) # Choose number of cities per row. diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 903b58f956c69d7063bc1fe328e8dae9abf157e8..f48264d52e754ca464a0cdc83cde9972492eaf6d 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -204,7 +204,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = if len(valid_positions) < num_agents: warnings.warn("schedule_generators: len(valid_positions) < num_agents") return Schedule(agent_positions=[], agent_directions=[], - agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) + agent_targets=[], agent_speeds=[], agent_malfunction_rates=None) agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)] agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)] diff --git a/flatland/evaluators/aicrowd_helpers.py b/flatland/evaluators/aicrowd_helpers.py index 1f79eac9f1b88645a107d4fba25fd7cb38c6db20..779f5ad63d8e99d3b3a0cb8e51705c54a7042bdd 100644 --- a/flatland/evaluators/aicrowd_helpers.py +++ b/flatland/evaluators/aicrowd_helpers.py @@ -30,14 +30,14 @@ def get_boto_client(): import boto3 except ImportError as e: raise Exception( - "boto3 is not installed. Please manually install by : ", - " pip install -U boto3" - ) + "boto3 is not installed. Please manually install by : ", + " pip install -U boto3" + ) return boto3.client( - 's3', - aws_access_key_id=AWS_ACCESS_KEY_ID, - aws_secret_access_key=AWS_SECRET_ACCESS_KEY + 's3', + aws_access_key_id=AWS_ACCESS_KEY_ID, + aws_secret_access_key=AWS_SECRET_ACCESS_KEY ) @@ -50,7 +50,7 @@ def is_aws_configured(): def is_grading(): return os.getenv("CROWDAI_IS_GRADING", False) or \ - os.getenv("AICROWD_IS_GRADING", False) + os.getenv("AICROWD_IS_GRADING", False) def upload_random_frame_to_s3(frames_folder): @@ -61,7 +61,7 @@ def upload_random_frame_to_s3(frames_folder): raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...") if not S3_BUCKET: raise Exception("S3_BUCKET not provided...") - + image_target_key = S3_UPLOAD_PATH_TEMPLATE.replace(".mp4", ".png").format(str(uuid.uuid4())) s3.put_object( ACL="public-read", @@ -78,7 +78,7 @@ def upload_to_s3(localpath): raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...") if not S3_BUCKET: raise Exception("S3_BUCKET not provided...") - + image_target_key = S3_UPLOAD_PATH_TEMPLATE.format(str(uuid.uuid4())) s3.put_object( ACL="public-read", @@ -91,11 +91,11 @@ def upload_to_s3(localpath): def make_subprocess_call(command, shell=False): result = subprocess.run( - command.split(), - shell=shell, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE - ) + command.split(), + shell=shell, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) stdout = result.stdout.decode('utf-8') stderr = result.stderr.decode('utf-8') return result.returncode, stdout, stderr @@ -103,7 +103,7 @@ def make_subprocess_call(command, shell=False): def generate_movie_from_frames(frames_folder): """ - Expects the frames in the frames_folder folder + Expects the frames in the frames_folder folder and then use ffmpeg to generate the video which writes the output to the frames_folder """ @@ -112,9 +112,9 @@ def generate_movie_from_frames(frames_folder): frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png") thumb_output_path = os.path.join(frames_folder, "out_thumb.mp4") return_code, output, output_err = make_subprocess_call( - "ffmpeg -r 7 -start_number 0 -i " + - frames_path + - " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " + + "ffmpeg -r 7 -start_number 0 -i " + + frames_path + + " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " + thumb_output_path ) if return_code != 0: @@ -125,13 +125,12 @@ def generate_movie_from_frames(frames_folder): frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png") output_path = os.path.join(frames_folder, "out.mp4") return_code, output, output_err = make_subprocess_call( - "ffmpeg -r 7 -start_number 0 -i " + - frames_path + - " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " + + "ffmpeg -r 7 -start_number 0 -i " + + frames_path + + " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " + output_path ) if return_code != 0: raise Exception(output_err) return output_path, thumb_output_path - diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py index 7b2e189971a364afc55cd62cd2e83bb69c9fbd89..922f7fc03700b9f2c4b8065c5271d6814a67a2eb 100644 --- a/flatland/evaluators/client.py +++ b/flatland/evaluators/client.py @@ -11,8 +11,6 @@ import numpy as np import redis import flatland -from flatland.envs.observations import TreeObsForRailEnv -from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file @@ -223,11 +221,11 @@ class FlatlandRemoteClient(object): time_start = time.time() local_observation, info = self.env.reset( - regenerate_rail=True, - regenerate_schedule=True, - activate_agents=False, - random_seed=random_seed - ) + regenerate_rail=True, + regenerate_schedule=True, + activate_agents=False, + random_seed=random_seed + ) time_diff = time.time() - time_start self.update_running_mean_stats("internal_env_reset_time", time_diff) # Use the local observation @@ -266,14 +264,14 @@ class FlatlandRemoteClient(object): ###################################################################### # Print Local Stats ###################################################################### - print("="*100) - print("="*100) + print("=" * 100) + print("=" * 100) print("## Client Performance Stats") - print("="*100) + print("=" * 100) for _key in self.stats: if _key.endswith("_mean"): print("\t - {}\t:{}".format(_key, self.stats[_key])) - print("="*100) + print("=" * 100) if os.getenv("AICROWD_BLOCKING_SUBMIT"): """ If the submission is supposed to happen as a blocking submit, @@ -288,12 +286,14 @@ class FlatlandRemoteClient(object): if __name__ == "__main__": remote_client = FlatlandRemoteClient() + def my_controller(obs, _env): _action = {} for _idx, _ in enumerate(_env.agents): _action[_idx] = np.random.randint(0, 5) return _action + my_observation_builder = DummyObservationBuilder() episode = 0 diff --git a/flatland/evaluators/messages.py b/flatland/evaluators/messages.py index dfe71efb39d9a80180c4ef3838e7f9d94c2d9db7..35c8b372582a905e87a3d7231a74a6f6fc6785ab 100644 --- a/flatland/evaluators/messages.py +++ b/flatland/evaluators/messages.py @@ -15,4 +15,3 @@ class FLATLAND_RL: ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE" ERROR = "FLATLAND_RL.ERROR" - \ No newline at end of file diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py index ce4cb8cffef8c93fddb79e92d06f56fd92e144b8..8a70f1973d1e3974083be5c72e6a3263c0601992 100644 --- a/flatland/evaluators/service.py +++ b/flatland/evaluators/service.py @@ -8,7 +8,6 @@ import shutil import time import traceback -import flatland import crowdai_api import msgpack import msgpack_numpy as m @@ -16,9 +15,10 @@ import numpy as np import redis import timeout_decorator +import flatland from flatland.core.env_observation_builder import DummyObservationBuilder -from flatland.envs.rail_env import RailEnv from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import rail_from_file from flatland.envs.schedule_generators import schedule_from_file from flatland.evaluators import aicrowd_helpers @@ -353,11 +353,11 @@ class FlatlandRemoteEvaluationService: self.current_step = 0 _observation, _info = self.env.reset( - regenerate_rail=True, - regenerate_schedule=True, - activate_agents=False, - random_seed=RANDOM_SEED - ) + regenerate_rail=True, + regenerate_schedule=True, + activate_agents=False, + random_seed=RANDOM_SEED + ) if self.visualize: if self.env_renderer: @@ -477,14 +477,14 @@ class FlatlandRemoteEvaluationService: ###################################################################### # Print Local Stats ###################################################################### - print("="*100) - print("="*100) + print("=" * 100) + print("=" * 100) print("## Server Performance Stats") - print("="*100) + print("=" * 100) for _key in self.stats: if _key.endswith("_mean"): print("\t - {}\t:{}".format(_key, self.stats[_key])) - print("="*100) + print("=" * 100) # Register simulation time of the last episode self.simulation_times.append(time.time() - self.begin_simulation) @@ -615,7 +615,7 @@ class FlatlandRemoteEvaluationService: print("Self.Reward : ", self.reward) print("Current Simulation : ", self.simulation_count) if self.env_file_paths and \ - self.simulation_count < len(self.env_file_paths): + self.simulation_count < len(self.env_file_paths): print("Current Env Path : ", self.env_file_paths[self.simulation_count]) diff --git a/make_docs.py b/make_docs.py index 81fe58736c68817164dc5e2471827ad071aaca31..2b9d92b287355ef3ee8de4ef9c1d5b16192f719e 100644 --- a/make_docs.py +++ b/make_docs.py @@ -50,7 +50,6 @@ for image_file in glob.glob(r'./specifications/img/*'): subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build']) - # we do not currrently use pydeps, commented out https://gitlab.aicrowd.com/flatland/flatland/issues/149 # subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow']) diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py index a573e55d0eef96d30189b483d6478b84653e1244..cb1ebd0c25384e9acbe585362fe79ff1f7506aa5 100644 --- a/tests/test_flaltland_rail_agent_status.py +++ b/tests/test_flaltland_rail_agent_status.py @@ -120,6 +120,7 @@ def test_initial_status(): run_replay_config(env, [test_config], activate_agents=False) + def test_status_done_remove(): """Test that agent lifecycle works correctly ready-to-depart -> active -> done.""" rail, rail_map = make_simple_rail() diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index e6550f17aad79bc6685e716249d790aa0acb8bf7..a7fd93d0e1fc84e274eae4e5628d3e0b43fadaac 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -217,6 +217,7 @@ def test_get_entry_directions(): # nowhere _assert((0, 0), [False, False, False, False]) + def test_rail_env_reset(): file_name = "test_rail_env_reset.pkl" diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index b2c1ca1162e476ff6e2f4fc3f8489428af23535e..4f0dc13d6142899e9b82c01e9937b9a90feacbc1 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -75,7 +75,8 @@ def test_malfunction_process(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static obs, info = env.reset(False, False, True, random_seed=10) @@ -124,24 +125,25 @@ def test_malfunction_process_statistically(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static env.reset(True, True, False, random_seed=10) env.agents[0].target = (0, 0) # Next line only for test generation - #agent_malfunction_list = [[] for i in range(10)] + # agent_malfunction_list = [[] for i in range(10)] agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], - [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0], - [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0], - [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], - [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2], - [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]] + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0], + [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0], + [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1], + [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]] for step in range(20): action_dict: Dict[int, RailEnvActions] = {} @@ -149,10 +151,10 @@ def test_malfunction_process_statistically(): # We randomly select an action action_dict[agent_idx] = RailEnvActions(np.random.randint(4)) # For generating tests only: - #agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) + # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction']) assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step] env.step(action_dict) - #print(agent_malfunction_list) + # print(agent_malfunction_list) def test_malfunction_before_entry(): @@ -185,7 +187,7 @@ def test_malfunction_before_entry(): assert env.agents[8].malfunction_data['malfunction'] == 10 assert env.agents[9].malfunction_data['malfunction'] == 10 - #for a in range(10): + # for a in range(10): # print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction'])) @@ -230,7 +232,8 @@ def test_initial_malfunction(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(seed=10), number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) # reset to initialize agents_static env.reset(False, False, True, random_seed=10) print(env.agents[0].malfunction_data) @@ -297,7 +300,8 @@ def test_initial_malfunction_stop_moving(): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=1, - obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=SingleAgentNavigationObs(), + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status) diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py index 6ed92fefb0c81512fc6006cbf44b6d55a274caf3..18b68f2a7cdd51602dc4d21f6ff31222ca18c098 100644 --- a/tests/test_flatland_utils_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -49,6 +49,7 @@ def test_render_env(save_new_images=False): oRT.render_env() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) + def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True) diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index bb5cd34e1b79ef9933984fbe480761cec1fd5711..d16cb3d563ee2843d20c076c8f04e44f6bb76f51 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -29,7 +29,8 @@ def test_get_global_observation(): grid_mode=False ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=GlobalObsForRailEnv(), + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index fa455b75c19b666e0b05d229bb309f228aff58ca..4c6c20858a86d3e2f912e0d7883b2e840eec3cff 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -1,13 +1,7 @@ -import random -from typing import Dict, List - import numpy as np -from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import RailAgentStatus from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import rail_from_grid_transition_map @@ -40,6 +34,7 @@ def test_malfanction_from_params(): assert env.min_number_of_steps_broken == 2 assert env.max_number_of_steps_broken == 5 + def test_malfanction_to_and_from_file(): """ Test loading malfunction from @@ -65,11 +60,11 @@ def test_malfanction_to_and_from_file(): env.save("./malfunction_saving_loading_tests.pkl") env2 = RailEnv(width=25, - height=30, - rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), - number_of_agents=1, - malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl")) + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=10), + number_of_agents=1, + malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl")) env2.reset() diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 75634a2299a599998dfeed0dd48e401d7795a794..b60c40ca9faae06434b0af372a2070f140521ffc 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -109,12 +109,14 @@ def test_seeding_and_malfunction(): for tests in range(1, 100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=GlobalObsForRailEnv(), + malfunction_generator=malfunction_from_params(stochastic_data)) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) + obs_builder_object=GlobalObsForRailEnv(), + malfunction_generator=malfunction_from_params(stochastic_data)) env.reset(True, False, True, random_seed=tests) env2.reset(True, False, True, random_seed=tests)