From ae31a7b8ffccd1256ec0b7d450d3d1e14ff0c6ab Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Mon, 4 Nov 2019 09:37:48 -0500
Subject: [PATCH] updated tests

---
 examples/introduction_flatland_2_1.py     |  9 +++--
 examples/simple_example_1.py              |  2 +-
 examples/simple_example_2.py              |  2 +-
 flatland/envs/grid4_generators_utils.py   |  3 +-
 flatland/envs/malfunction_generators.py   |  4 ++-
 flatland/envs/rail_env.py                 |  1 -
 flatland/envs/rail_generators.py          |  4 +--
 flatland/envs/schedule_generators.py      |  2 +-
 flatland/evaluators/aicrowd_helpers.py    | 43 +++++++++++------------
 flatland/evaluators/client.py             | 22 ++++++------
 flatland/evaluators/messages.py           |  1 -
 flatland/evaluators/service.py            | 24 ++++++-------
 make_docs.py                              |  1 -
 tests/test_flaltland_rail_agent_status.py |  1 +
 tests/test_flatland_envs_rail_env.py      |  1 +
 tests/test_flatland_malfunction.py        | 38 +++++++++++---------
 tests/test_flatland_utils_rendertools.py  |  1 +
 tests/test_global_observation.py          |  3 +-
 tests/test_malfunction_generators.py      | 17 ++++-----
 tests/test_random_seeding.py              |  6 ++--
 20 files changed, 97 insertions(+), 88 deletions(-)

diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index 29fafe64..79b06ed9 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -74,8 +74,13 @@ observation_builder = GlobalObsForRailEnv()
 # observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
 # Construct the enviornment with the given observation, generataors, predictors, and stochastic data
-env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator,
-              number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=malfunction_from_params(stochastic_data),
+env = RailEnv(width=width,
+              height=height,
+              rail_generator=rail_generator,
+              schedule_generator=schedule_generator,
+              number_of_agents=nr_trains,
+              obs_builder_object=observation_builder,
+              malfunction_generator=malfunction_from_params(stochastic_data),
               remove_agents_at_target=True)
 env.reset()
 
diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py
index ba889301..b8341e6e 100644
--- a/examples/simple_example_1.py
+++ b/examples/simple_example_1.py
@@ -17,4 +17,4 @@ env_renderer = RenderTool(env)
 env_renderer.render_env(show=True, show_predictions=False, show_observations=False)
 
 # uncomment to keep the renderer open
-#input("Press Enter to continue...")
+# input("Press Enter to continue...")
diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py
index f9659cbd..bffceddf 100644
--- a/examples/simple_example_2.py
+++ b/examples/simple_example_2.py
@@ -33,4 +33,4 @@ env_renderer = RenderTool(env, gl="PIL")
 env_renderer.render_env(show=True)
 
 # uncomment to keep the renderer open
-#input("Press Enter to continue...")
+# input("Press Enter to continue...")
diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py
index ac30ca49..053796e9 100644
--- a/flatland/envs/grid4_generators_utils.py
+++ b/flatland/envs/grid4_generators_utils.py
@@ -160,6 +160,7 @@ def fix_inner_nodes(grid_map: GridTransitionMap, inner_node_pos: IntVector2D, ra
         grid_map.grid[tmp_pos] = transition
     return
 
+
 def align_cell_to_city(city_center, city_orientation, cell):
     """
     Alig all cells to face the city center along the city orientation
@@ -171,4 +172,4 @@ def align_cell_to_city(city_center, city_orientation, cell):
     if city_orientation % 2 == 0:
         return int(2 * np.clip(cell[0] - city_center[0], 0, 1))
     else:
-       return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
+        return int(2 * np.clip(city_center[1] - cell[1], 0, 1)) + 1
diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 0de2f4b5..1c1df1de 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -1,6 +1,6 @@
 """Malfunction generators for rail systems"""
 
-from typing import Tuple, List, Callable
+from typing import Tuple, Callable
 
 import msgpack
 
@@ -36,6 +36,7 @@ def malfunction_from_file(filename) -> MalfunctionGenerator:
 
     return generator
 
+
 def malfunction_from_params(parameters) -> MalfunctionGenerator:
     """
     Utility to load malfunction from parameters
@@ -60,6 +61,7 @@ def malfunction_from_params(parameters) -> MalfunctionGenerator:
 
     return generator
 
+
 def no_malfunction_generator() -> MalfunctionGenerator:
     """
     Utility to load malfunction from parameters
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index fe51db2f..0adbd451 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -994,4 +994,3 @@ class RailEnv(Environment):
 
         """
         return agent.malfunction_data['malfunction'] < 1
-
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 3e90128c..231cc825 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -345,7 +345,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R
         def get_matching_templates(template):
             """
             Returns a list of possible transition maps for a given template
-            
+
             Parameters:
             ------
             template:List[int]
@@ -751,7 +751,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         # Respect padding between cities
         padding = 2
         city_size = 2 * (city_radius + 1)
-        max_cities_per_row =int((height - padding) // city_size)
+        max_cities_per_row = int((height - padding) // city_size)
         max_cities_per_col = int((width - padding) // city_size)
 
         # Choose number of cities per row.
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 903b58f9..f48264d5 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -204,7 +204,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
         if len(valid_positions) < num_agents:
             warnings.warn("schedule_generators: len(valid_positions) < num_agents")
             return Schedule(agent_positions=[], agent_directions=[],
-                            agent_targets=[], agent_speeds=[],  agent_malfunction_rates=None)
+                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
 
         agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
         agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
diff --git a/flatland/evaluators/aicrowd_helpers.py b/flatland/evaluators/aicrowd_helpers.py
index 1f79eac9..779f5ad6 100644
--- a/flatland/evaluators/aicrowd_helpers.py
+++ b/flatland/evaluators/aicrowd_helpers.py
@@ -30,14 +30,14 @@ def get_boto_client():
         import boto3
     except ImportError as e:
         raise Exception(
-                        "boto3 is not installed. Please manually install by : ",
-                        " pip install -U boto3"
-                        )
+            "boto3 is not installed. Please manually install by : ",
+            " pip install -U boto3"
+        )
 
     return boto3.client(
-            's3',
-            aws_access_key_id=AWS_ACCESS_KEY_ID,
-            aws_secret_access_key=AWS_SECRET_ACCESS_KEY
+        's3',
+        aws_access_key_id=AWS_ACCESS_KEY_ID,
+        aws_secret_access_key=AWS_SECRET_ACCESS_KEY
     )
 
 
@@ -50,7 +50,7 @@ def is_aws_configured():
 
 def is_grading():
     return os.getenv("CROWDAI_IS_GRADING", False) or \
-        os.getenv("AICROWD_IS_GRADING", False)
+           os.getenv("AICROWD_IS_GRADING", False)
 
 
 def upload_random_frame_to_s3(frames_folder):
@@ -61,7 +61,7 @@ def upload_random_frame_to_s3(frames_folder):
         raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...")
     if not S3_BUCKET:
         raise Exception("S3_BUCKET not provided...")
-    
+
     image_target_key = S3_UPLOAD_PATH_TEMPLATE.replace(".mp4", ".png").format(str(uuid.uuid4()))
     s3.put_object(
         ACL="public-read",
@@ -78,7 +78,7 @@ def upload_to_s3(localpath):
         raise Exception("S3_UPLOAD_PATH_TEMPLATE not provided...")
     if not S3_BUCKET:
         raise Exception("S3_BUCKET not provided...")
-    
+
     image_target_key = S3_UPLOAD_PATH_TEMPLATE.format(str(uuid.uuid4()))
     s3.put_object(
         ACL="public-read",
@@ -91,11 +91,11 @@ def upload_to_s3(localpath):
 
 def make_subprocess_call(command, shell=False):
     result = subprocess.run(
-                command.split(),
-                shell=shell,
-                stdout=subprocess.PIPE,
-                stderr=subprocess.PIPE
-            )
+        command.split(),
+        shell=shell,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE
+    )
     stdout = result.stdout.decode('utf-8')
     stderr = result.stderr.decode('utf-8')
     return result.returncode, stdout, stderr
@@ -103,7 +103,7 @@ def make_subprocess_call(command, shell=False):
 
 def generate_movie_from_frames(frames_folder):
     """
-        Expects the frames in the  frames_folder folder 
+        Expects the frames in the  frames_folder folder
         and then use ffmpeg to generate the video
         which writes the output to the frames_folder
     """
@@ -112,9 +112,9 @@ def generate_movie_from_frames(frames_folder):
     frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
     thumb_output_path = os.path.join(frames_folder, "out_thumb.mp4")
     return_code, output, output_err = make_subprocess_call(
-        "ffmpeg -r 7 -start_number 0 -i " + 
-        frames_path + 
-        " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " + 
+        "ffmpeg -r 7 -start_number 0 -i " +
+        frames_path +
+        " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 320x320 " +
         thumb_output_path
     )
     if return_code != 0:
@@ -125,13 +125,12 @@ def generate_movie_from_frames(frames_folder):
     frames_path = os.path.join(frames_folder, "flatland_frame_%04d.png")
     output_path = os.path.join(frames_folder, "out.mp4")
     return_code, output, output_err = make_subprocess_call(
-        "ffmpeg -r 7 -start_number 0 -i " + 
-        frames_path + 
-        " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " + 
+        "ffmpeg -r 7 -start_number 0 -i " +
+        frames_path +
+        " -c:v libx264 -vf fps=7 -pix_fmt yuv420p -s 600x600 " +
         output_path
     )
     if return_code != 0:
         raise Exception(output_err)
 
     return output_path, thumb_output_path
-
diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index 7b2e1899..922f7fc0 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -11,8 +11,6 @@ import numpy as np
 import redis
 
 import flatland
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import rail_from_file
 from flatland.envs.schedule_generators import schedule_from_file
@@ -223,11 +221,11 @@ class FlatlandRemoteClient(object):
 
         time_start = time.time()
         local_observation, info = self.env.reset(
-                                regenerate_rail=True,
-                                regenerate_schedule=True,
-                                activate_agents=False,
-                                random_seed=random_seed
-                            )
+            regenerate_rail=True,
+            regenerate_schedule=True,
+            activate_agents=False,
+            random_seed=random_seed
+        )
         time_diff = time.time() - time_start
         self.update_running_mean_stats("internal_env_reset_time", time_diff)
         # Use the local observation
@@ -266,14 +264,14 @@ class FlatlandRemoteClient(object):
         ######################################################################
         # Print Local Stats
         ######################################################################
-        print("="*100)
-        print("="*100)
+        print("=" * 100)
+        print("=" * 100)
         print("## Client Performance Stats")
-        print("="*100)
+        print("=" * 100)
         for _key in self.stats:
             if _key.endswith("_mean"):
                 print("\t - {}\t:{}".format(_key, self.stats[_key]))
-        print("="*100)
+        print("=" * 100)
         if os.getenv("AICROWD_BLOCKING_SUBMIT"):
             """
             If the submission is supposed to happen as a blocking submit,
@@ -288,12 +286,14 @@ class FlatlandRemoteClient(object):
 if __name__ == "__main__":
     remote_client = FlatlandRemoteClient()
 
+
     def my_controller(obs, _env):
         _action = {}
         for _idx, _ in enumerate(_env.agents):
             _action[_idx] = np.random.randint(0, 5)
         return _action
 
+
     my_observation_builder = DummyObservationBuilder()
 
     episode = 0
diff --git a/flatland/evaluators/messages.py b/flatland/evaluators/messages.py
index dfe71efb..35c8b372 100644
--- a/flatland/evaluators/messages.py
+++ b/flatland/evaluators/messages.py
@@ -15,4 +15,3 @@ class FLATLAND_RL:
     ENV_SUBMIT_RESPONSE = "FLATLAND_RL.ENV_SUBMIT_RESPONSE"
 
     ERROR = "FLATLAND_RL.ERROR"
-    
\ No newline at end of file
diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py
index ce4cb8cf..8a70f197 100644
--- a/flatland/evaluators/service.py
+++ b/flatland/evaluators/service.py
@@ -8,7 +8,6 @@ import shutil
 import time
 import traceback
 
-import flatland
 import crowdai_api
 import msgpack
 import msgpack_numpy as m
@@ -16,9 +15,10 @@ import numpy as np
 import redis
 import timeout_decorator
 
+import flatland
 from flatland.core.env_observation_builder import DummyObservationBuilder
-from flatland.envs.rail_env import RailEnv
 from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import rail_from_file
 from flatland.envs.schedule_generators import schedule_from_file
 from flatland.evaluators import aicrowd_helpers
@@ -353,11 +353,11 @@ class FlatlandRemoteEvaluationService:
             self.current_step = 0
 
             _observation, _info = self.env.reset(
-                                regenerate_rail=True,
-                                regenerate_schedule=True,
-                                activate_agents=False,
-                                random_seed=RANDOM_SEED
-                                )
+                regenerate_rail=True,
+                regenerate_schedule=True,
+                activate_agents=False,
+                random_seed=RANDOM_SEED
+            )
 
             if self.visualize:
                 if self.env_renderer:
@@ -477,14 +477,14 @@ class FlatlandRemoteEvaluationService:
         ######################################################################
         # Print Local Stats
         ######################################################################
-        print("="*100)
-        print("="*100)
+        print("=" * 100)
+        print("=" * 100)
         print("## Server Performance Stats")
-        print("="*100)
+        print("=" * 100)
         for _key in self.stats:
             if _key.endswith("_mean"):
                 print("\t - {}\t:{}".format(_key, self.stats[_key]))
-        print("="*100)
+        print("=" * 100)
 
         # Register simulation time of the last episode
         self.simulation_times.append(time.time() - self.begin_simulation)
@@ -615,7 +615,7 @@ class FlatlandRemoteEvaluationService:
                 print("Self.Reward : ", self.reward)
                 print("Current Simulation : ", self.simulation_count)
                 if self.env_file_paths and \
-                        self.simulation_count < len(self.env_file_paths):
+                    self.simulation_count < len(self.env_file_paths):
                     print("Current Env Path : ",
                           self.env_file_paths[self.simulation_count])
 
diff --git a/make_docs.py b/make_docs.py
index 81fe5873..2b9d92b2 100644
--- a/make_docs.py
+++ b/make_docs.py
@@ -50,7 +50,6 @@ for image_file in glob.glob(r'./specifications/img/*'):
 
 subprocess.call(['python', '-msphinx', '-M', 'html', '.', '_build'])
 
-
 # we do not currrently use pydeps, commented out https://gitlab.aicrowd.com/flatland/flatland/issues/149
 # subprocess.call(['python', '-mpydeps', '../flatland', '-o', '_build/html/flatland.svg', '--no-config', '--noshow'])
 
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
index a573e55d..cb1ebd0c 100644
--- a/tests/test_flaltland_rail_agent_status.py
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -120,6 +120,7 @@ def test_initial_status():
 
     run_replay_config(env, [test_config], activate_agents=False)
 
+
 def test_status_done_remove():
     """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
     rail, rail_map = make_simple_rail()
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index e6550f17..a7fd93d0 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -217,6 +217,7 @@ def test_get_entry_directions():
     # nowhere
     _assert((0, 0), [False, False, False, False])
 
+
 def test_rail_env_reset():
     file_name = "test_rail_env_reset.pkl"
 
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index b2c1ca11..4f0dc13d 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -75,7 +75,8 @@ def test_malfunction_process():
 
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(), number_of_agents=1,
-                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
+                  obs_builder_object=SingleAgentNavigationObs(),
+                  malfunction_generator=malfunction_from_params(stochastic_data))
     # reset to initialize agents_static
     obs, info = env.reset(False, False, True, random_seed=10)
 
@@ -124,24 +125,25 @@ def test_malfunction_process_statistically():
 
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(), number_of_agents=10,
-                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
+                  obs_builder_object=SingleAgentNavigationObs(),
+                  malfunction_generator=malfunction_from_params(stochastic_data))
 
     # reset to initialize agents_static
     env.reset(True, True, False, random_seed=10)
 
     env.agents[0].target = (0, 0)
     # Next line only for test generation
-    #agent_malfunction_list = [[] for i in range(10)]
+    # agent_malfunction_list = [[] for i in range(10)]
     agent_malfunction_list = [[0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
-     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
-     [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
-     [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
-     [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
-     [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
-     [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-     [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
+                              [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 0],
+                              [5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0],
+                              [0, 5, 4, 3, 2, 1, 0, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2, 1],
+                              [0, 0, 0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0],
+                              [0, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                              [0, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 5, 4, 3, 2, 1],
+                              [0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 5, 4, 3, 2, 1, 0, 5, 4, 3, 2],
+                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                              [5, 4, 3, 2, 1, 0, 0, 0, 0, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}
@@ -149,10 +151,10 @@ def test_malfunction_process_statistically():
             # We randomly select an action
             action_dict[agent_idx] = RailEnvActions(np.random.randint(4))
             # For generating tests only:
-            #agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
+            # agent_malfunction_list[agent_idx].append(env.agents[agent_idx].malfunction_data['malfunction'])
             assert env.agents[agent_idx].malfunction_data['malfunction'] == agent_malfunction_list[agent_idx][step]
         env.step(action_dict)
-    #print(agent_malfunction_list)
+    # print(agent_malfunction_list)
 
 
 def test_malfunction_before_entry():
@@ -185,7 +187,7 @@ def test_malfunction_before_entry():
     assert env.agents[8].malfunction_data['malfunction'] == 10
     assert env.agents[9].malfunction_data['malfunction'] == 10
 
-    #for a in range(10):
+    # for a in range(10):
     #  print("assert env.agents[{}].malfunction_data['malfunction'] == {}".format(a,env.agents[a].malfunction_data['malfunction']))
 
 
@@ -230,7 +232,8 @@ def test_initial_malfunction():
 
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(seed=10), number_of_agents=1,
-                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
+                  obs_builder_object=SingleAgentNavigationObs(),
+                  malfunction_generator=malfunction_from_params(stochastic_data))
     # reset to initialize agents_static
     env.reset(False, False, True, random_seed=10)
     print(env.agents[0].malfunction_data)
@@ -297,7 +300,8 @@ def test_initial_malfunction_stop_moving():
 
     env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                   schedule_generator=random_schedule_generator(), number_of_agents=1,
-                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
+                  obs_builder_object=SingleAgentNavigationObs(),
+                  malfunction_generator=malfunction_from_params(stochastic_data))
     env.reset()
 
     print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py
index 6ed92fef..18b68f2a 100644
--- a/tests/test_flatland_utils_rendertools.py
+++ b/tests/test_flatland_utils_rendertools.py
@@ -49,6 +49,7 @@ def test_render_env(save_new_images=False):
     oRT.render_env()
     checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
 
+
 def main():
     if len(sys.argv) == 2 and sys.argv[1] == "save":
         test_render_env(save_new_images=True)
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
index bb5cd34e..d16cb3d5 100644
--- a/tests/test_global_observation.py
+++ b/tests/test_global_observation.py
@@ -29,7 +29,8 @@ def test_get_global_observation():
                                                                             grid_mode=False
                                                                             ),
                   schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents,
-                  obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
+                  obs_builder_object=GlobalObsForRailEnv(),
+                  malfunction_generator=malfunction_from_params(stochastic_data))
     env.reset()
 
     obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py
index fa455b75..4c6c2085 100644
--- a/tests/test_malfunction_generators.py
+++ b/tests/test_malfunction_generators.py
@@ -1,13 +1,7 @@
-import random
-from typing import Dict, List
-
 import numpy as np
-from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import rail_from_grid_transition_map
@@ -40,6 +34,7 @@ def test_malfanction_from_params():
     assert env.min_number_of_steps_broken == 2
     assert env.max_number_of_steps_broken == 5
 
+
 def test_malfanction_to_and_from_file():
     """
     Test loading malfunction from
@@ -65,11 +60,11 @@ def test_malfanction_to_and_from_file():
     env.save("./malfunction_saving_loading_tests.pkl")
 
     env2 = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=10),
-                  number_of_agents=1,
-                  malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl"))
+                   height=30,
+                   rail_generator=rail_from_grid_transition_map(rail),
+                   schedule_generator=random_schedule_generator(seed=10),
+                   number_of_agents=1,
+                   malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl"))
 
     env2.reset()
 
diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py
index 75634a22..b60c40ca 100644
--- a/tests/test_random_seeding.py
+++ b/tests/test_random_seeding.py
@@ -109,12 +109,14 @@ def test_seeding_and_malfunction():
     for tests in range(1, 100):
         env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                       schedule_generator=random_schedule_generator(), number_of_agents=10,
-                      obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
+                      obs_builder_object=GlobalObsForRailEnv(),
+                      malfunction_generator=malfunction_from_params(stochastic_data))
 
         # Tree Observation
         env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
                        schedule_generator=random_schedule_generator(), number_of_agents=10,
-                       obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data))
+                       obs_builder_object=GlobalObsForRailEnv(),
+                       malfunction_generator=malfunction_from_params(stochastic_data))
 
         env.reset(True, False, True, random_seed=tests)
         env2.reset(True, False, True, random_seed=tests)
-- 
GitLab