diff --git a/README.md b/README.md
index 02209ab4308f9dc65c84fc744de48f328ad5dfe4..f81370b950d1c3f9d40cbdde5d64f982df8de6a6 100644
--- a/README.md
+++ b/README.md
@@ -149,7 +149,7 @@ env = RailEnv(width=width,
               rail_generator=rail_generator,
               schedule_generator=schedule_generator,
               number_of_agents=nr_trains,
-              stochastic_data=stochastic_data,  # Malfunction data generator
+              malfunction_generator=stochastic_data,  # Malfunction data generator
               obs_builder_object=observation_builder,
               remove_agents_at_target=True  # Removes agents at the end of their journey to make space for others
               )
diff --git a/examples/complex_rail_benchmark.py b/examples/complex_rail_benchmark.py
index 87f248ee223b061fca8375d56072ed81d5ad338c..7a203baf87ebc07a3e5a1afad8606bdd98a8cc83 100644
--- a/examples/complex_rail_benchmark.py
+++ b/examples/complex_rail_benchmark.py
@@ -14,10 +14,8 @@ def run_benchmark():
     np.random.seed(1)
 
     # Example generate a random rail
-    env = RailEnv(width=15, height=15,
-                  rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
-                  schedule_generator=complex_schedule_generator(),
-                  number_of_agents=5)
+    env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12),
+                  schedule_generator=complex_schedule_generator(), number_of_agents=5)
     env.reset()
 
     n_trials = 20
diff --git a/examples/custom_observation_example_01_SimpleObs.py b/examples/custom_observation_example_01_SimpleObs.py
index 600b8f0968fb2226ec92c224c32ffd7617138e5b..2cee8e867aaec93e1a42dba562be787ab03fa5cd 100644
--- a/examples/custom_observation_example_01_SimpleObs.py
+++ b/examples/custom_observation_example_01_SimpleObs.py
@@ -28,10 +28,7 @@ class SimpleObs(ObservationBuilder):
 
 
 def main():
-    env = RailEnv(width=7,
-                  height=7,
-                  rail_generator=random_rail_generator(),
-                  number_of_agents=3,
+    env = RailEnv(width=7, height=7, rail_generator=random_rail_generator(), number_of_agents=3,
                   obs_builder_object=SimpleObs())
     env.reset()
 
diff --git a/examples/custom_observation_example_02_SingleAgentNavigationObs.py b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
index b1729296a199fded38270a63a527de06d9e7b329..52a56b06dcdeb9bebeb9dbb70d5acc9ae0c350bb 100644
--- a/examples/custom_observation_example_02_SingleAgentNavigationObs.py
+++ b/examples/custom_observation_example_02_SingleAgentNavigationObs.py
@@ -76,13 +76,10 @@ def main(args):
         else:
             assert False, "unhandled option"
 
-    env = RailEnv(width=7,
-                  height=7,
+    env = RailEnv(width=7, height=7,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999,
-                                                        seed=1),
-                  schedule_generator=complex_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=SingleAgentNavigationObs())
+                                                        seed=1), schedule_generator=complex_schedule_generator(),
+                  number_of_agents=1, obs_builder_object=SingleAgentNavigationObs())
 
     obs, info = env.reset()
     env_renderer = RenderTool(env, gl="PILSVG")
diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 7af7499af8ffc5900a87d7e543310aab2a6df7f9..ac99835368fc4a5a709e72234396a286f1d622a7 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -122,13 +122,10 @@ def main(args):
     custom_obs_builder = ObservePredictions(custom_predictor)
 
     # Initiate Environment
-    env = RailEnv(width=10,
-                  height=10,
+    env = RailEnv(width=10, height=10,
                   rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=8, max_dist=99999,
-                                                        seed=1),
-                  schedule_generator=complex_schedule_generator(),
-                  number_of_agents=3,
-                  obs_builder_object=custom_obs_builder)
+                                                        seed=1), schedule_generator=complex_schedule_generator(),
+                  number_of_agents=3, obs_builder_object=custom_obs_builder)
 
     obs, info = env.reset()
     env_renderer = RenderTool(env, gl="PILSVG")
diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index ed263ef93e80f4d4a04db240d5e21c6e855806f8..ceea22a94fd1c4803a73fe57230393116766c8bd 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -43,10 +43,7 @@ def custom_schedule_generator() -> ScheduleGenerator:
     return generator
 
 
-env = RailEnv(width=6,
-              height=4,
-              rail_generator=custom_rail_generator(),
-              schedule_generator=custom_schedule_generator(),
+env = RailEnv(width=6, height=4, rail_generator=custom_rail_generator(), schedule_generator=custom_schedule_generator(),
               number_of_agents=1)
 
 env.reset()
diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 5ece03e9c56d672b76a453e0036f6b89c3a6ee77..5556a2a0ed9c5b67e2708c8bf222304603a131ad 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -30,21 +30,16 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
                     1. / 3.: 0.25,  # Slow commuter train
                     1. / 4.: 0.25}  # Slow freight train
 
-env = RailEnv(width=100,
-              height=100,
-              rail_generator=sparse_rail_generator(max_num_cities=30,
-                                                   # Number of cities in map (where train stations are)
-                                                   seed=14,  # Random seed
-                                                   grid_mode=False,
-                                                   max_rails_between_cities=2,
-                                                   max_rails_in_city=8,
-                                                   ),
-              schedule_generator=sparse_schedule_generator(speed_ration_map),
-              number_of_agents=100,
-              stochastic_data=stochastic_data,  # Malfunction data generator
-              obs_builder_object=GlobalObsForRailEnv(),
-              remove_agents_at_target=True
-              )
+env = RailEnv(width=100, height=100, rail_generator=sparse_rail_generator(max_num_cities=30,
+                                                                          # Number of cities in map (where train stations are)
+                                                                          seed=14,  # Random seed
+                                                                          grid_mode=False,
+                                                                          max_rails_between_cities=2,
+                                                                          max_rails_in_city=8,
+                                                                          ),
+              schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=100,
+              obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data,
+              remove_agents_at_target=True)
 
 # RailEnv.DEPOT_POSITION = lambda agent, agent_handle : (agent_handle % env.height,0)
 
diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index 4cdf63b0b32d431efabc05ed4d593aa746a44b40..de7c77faebc1f9d8fb8bb9b18fec142601babcb7 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -72,15 +72,9 @@ observation_builder = GlobalObsForRailEnv()
 # observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
 # Construct the enviornment with the given observation, generataors, predictors, and stochastic data
-env = RailEnv(width=width,
-              height=height,
-              rail_generator=rail_generator,
-              schedule_generator=schedule_generator,
-              number_of_agents=nr_trains,
-              stochastic_data=stochastic_data,  # Malfunction data generator
-              obs_builder_object=observation_builder,
-              remove_agents_at_target=True  # Removes agents at the end of their journey to make space for others
-              )
+env = RailEnv(width=width, height=height, rail_generator=rail_generator, schedule_generator=schedule_generator,
+              number_of_agents=nr_trains, obs_builder_object=observation_builder, malfunction_generator=stochastic_data,
+              remove_agents_at_target=True)
 env.reset()
 
 # Initiate the renderer
diff --git a/examples/simple_example_1.py b/examples/simple_example_1.py
index 388128d0d246d73f0236b054a3228ec20c46864e..ba88930142f8344b13a3cae8de2178148e459998 100644
--- a/examples/simple_example_1.py
+++ b/examples/simple_example_1.py
@@ -9,10 +9,7 @@ specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
          [(7, 270), (1, 90), (1, 90), (1, 90), (2, 90), (7, 90)],
          [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]]
 
-env = RailEnv(width=6,
-              height=4,
-              rail_generator=rail_from_manual_specifications_generator(specs),
-              number_of_agents=1)
+env = RailEnv(width=6, height=4, rail_generator=rail_from_manual_specifications_generator(specs), number_of_agents=1)
 
 env.reset()
 
diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py
index 34abee096a043b73f53de8eed42a2e2b73ec1cc5..f9659cbdb666a2a6bd94db8aa560ded40f69079a 100644
--- a/examples/simple_example_2.py
+++ b/examples/simple_example_2.py
@@ -23,8 +23,7 @@ transition_probability = [1.0,  # empty cell - Case 0
                           1.0]  # Case 10 - mirrored switch
 
 # Example generate a random rail
-env = RailEnv(width=10,
-              height=10,
+env = RailEnv(width=10, height=10,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
               number_of_agents=3)
 
diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index ccbe8682fe5c8744737a452c77257dd4570b6f75..82fca31943a6f60607dfbc8e6befdac185c3acc6 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -11,11 +11,9 @@ from flatland.utils.rendertools import RenderTool
 random.seed(1)
 np.random.seed(1)
 
-env = RailEnv(width=7,
-              height=7,
+env = RailEnv(width=7, height=7,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=5, max_dist=99999, seed=1),
-              schedule_generator=complex_schedule_generator(),
-              number_of_agents=2,
+              schedule_generator=complex_schedule_generator(), number_of_agents=2,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
 env.reset()
diff --git a/examples/training_example.py b/examples/training_example.py
index 2ce2ad1a86dd85acf00926f413c941df84d973c7..5f8cbe4088b1358e13a323a7c665ac8ccf60f740 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -14,12 +14,9 @@ np.random.seed(1)
 
 TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
-env = RailEnv(width=20,
-              height=20,
+env = RailEnv(width=20, height=20,
               rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=2, min_dist=8, max_dist=99999, seed=1),
-              schedule_generator=complex_schedule_generator(),
-              obs_builder_object=TreeObservation,
-              number_of_agents=3)
+              schedule_generator=complex_schedule_generator(), number_of_agents=3, obs_builder_object=TreeObservation)
 env.reset()
 
 env_renderer = RenderTool(env, gl="PILSVG", )
diff --git a/flatland/cli.py b/flatland/cli.py
index f544aabcb6d9e81ddf8703c59b8bf07324b3ce2c..cc7576d16a02b0d0268ecaab201921b5034d7ee0 100644
--- a/flatland/cli.py
+++ b/flatland/cli.py
@@ -18,16 +18,11 @@ from flatland.utils.rendertools import RenderTool
 @click.command()
 def demo(args=None):
     """Demo script to check installation"""
-    env = RailEnv(
-        width=15,
-        height=15,
-        rail_generator=complex_rail_generator(
-            nr_start_goal=10,
-            nr_extra=1,
-            min_dist=8,
-            max_dist=99999),
-        schedule_generator=complex_schedule_generator(),
-        number_of_agents=5)
+    env = RailEnv(width=15, height=15, rail_generator=complex_rail_generator(
+        nr_start_goal=10,
+        nr_extra=1,
+        min_dist=8,
+        max_dist=99999), schedule_generator=complex_schedule_generator(), number_of_agents=5)
 
     env._max_episode_steps = int(15 * (env.width + env.height))
     env_renderer = RenderTool(env)
diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..0de2f4b598023884a49a47473670bb8613c71eeb
--- /dev/null
+++ b/flatland/envs/malfunction_generators.py
@@ -0,0 +1,79 @@
+"""Malfunction generators for rail systems"""
+
+from typing import Tuple, List, Callable
+
+import msgpack
+
+MalfunctionGenerator = Callable[[], Tuple[float, int, int]]
+
+
+def malfunction_from_file(filename) -> MalfunctionGenerator:
+    """
+    Utility to load pickle file
+
+    Parameters
+    ----------
+    input_file : Pickle file generated by env.save() or editor
+
+    Returns
+    -------
+    Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    """
+
+    def generator():
+        with open(filename, "rb") as file_in:
+            load_data = file_in.read()
+        data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
+
+        if "malfunction" in data:
+            # Mean malfunction in number of time steps
+            mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
+            # Uniform distribution parameters for malfunction duration
+            min_number_of_steps_broken = data["malfunction"]["min_duration"]
+            max_number_of_steps_broken = data["malfunction"]["max_duration"]
+            agents_speed = None
+        return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+
+    return generator
+
+def malfunction_from_params(parameters) -> MalfunctionGenerator:
+    """
+    Utility to load malfunction from parameters
+
+    Parameters
+    ----------
+    parameters containing
+    malfunction_rate : float how many time steps it takes for a sinlge agent befor it breaks
+    min_duration : int minimal duration of a failure
+    max_number_of_steps_broken : int maximal duration of a failure
+
+    Returns
+    -------
+    Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    """
+
+    def generator():
+        mean_malfunction_rate = parameters['malfunction_rate']
+        min_number_of_steps_broken = parameters['min_duration']
+        max_number_of_steps_broken = parameters['max_duration']
+        return mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+
+    return generator
+
+def no_malfunction_generator() -> MalfunctionGenerator:
+    """
+    Utility to load malfunction from parameters
+
+    Parameters
+    ----------
+    input_file : Pickle file generated by env.save() or editor
+
+    Returns
+    -------
+    Tuple[float, int, int] with mean_malfunction_rate, min_number_of_steps_broken, max_number_of_steps_broken
+    """
+
+    def generator():
+        return 0, 0, 0
+
+    return generator
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index f284f3ac38b480d8feb6c0b4944cf8831d2a70d2..6da778ed40e917833c9460e9e08a8e4a516e5611 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -19,6 +19,7 @@ from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
+from flatland.envs.malfunction_generators import MalfunctionGenerator, no_malfunction_generator
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
@@ -111,17 +112,15 @@ class RailEnv(Environment):
     stop_penalty = 0  # penalty for stopping a moving agent
     start_penalty = 0  # penalty for starting a stopped agent
 
-    def __init__(self,
-                 width,
+    def __init__(self, width,
                  height,
                  rail_generator: RailGenerator = random_rail_generator(),
                  schedule_generator: ScheduleGenerator = random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
-                 stochastic_data=None,
+                 malfunction_generator: MalfunctionGenerator = no_malfunction_generator(),
                  remove_agents_at_target=True,
-                 random_seed=1
-                 ):
+                 random_seed=1):
         """
         Environment init.
 
@@ -161,6 +160,7 @@ class RailEnv(Environment):
 
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
+        self.malfunction_generator: MalfunctionGenerator = malfunction_generator
         self.rail: Optional[GridTransitionMap] = None
         self.width = width
         self.height = height
@@ -196,19 +196,8 @@ class RailEnv(Environment):
             self._seed(seed=random_seed)
 
         # Stochastic train malfunctioning parameters
-        if stochastic_data is not None:
-            mean_malfunction_rate = stochastic_data['malfunction_rate']
-            malfunction_min_duration = stochastic_data['min_duration']
-            malfunction_max_duration = stochastic_data['max_duration']
-        else:
-            mean_malfunction_rate = 0.
-            malfunction_min_duration = 0.
-            malfunction_max_duration = 0.
-
-        # Mean malfunction in number of time steps
+        mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator()
         self.mean_malfunction_rate = mean_malfunction_rate
-
-        # Uniform distribution parameters for malfunction duration
         self.min_number_of_steps_broken = malfunction_min_duration
         self.max_number_of_steps_broken = malfunction_max_duration
 
@@ -359,6 +348,12 @@ class RailEnv(Environment):
             else:
                 self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
 
+        # Stochastic train malfunctioning parameters
+        mean_malfunction_rate, malfunction_min_duration, malfunction_max_duration = self.malfunction_generator()
+        self.mean_malfunction_rate = mean_malfunction_rate
+        self.min_number_of_steps_broken = malfunction_min_duration
+        self.max_number_of_steps_broken = malfunction_max_duration
+
         self.agent_positions = np.full((self.height, self.width), False)
 
         self.restart_agents()
@@ -837,22 +832,41 @@ class RailEnv(Environment):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
         agent_data = [agent.to_list() for agent in self.agents]
+        malfunction_data = {"malfunction_rate": self.mean_malfunction_rate,
+                            "min_duration": self.min_number_of_steps_broken,
+                            "max_duration": self.max_number_of_steps_broken}
+
         msgpack.packb(grid_data, use_bin_type=True)
         msgpack.packb(agent_data, use_bin_type=True)
         msgpack.packb(agent_static_data, use_bin_type=True)
         msg_data = {
             "grid": grid_data,
             "agents_static": agent_static_data,
-            "agents": agent_data}
+            "agents": agent_data,
+            "malfunction": malfunction_data}
         return msgpack.packb(msg_data, use_bin_type=True)
 
-    def get_agent_state_msg(self):
+    def get_full_state_dist_msg(self):
         """
-        Returns agents information in msgpack object
+        Returns environment information with distance map information as msgpack object
         """
+        grid_data = self.rail.grid.tolist()
+        agent_static_data = [agent.to_list() for agent in self.agents_static]
         agent_data = [agent.to_list() for agent in self.agents]
+        msgpack.packb(grid_data, use_bin_type=True)
+        msgpack.packb(agent_data, use_bin_type=True)
+        msgpack.packb(agent_static_data, use_bin_type=True)
+        distance_map_data = self.distance_map.get()
+        malfunction_data = {"malfunction_rate": self.mean_malfunction_rate,
+                            "min_duration": self.min_number_of_steps_broken,
+                            "max_duration": self.max_number_of_steps_broken}
+        msgpack.packb(distance_map_data, use_bin_type=True)
         msg_data = {
-            "agents": agent_data}
+            "grid": grid_data,
+            "agents_static": agent_static_data,
+            "agents": agent_data,
+            "distance_map": distance_map_data,
+            "malfunction": malfunction_data}
         return msgpack.packb(msg_data, use_bin_type=True)
 
     def set_full_state_msg(self, msg_data):
@@ -873,6 +887,12 @@ class RailEnv(Environment):
         self.rail.height = self.height
         self.rail.width = self.width
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
+        if "malfunction" in data:
+            # Mean malfunction in number of time steps
+            self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
+            # Uniform distribution parameters for malfunction duration
+            self.min_number_of_steps_broken = data["malfunction"]["min_duration"]
+            self.max_number_of_steps_broken = data["malfunction"]["max_duration"]
 
     def set_full_state_dist_msg(self, msg_data):
         """
@@ -894,26 +914,12 @@ class RailEnv(Environment):
         self.rail.height = self.height
         self.rail.width = self.width
         self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
-
-    def get_full_state_dist_msg(self):
-        """
-        Returns environment information with distance map information as msgpack object
-        """
-        grid_data = self.rail.grid.tolist()
-        agent_static_data = [agent.to_list() for agent in self.agents_static]
-        agent_data = [agent.to_list() for agent in self.agents]
-        msgpack.packb(grid_data, use_bin_type=True)
-        msgpack.packb(agent_data, use_bin_type=True)
-        msgpack.packb(agent_static_data, use_bin_type=True)
-        distance_map_data = self.distance_map.get()
-        msgpack.packb(distance_map_data, use_bin_type=True)
-        msg_data = {
-            "grid": grid_data,
-            "agents_static": agent_static_data,
-            "agents": agent_data,
-            "distance_map": distance_map_data}
-
-        return msgpack.packb(msg_data, use_bin_type=True)
+        if "malfunction" in data:
+            # Mean malfunction in number of time steps
+            self.mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
+            # Uniform distribution parameters for malfunction duration
+            self.min_number_of_steps_broken = data["malfunction"]["min_duration"]
+            self.max_number_of_steps_broken = data["malfunction"]["max_duration"]
 
     def save(self, filename, save_distance_maps=False):
         """
diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index dc1cff12c0c8b1860859208a13d6403734a2d2ad..525fd6564b2a76e63641d38b7c93823ec2c83153 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -10,10 +10,7 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b
         obs_builder_object = TreeObsForRailEnv(
             max_depth=2,
             predictor=ShortestPathPredictorForRailEnv(max_depth=10))
-    environment = RailEnv(width=1,
-                          height=1,
-                          rail_generator=rail_from_file(file_name, load_from_package),
-                          number_of_agents=1,
-                          schedule_generator=schedule_from_file(file_name, load_from_package),
+    environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
+                          schedule_generator=schedule_from_file(file_name, load_from_package), number_of_agents=1,
                           obs_builder_object=obs_builder_object)
     return environment
diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index d41224f36f5ffc974e976bb362e8ba8050df4e7d..7b2e189971a364afc55cd62cd2e83bb69c9fbd89 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -217,13 +217,9 @@ class FlatlandRemoteClient(object):
         if self.verbose:
             print("Current env path : ", test_env_file_path)
         self.current_env_path = test_env_file_path
-        self.env = RailEnv(
-            width=1,
-            height=1,
-            rail_generator=rail_from_file(test_env_file_path),
-            schedule_generator=schedule_from_file(test_env_file_path),
-            obs_builder_object=obs_builder_object
-        )
+        self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
+                           schedule_generator=schedule_from_file(test_env_file_path),
+                           obs_builder_object=obs_builder_object)
 
         time_start = time.time()
         local_observation, info = self.env.reset(
@@ -246,8 +242,8 @@ class FlatlandRemoteClient(object):
         _request['type'] = messages.FLATLAND_RL.ENV_STEP
         _request['payload'] = {}
         _request['payload']['action'] = action
-        
-        # Relay the action in a non-blocking way to the server 
+
+        # Relay the action in a non-blocking way to the server
         # so that it can start doing an env.step on it in ~ parallel
         self._remote_request(_request, blocking=False)
 
diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py
index 05601455549a2e1bf44f7bd9041229ba8d2cb80e..ce4cb8cffef8c93fddb79e92d06f56fd92e144b8 100644
--- a/flatland/evaluators/service.py
+++ b/flatland/evaluators/service.py
@@ -273,7 +273,7 @@ class FlatlandRemoteEvaluationService:
         )
         if self.verbose:
             print("Received Request : ", command)
-        
+
         message_queue_latency = time.time() - command["timestamp"]
         self.update_running_mean_stats("message_queue_latency", message_queue_latency)
         return command
@@ -335,13 +335,9 @@ class FlatlandRemoteEvaluationService:
                 test_env_file_path
             )
             del self.env
-            self.env = RailEnv(
-                width=1,
-                height=1,
-                rail_generator=rail_from_file(test_env_file_path),
-                schedule_generator=schedule_from_file(test_env_file_path),
-                obs_builder_object=DummyObservationBuilder()
-            )
+            self.env = RailEnv(width=1, height=1, rail_generator=rail_from_file(test_env_file_path),
+                               schedule_generator=schedule_from_file(test_env_file_path),
+                               obs_builder_object=DummyObservationBuilder())
 
             if self.begin_simulation:
                 # If begin simulation has already been initialized
diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py
index f8c9afd0358d42c2829dc9b7c1fd7f3ad5198a3e..c309f9eb3b56c82b872b1842f30eace25a70026a 100644
--- a/flatland/utils/editor.py
+++ b/flatland/utils/editor.py
@@ -24,10 +24,7 @@ class EditorMVC(object):
         """ Create an Editor MVC assembly around a railenv, or create one if None.
         """
         if env is None:
-            env = RailEnv(width=10,
-                          height=10,
-                          rail_generator=empty_rail_generator(),
-                          number_of_agents=0,
+            env = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
                           obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
         env.reset()
@@ -669,11 +666,8 @@ class EditorModel(object):
             fnMethod = complex_rail_generator(nr_start_goal=nAgents, nr_extra=20, min_dist=12, seed=int(time.time()))
 
         if env is None:
-            self.env = RailEnv(width=self.regen_size_width,
-                               height=self.regen_size_height,
-                               rail_generator=fnMethod,
-                               number_of_agents=nAgents,
-                               obs_builder_object=TreeObsForRailEnv(max_depth=2))
+            self.env = RailEnv(width=self.regen_size_width, height=self.regen_size_height, rail_generator=fnMethod,
+                               number_of_agents=nAgents, obs_builder_object=TreeObsForRailEnv(max_depth=2))
         else:
             self.env = env
         self.env.reset(regenerate_rail=True)
diff --git a/tests/test_distance_map.py b/tests/test_distance_map.py
index 3bed89b8ce0947c86593e2f1680ef6082f321d84..22cea8280d377b7b7b8a118a4ba1fe3d35e972a7 100644
--- a/tests/test_distance_map.py
+++ b/tests/test_distance_map.py
@@ -25,14 +25,10 @@ def test_walker():
     rail = GridTransitionMap(width=rail_map.shape[1],
                              height=rail_map.shape[0], transitions=transitions)
     rail.grid = rail_map
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2,
-                                                       predictor=ShortestPathPredictorForRailEnv(max_depth=10)),
-                  )
+                                                       predictor=ShortestPathPredictorForRailEnv(max_depth=10)))
     # reset to initialize agents_static
     env.reset()
 
diff --git a/tests/test_flaltland_rail_agent_status.py b/tests/test_flaltland_rail_agent_status.py
index e70012f8e4e77017c7dde4c3f1287e4d3bf72278..a573e55d0eef96d30189b483d6478b84653e1244 100644
--- a/tests/test_flaltland_rail_agent_status.py
+++ b/tests/test_flaltland_rail_agent_status.py
@@ -16,14 +16,10 @@ np.random.seed(1)
 def test_initial_status():
     """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  remove_agents_at_target=False
-                  )
+                  remove_agents_at_target=False)
     env.reset()
     set_penalties_for_replay(env)
     test_config = ReplayConfig(
@@ -127,14 +123,10 @@ def test_initial_status():
 def test_status_done_remove():
     """Test that agent lifecycle works correctly ready-to-depart -> active -> done."""
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  remove_agents_at_target=True
-                  )
+                  remove_agents_at_target=True)
     env.reset()
 
     set_penalties_for_replay(env)
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
index 0913e45959d08230a815c33d98fb6de8eb99d956..8bc7235edbbed51d65818b1c4de5197b7455ddbe 100644
--- a/tests/test_flatland_core_transition_map.py
+++ b/tests/test_flatland_core_transition_map.py
@@ -69,13 +69,9 @@ def check_path(env, rail, position, direction, target, expected, rendering=False
 
 def test_path_exists(rendering=False):
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 
     # reset to initialize agents_static
     env.reset()
@@ -135,13 +131,9 @@ def test_path_exists(rendering=False):
 
 def test_path_not_exists(rendering=False):
     rail, rail_map = make_simple_rail_unconnected()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 
     # reset to initialize agents_static
     env.reset()
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index f425636467ec7cefa0169db006122999b862308a..5543f3912aa4aec750ad63024c7371b02de30309 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -20,11 +20,8 @@ from flatland.utils.simple_rail import make_simple_rail
 def test_global_obs():
     rail, rail_map = make_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv())
 
     global_obs, info = env.reset()
@@ -95,13 +92,9 @@ def _step_along_shortest_path(env, obs_builder, rail):
 
 def test_reward_function_conflict(rendering=False):
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=2,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     obs_builder: TreeObsForRailEnv = env.obs_builder
     # initialize agents_static
     env.reset()
@@ -176,14 +169,10 @@ def test_reward_function_conflict(rendering=False):
 
 def test_reward_function_waiting(rendering=False):
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=2,
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  remove_agents_at_target=False
-                  )
+                  remove_agents_at_target=False)
     obs_builder: TreeObsForRailEnv = env.obs_builder
     # initialize agents_static
     env.reset()
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 280d1d1143d06b9b00832b9c3eb6cbf4add0ffb2..45e0bdda32fdbf4135c9400d655e49f146bce08c 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -21,13 +21,9 @@ from flatland.utils.simple_rail import make_simple_rail, make_simple_rail2, make
 def test_dummy_predictor(rendering=False):
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)))
     # reset to initialize agents_static
     env.reset()
 
@@ -113,13 +109,9 @@ def test_dummy_predictor(rendering=False):
 
 def test_shortest_path_predictor(rendering=False):
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 
     # reset to initialize agents_static
     env.reset()
@@ -251,13 +243,9 @@ def test_shortest_path_predictor(rendering=False):
 
 def test_shortest_path_predictor_conflicts(rendering=False):
     rail, rail_map = make_invalid_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=2,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     # initialize agents_static
     env.reset()
 
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index dc4c78f9a6796d8eef3cfbeb4c54409f14406415..e6550f17aad79bc6685e716249d790aa0acb8bf7 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -30,8 +30,7 @@ def test_load_env():
 def test_save_load():
     env = RailEnv(width=10, height=10,
                   rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
-                  schedule_generator=complex_schedule_generator(),
-                  number_of_agents=2)
+                  schedule_generator=complex_schedule_generator(), number_of_agents=2)
     env.reset()
     agent_1_pos = env.agents_static[0].position
     agent_1_dir = env.agents_static[0].direction
@@ -78,11 +77,8 @@ def test_rail_environment_single_agent():
 
     rail = GridTransitionMap(width=3, height=3, transitions=transitions)
     rail.grid = rail_map
-    rail_env = RailEnv(width=3,
-                       height=3,
-                       rail_generator=rail_from_grid_transition_map(rail),
-                       schedule_generator=random_schedule_generator(),
-                       number_of_agents=1,
+    rail_env = RailEnv(width=3, height=3, rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
     for _ in range(200):
@@ -155,11 +151,9 @@ def test_dead_end():
                              transitions=transitions)
 
     rail.grid = rail_map
-    rail_env = RailEnv(width=rail_map.shape[1],
-                       height=rail_map.shape[0],
+    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
-                       schedule_generator=random_schedule_generator(),
-                       number_of_agents=1,
+                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
     # We try the configuration in the 4 directions:
@@ -180,11 +174,9 @@ def test_dead_end():
                              transitions=transitions)
 
     rail.grid = rail_map
-    rail_env = RailEnv(width=rail_map.shape[1],
-                       height=rail_map.shape[0],
+    rail_env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
                        rail_generator=rail_from_grid_transition_map(rail),
-                       schedule_generator=random_schedule_generator(),
-                       number_of_agents=1,
+                       schedule_generator=random_schedule_generator(), number_of_agents=1,
                        obs_builder_object=GlobalObsForRailEnv())
 
     rail_env.reset()
@@ -198,13 +190,9 @@ def test_dead_end():
 
 def test_get_entry_directions():
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
 
     def _assert(position, expected):
@@ -236,13 +224,9 @@ def test_rail_env_reset():
 
     rail, rail_map = make_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=3,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=3,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
     env.save(file_name)
     dist_map_shape = np.shape(env.distance_map.get())
@@ -250,13 +234,9 @@ def test_rail_env_reset():
     rails_initial = env.rail.grid
     agents_initial = env.agents
 
-    env2 = RailEnv(width=1,
-                  height=1,
-                  rail_generator=rail_from_file(file_name),
-                  schedule_generator=schedule_from_file(file_name),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
+                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env2.reset(False, False, False)
     rails_loaded = env2.rail.grid
     agents_loaded = env2.agents
@@ -264,13 +244,9 @@ def test_rail_env_reset():
     assert np.all(np.array_equal(rails_initial, rails_loaded))
     assert agents_initial == agents_loaded
 
-    env3 = RailEnv(width=1,
-                  height=1,
-                  rail_generator=rail_from_file(file_name),
-                  schedule_generator=schedule_from_file(file_name),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
+                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env3.reset(False, True, False)
     rails_loaded = env3.rail.grid
     agents_loaded = env3.agents
@@ -278,13 +254,9 @@ def test_rail_env_reset():
     assert np.all(np.array_equal(rails_initial, rails_loaded))
     assert agents_initial == agents_loaded
 
-    env4 = RailEnv(width=1,
-                  height=1,
-                  rail_generator=rail_from_file(file_name),
-                  schedule_generator=schedule_from_file(file_name),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
+                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env4.reset(True, False, False)
     rails_loaded = env4.rail.grid
     agents_loaded = env4.agents
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
index dd64d370077ab12950f0189065c15652e6ad1c6d..344739b798dd2d36136ff0c35698ce0025fc781d 100644
--- a/tests/test_flatland_envs_rail_env_shortest_paths.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -16,13 +16,9 @@ from flatland.utils.simple_rail import make_disconnected_simple_rail
 def test_get_shortest_paths_unreachable():
     rail, rail_map = make_disconnected_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)))
     env.reset()
 
     # set the initial position
diff --git a/tests/test_flatland_envs_sparse_rail_generator.py b/tests/test_flatland_envs_sparse_rail_generator.py
index a1d0fb17ffbf48f77dbad5d7a01acc20e56e30f6..1502ab34c02647d7dda3aed9c3037597ac3bfcc0 100644
--- a/tests/test_flatland_envs_sparse_rail_generator.py
+++ b/tests/test_flatland_envs_sparse_rail_generator.py
@@ -1,9 +1,11 @@
 import random
-
-import numpy as np
 import unittest
 import warnings
+
+import numpy as np
+
 from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
+from flatland.envs.malfunction_generators import malfunction_from_params
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
@@ -14,17 +16,13 @@ from flatland.utils.rendertools import RenderTool
 def test_sparse_rail_generator():
     np.random.seed(0)
     random.seed(0)
-    env = RailEnv(width=50,
-                  height=50,
-                  rail_generator=sparse_rail_generator(max_num_cities=10,
-                                                       max_rails_between_cities=3,
-                                                       seed=5,
-                                                       grid_mode=False
-                                                       ),
-                  schedule_generator=sparse_schedule_generator(),
-                  number_of_agents=10,
-                  obs_builder_object=GlobalObsForRailEnv()
-                  )
+    env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
+                                                                            max_rails_between_cities=3,
+                                                                            seed=5,
+                                                                            grid_mode=False
+                                                                            ),
+                  schedule_generator=sparse_schedule_generator(), number_of_agents=10,
+                  obs_builder_object=GlobalObsForRailEnv())
     env.reset(False, False, True)
     # for r in range(env.height):
     #    for c in range (env.width):
@@ -554,17 +552,13 @@ def test_sparse_rail_generator_deterministic():
                         1. / 3.: 0.,  # Slow commuter train
                         1. / 4.: 0.}  # Slow freight train
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=sparse_rail_generator(max_num_cities=5,
-                                                       max_rails_between_cities=3,
-                                                       seed=215545,  # Random seed
-                                                       grid_mode=True
-                                                       ),
-                  schedule_generator=sparse_schedule_generator(speed_ration_map),
-                  number_of_agents=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
+                                                                            max_rails_between_cities=3,
+                                                                            seed=215545,  # Random seed
+                                                                            grid_mode=True
+                                                                            ),
+                  schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data))
     env.reset()
     # for r in range(env.height):
     #   for c in range(env.width):
@@ -1323,42 +1317,30 @@ def test_sparse_rail_generator_deterministic():
     assert env.rail.get_full_transitions(29, 24) == 0, "[29][24]"
 
 
-
 def test_rail_env_action_required_info():
-
     np.random.seed(0)
     random.seed(0)
     speed_ration_map = {1.: 0.25,  # Fast passenger train
                         1. / 2.: 0.25,  # Fast freight train
                         1. / 3.: 0.25,  # Slow commuter train
                         1. / 4.: 0.25}  # Slow freight train
-    env_always_action = RailEnv(width=50,
-                                height=50,
-                                rail_generator=sparse_rail_generator(
-                                    max_num_cities=10,
-                                    max_rails_between_cities=3,
-                                    seed=5,  # Random seed
-                                    grid_mode=False  # Ordered distribution of nodes
-                                ),
-                                schedule_generator=sparse_schedule_generator(speed_ration_map),
-                                number_of_agents=10,
-                                obs_builder_object=GlobalObsForRailEnv(),
-                                remove_agents_at_target=False)
+    env_always_action = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
+        max_num_cities=10,
+        max_rails_between_cities=3,
+        seed=5,  # Random seed
+        grid_mode=False  # Ordered distribution of nodes
+    ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10,
+                                obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
     np.random.seed(0)
     random.seed(0)
-    env_only_if_action_required = RailEnv(width=50,
-                                          height=50,
-                                          rail_generator=sparse_rail_generator(
-                                              max_num_cities=10,
-                                              max_rails_between_cities=3,
-                                              seed=5,  # Random seed
-                                              grid_mode=False
-                                              # Ordered distribution of nodes
-                                          ),
-                                          schedule_generator=sparse_schedule_generator(speed_ration_map),
-                                          number_of_agents=10,
-                                          obs_builder_object=GlobalObsForRailEnv(),
-                                          remove_agents_at_target=False)
+    env_only_if_action_required = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
+        max_num_cities=10,
+        max_rails_between_cities=3,
+        seed=5,  # Random seed
+        grid_mode=False
+        # Ordered distribution of nodes
+    ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=10,
+                                          obs_builder_object=GlobalObsForRailEnv(), remove_agents_at_target=False)
     env_renderer = RenderTool(env_always_action, gl="PILSVG", )
 
     env_always_action.reset(False, False, True)
@@ -1418,17 +1400,14 @@ def test_rail_env_malfunction_speed_info():
                        'min_duration': 3,  # Minimal duration of malfunction
                        'max_duration': 10  # Max duration of malfunction
                        }
-    env = RailEnv(width=50,
-                  height=50,
-                  rail_generator=sparse_rail_generator(max_num_cities=10,
-                                                       max_rails_between_cities=3,
-                                                       seed=5,
-                                                       grid_mode=False
-                                                       ),
-                  schedule_generator=sparse_schedule_generator(),
-                  number_of_agents=10,
+    env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=10,
+                                                                            max_rails_between_cities=3,
+                                                                            seed=5,
+                                                                            grid_mode=False
+                                                                            ),
+                  schedule_generator=sparse_schedule_generator(), number_of_agents=10,
                   obs_builder_object=GlobalObsForRailEnv(),
-                  stochastic_data=stochastic_data)
+                  malfunction_generator=malfunction_from_params(stochastic_data))
     env.reset(False, False, True)
 
     env_renderer = RenderTool(env, gl="PILSVG", )
@@ -1458,17 +1437,12 @@ def test_rail_env_malfunction_speed_info():
 def test_sparse_generator_with_too_man_cities_does_not_break_down():
     np.random.seed(0)
 
-    RailEnv(width=50,
-            height=50,
-            rail_generator=sparse_rail_generator(
-                max_num_cities=100,
-                max_rails_between_cities=3,
-                seed=5,
-                grid_mode=False
-            ),
-            schedule_generator=sparse_schedule_generator(),
-            number_of_agents=10,
-            obs_builder_object=GlobalObsForRailEnv())
+    RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(
+        max_num_cities=100,
+        max_rails_between_cities=3,
+        seed=5,
+        grid_mode=False
+    ), schedule_generator=sparse_schedule_generator(), number_of_agents=10, obs_builder_object=GlobalObsForRailEnv())
 
 
 def test_sparse_generator_with_illegal_params_aborts():
@@ -1477,29 +1451,21 @@ def test_sparse_generator_with_illegal_params_aborts():
     """
     np.random.seed(0)
     with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, SystemExit):
-        RailEnv(width=6,
-                height=6,
-                rail_generator=sparse_rail_generator(
-                    max_num_cities=100,
-                    max_rails_between_cities=3,
-                    seed=5,
-                    grid_mode=False
-                ),
-                schedule_generator=sparse_schedule_generator(),
-                number_of_agents=10,
+        RailEnv(width=6, height=6, rail_generator=sparse_rail_generator(
+            max_num_cities=100,
+            max_rails_between_cities=3,
+            seed=5,
+            grid_mode=False
+        ), schedule_generator=sparse_schedule_generator(), number_of_agents=10,
                 obs_builder_object=GlobalObsForRailEnv()).reset()
 
     with unittest.TestCase.assertRaises(test_sparse_generator_with_illegal_params_aborts, SystemExit):
-        RailEnv(width=60,
-                height=60,
-                rail_generator=sparse_rail_generator(
-                    max_num_cities=1,
-                    max_rails_between_cities=3,
-                    seed=5,
-                    grid_mode=False
-                ),
-                schedule_generator=sparse_schedule_generator(),
-                number_of_agents=10,
+        RailEnv(width=60, height=60, rail_generator=sparse_rail_generator(
+            max_num_cities=1,
+            max_rails_between_cities=3,
+            seed=5,
+            grid_mode=False
+        ), schedule_generator=sparse_schedule_generator(), number_of_agents=10,
                 obs_builder_object=GlobalObsForRailEnv()).reset()
 
 
@@ -1514,17 +1480,12 @@ def test_sparse_generator_changes_to_grid_mode():
 
     for test_run in range(10):
         with warnings.catch_warnings(record=True) as w:
-            RailEnv(width=10,
-                    height=20,
-                    rail_generator=sparse_rail_generator(
-                        max_num_cities=100,
-                        max_rails_between_cities=2,
-                        max_rails_in_city=2,
-                        seed=5,
-                        grid_mode=False
-                    ),
-                    schedule_generator=sparse_schedule_generator(),
-                    number_of_agents=10,
+            RailEnv(width=10, height=20, rail_generator=sparse_rail_generator(
+                max_num_cities=100,
+                max_rails_between_cities=2,
+                max_rails_in_city=2,
+                seed=5,
+                grid_mode=False
+            ), schedule_generator=sparse_schedule_generator(), number_of_agents=10,
                     obs_builder_object=GlobalObsForRailEnv()).reset()
             assert "[WARNING]" in str(w[-1].message)
-
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index e18685f6c68f757f50075ebf206423b562514dd6..14f9b6c0295306448d27a83444e3dd6496485cf1 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -8,6 +8,7 @@ from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.malfunction_generators import malfunction_from_params
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import rail_from_grid_transition_map
 from flatland.envs.schedule_generators import random_schedule_generator
@@ -72,14 +73,9 @@ def test_malfunction_process():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  obs_builder_object=SingleAgentNavigationObs()
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
     # reset to initialize agents_static
     obs, info = env.reset(False, False, True, random_seed=10)
 
@@ -126,14 +122,9 @@ def test_malfunction_process_statistically():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=10,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  obs_builder_object=SingleAgentNavigationObs()
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=10,
+                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
 
     # reset to initialize agents_static
     env.reset(True, True, False, random_seed=10)
@@ -173,14 +164,9 @@ def test_malfunction_before_entry():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=1),  # seed 12
-                  number_of_agents=10,
-                  random_seed=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=1), number_of_agents=10,
+                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
     # reset to initialize agents_static
     env.reset(False, False, False, random_seed=10)
     env.agents[0].target = (0, 0)
@@ -216,14 +202,9 @@ def test_malfunction_values_and_behavior():
     stochastic_data = {'malfunction_rate': 0.001,
                        'min_duration': 10,
                        'max_duration': 10}
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
-                  stochastic_data=stochastic_data,
-                  number_of_agents=1,
-                  random_seed=1,
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
 
     # reset to initialize agents_static
     env.reset(False, False, activate_agents=True, random_seed=10)
@@ -246,14 +227,9 @@ def test_initial_malfunction():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=10),
-                  number_of_agents=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  obs_builder_object=SingleAgentNavigationObs()
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=10), number_of_agents=1,
+                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
     # reset to initialize agents_static
     env.reset(False, False, True, random_seed=10)
     print(env.agents[0].malfunction_data)
@@ -318,14 +294,9 @@ def test_initial_malfunction_stop_moving():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  obs_builder_object=SingleAgentNavigationObs()
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=SingleAgentNavigationObs(), malfunction_generator=malfunction_from_params(stochastic_data))
     env.reset()
 
     print(env.agents[0].initial_position, env.agents[0].direction, env.agents[0].position, env.agents[0].status)
@@ -409,13 +380,9 @@ def test_initial_malfunction_do_nothing():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data))
     # reset to initialize agents_static
     env.reset()
     set_penalties_for_replay(env)
@@ -492,14 +459,9 @@ def tests_random_interference_from_outside():
                        'max_duration': 10}
 
     rail, rail_map = make_simple_rail2()
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
-                  number_of_agents=1,
-                  random_seed=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
     env.reset()
     # reset to initialize agents_static
     env.agents[0].speed_data['speed'] = 0.33
@@ -523,14 +485,9 @@ def tests_random_interference_from_outside():
     rail, rail_map = make_simple_rail2()
     random.seed(47)
     np.random.seed(1234)
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
-                  number_of_agents=1,
-                  random_seed=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
     env.reset()
     # reset to initialize agents_static
     env.agents[0].speed_data['speed'] = 0.33
@@ -565,14 +522,9 @@ def test_last_malfunction_step():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=2),  # seed 12
-                  number_of_agents=1,
-                  random_seed=1,
-                  stochastic_data=stochastic_data,  # Malfunction data generator
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=2), number_of_agents=1,
+                  malfunction_generator=malfunction_from_params(stochastic_data), random_seed=1)
     env.reset()
     # reset to initialize agents_static
     env.agents[0].speed_data['speed'] = 1. / 3.
diff --git a/tests/test_flatland_utils_rendertools.py b/tests/test_flatland_utils_rendertools.py
index 6e1fb2441d4428b24da3c37d764a7676f3929a2a..6ed92fefb0c81512fc6006cbf44b6d55a274caf3 100644
--- a/tests/test_flatland_utils_rendertools.py
+++ b/tests/test_flatland_utils_rendertools.py
@@ -37,11 +37,8 @@ def checkFrozenImage(oRT, sFileImage, resave=False):
 
 def test_render_env(save_new_images=False):
     np.random.seed(100)
-    oEnv = RailEnv(width=10, height=10,
-                   rail_generator=empty_rail_generator(),
-                   number_of_agents=0,
-                   obs_builder_object=TreeObsForRailEnv(max_depth=2)
-                   )
+    oEnv = RailEnv(width=10, height=10, rail_generator=empty_rail_generator(), number_of_agents=0,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2))
     oEnv.reset()
     oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
     oRT = rt.RenderTool(oEnv, gl="PILSVG")
diff --git a/tests/test_generators.py b/tests/test_generators.py
index 1e69223daebd24c52137e12eed9dc43d188a9bbd..83ef0d76360d1c662c9e181c6f88f41adb62ebbf 100644
--- a/tests/test_generators.py
+++ b/tests/test_generators.py
@@ -20,11 +20,7 @@ def test_empty_rail_generator():
     y_dim = 10
 
     # Check that a random level at with correct parameters is generated
-    env = RailEnv(width=x_dim,
-                  height=y_dim,
-                  number_of_agents=n_agents,
-                  rail_generator=empty_rail_generator()
-                  )
+    env = RailEnv(width=x_dim, height=y_dim, rail_generator=empty_rail_generator(), number_of_agents=n_agents)
     env.reset()
     # Check the dimensions
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -41,11 +37,7 @@ def test_random_rail_generator():
     y_dim = 10
 
     # Check that a random level at with correct parameters is generated
-    env = RailEnv(width=x_dim,
-                  height=y_dim,
-                  number_of_agents=n_agents,
-                  rail_generator=random_rail_generator()
-                  )
+    env = RailEnv(width=x_dim, height=y_dim, rail_generator=random_rail_generator(), number_of_agents=n_agents)
     env.reset()
     assert env.rail.grid.shape == (y_dim, x_dim)
     assert env.get_num_agents() == n_agents
@@ -59,12 +51,9 @@ def test_complex_rail_generator():
     min_dist = 4
 
     # Check that agent number is changed to fit generated level
-    env = RailEnv(width=x_dim,
-                  height=y_dim,
-                  number_of_agents=n_agents,
+    env = RailEnv(width=x_dim, height=y_dim,
                   rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
-                  schedule_generator=complex_schedule_generator()
-                  )
+                  schedule_generator=complex_schedule_generator(), number_of_agents=n_agents)
     env.reset()
     assert env.get_num_agents() == 2
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -72,12 +61,9 @@ def test_complex_rail_generator():
     min_dist = 2 * x_dim
 
     # Check that no agents are generated when level cannot be generated
-    env = RailEnv(width=x_dim,
-                  height=y_dim,
-                  number_of_agents=n_agents,
+    env = RailEnv(width=x_dim, height=y_dim,
                   rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
-                  schedule_generator=complex_schedule_generator()
-                  )
+                  schedule_generator=complex_schedule_generator(), number_of_agents=n_agents)
     env.reset()
     assert env.get_num_agents() == 0
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -87,12 +73,9 @@ def test_complex_rail_generator():
     n_start = 5
     n_agents = 5
 
-    env = RailEnv(width=x_dim,
-                  height=y_dim,
-                  number_of_agents=n_agents,
+    env = RailEnv(width=x_dim, height=y_dim,
                   rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist),
-                  schedule_generator=complex_schedule_generator()
-                  )
+                  schedule_generator=complex_schedule_generator(), number_of_agents=n_agents)
     env.reset()
     assert env.get_num_agents() == n_agents
     assert env.rail.grid.shape == (y_dim, x_dim)
@@ -101,12 +84,8 @@ def test_complex_rail_generator():
 def test_rail_from_grid_transition_map():
     rail, rail_map = make_simple_rail()
     n_agents = 3
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=n_agents
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=n_agents)
     env.reset(False, False, True)
     nr_rail_elements = np.count_nonzero(env.rail.grid)
 
@@ -127,13 +106,9 @@ def tests_rail_from_file():
 
     rail, rail_map = make_simple_rail()
 
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=3,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=3,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
     env.save(file_name)
     dist_map_shape = np.shape(env.distance_map.get())
@@ -141,13 +116,9 @@ def tests_rail_from_file():
     rails_initial = env.rail.grid
     agents_initial = env.agents
 
-    env = RailEnv(width=1,
-                  height=1,
-                  rail_generator=rail_from_file(file_name),
-                  schedule_generator=schedule_from_file(file_name),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
+                  schedule_generator=schedule_from_file(file_name), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
     rails_loaded = env.rail.grid
     agents_loaded = env.agents
@@ -163,13 +134,9 @@ def tests_rail_from_file():
 
     file_name_2 = "test_without_distance_map.pkl"
 
-    env2 = RailEnv(width=rail_map.shape[1],
-                   height=rail_map.shape[0],
-                   rail_generator=rail_from_grid_transition_map(rail),
-                   schedule_generator=random_schedule_generator(),
-                   number_of_agents=3,
-                   obs_builder_object=GlobalObsForRailEnv(),
-                   )
+    env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
+                   rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(),
+                   number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
     env2.reset()
     env2.save(file_name_2)
 
@@ -177,13 +144,9 @@ def tests_rail_from_file():
     rails_initial_2 = env2.rail.grid
     agents_initial_2 = env2.agents
 
-    env2 = RailEnv(width=1,
-                   height=1,
-                   rail_generator=rail_from_file(file_name_2),
-                   schedule_generator=schedule_from_file(file_name_2),
-                   number_of_agents=1,
-                   obs_builder_object=GlobalObsForRailEnv(),
-                   )
+    env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
+                   schedule_generator=schedule_from_file(file_name_2), number_of_agents=1,
+                   obs_builder_object=GlobalObsForRailEnv())
     env2.reset()
     rails_loaded_2 = env2.rail.grid
     agents_loaded_2 = env2.agents
@@ -194,13 +157,9 @@ def tests_rail_from_file():
 
     # Test to save with distance map and load without
 
-    env3 = RailEnv(width=1,
-                   height=1,
-                   rail_generator=rail_from_file(file_name),
-                   schedule_generator=schedule_from_file(file_name),
-                   number_of_agents=1,
-                   obs_builder_object=GlobalObsForRailEnv(),
-                   )
+    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
+                   schedule_generator=schedule_from_file(file_name), number_of_agents=1,
+                   obs_builder_object=GlobalObsForRailEnv())
     env3.reset()
     rails_loaded_3 = env3.rail.grid
     agents_loaded_3 = env3.agents
@@ -212,13 +171,9 @@ def tests_rail_from_file():
     # Test to save without distance map and load with generating distance map
 
     # initialize agents_static
-    env4 = RailEnv(width=1,
-                   height=1,
-                   rail_generator=rail_from_file(file_name_2),
-                   schedule_generator=schedule_from_file(file_name_2),
-                   number_of_agents=1,
-                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
-                   )
+    env4 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
+                   schedule_generator=schedule_from_file(file_name_2), number_of_agents=1,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2))
     env4.reset()
     rails_loaded_4 = env4.rail.grid
     agents_loaded_4 = env4.agents
diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py
index d3bcf7779dbc4c8dbafe6e726aeff33c757c25fb..afaf2b7f496ef31c7d5228d1f389c254cc16df2e 100644
--- a/tests/test_global_observation.py
+++ b/tests/test_global_observation.py
@@ -22,16 +22,13 @@ def test_get_global_observation():
                         1. / 3.: 0.25,  # Slow commuter train
                         1. / 4.: 0.25}  # Slow freight train
 
-    env = RailEnv(width=50,
-                  height=50,
-                  rail_generator=sparse_rail_generator(max_num_cities=6,
-                                                       max_rails_between_cities=4,
-                                                       seed=15,
-                                                       grid_mode=False
-                                                       ),
-                  schedule_generator=sparse_schedule_generator(speed_ration_map),
-                  number_of_agents=number_of_agents, stochastic_data=stochastic_data,  # Malfunction data generator
-                  obs_builder_object=GlobalObsForRailEnv())
+    env = RailEnv(width=50, height=50, rail_generator=sparse_rail_generator(max_num_cities=6,
+                                                                            max_rails_between_cities=4,
+                                                                            seed=15,
+                                                                            grid_mode=False
+                                                                            ),
+                  schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents,
+                  obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data)
     env.reset()
 
     obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)})
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index f83990cc39bf73e50719b2291006eed68d1d1360..f5fc66662149816c9345be738d8bd30a8bc8306f 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -49,11 +49,9 @@ class RandomAgent:
 
 
 def test_multi_speed_init():
-    env = RailEnv(width=50,
-                  height=50,
+    env = RailEnv(width=50, height=50,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
-                                                        seed=1),
-                  schedule_generator=complex_schedule_generator(),
+                                                        seed=1), schedule_generator=complex_schedule_generator(),
                   number_of_agents=5)
     # Initialize the agent with the parameters corresponding to the environment and observation_builder
     agent = RandomAgent(218, 4)
@@ -97,13 +95,9 @@ def test_multi_speed_init():
 def test_multispeed_actions_no_malfunction_no_blocking():
     """Test that actions are correctly performed on cell exit for a single agent."""
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
 
     set_penalties_for_replay(env)
@@ -201,13 +195,9 @@ def test_multispeed_actions_no_malfunction_no_blocking():
 def test_multispeed_actions_no_malfunction_blocking():
     """The second agent blocks the first because it is slower."""
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=2,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
     set_penalties_for_replay(env)
     test_configs = [
@@ -389,13 +379,9 @@ def test_multispeed_actions_no_malfunction_blocking():
 def test_multispeed_actions_malfunction_no_blocking():
     """Test on a single agent whether action on cell exit work correctly despite malfunction."""
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
 
     set_penalties_for_replay(env)
@@ -527,13 +513,9 @@ def test_multispeed_actions_malfunction_no_blocking():
 def test_multispeed_actions_no_malfunction_invalid_actions():
     """Test that actions are correctly performed on cell exit for a single agent."""
     rail, rail_map = make_simple_rail()
-    env = RailEnv(width=rail_map.shape[1],
-                  height=rail_map.shape[0],
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(),
-                  number_of_agents=1,
-                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
-                  )
+    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(), number_of_agents=1,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
     env.reset()
 
     set_penalties_for_replay(env)
diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py
index a1981a9a0d2ed80b98f8d1e1fc8a49e14624afae..4ce04e5e5e07ca5b5e864efffea951221322f306 100644
--- a/tests/test_random_seeding.py
+++ b/tests/test_random_seeding.py
@@ -14,12 +14,8 @@ def test_random_seeding():
 
     # Move target to unreachable position in order to not interfere with test
     for idx in range(100):
-        env = RailEnv(width=25,
-                      height=30,
-                      rail_generator=rail_from_grid_transition_map(rail),
-                      schedule_generator=random_schedule_generator(seed=12),
-                      number_of_agents=10
-                      )
+        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                      schedule_generator=random_schedule_generator(seed=12), number_of_agents=10)
         env.reset(True, True, False, random_seed=1)
 
         env.agents[0].target = (0, 0)
@@ -52,21 +48,13 @@ def test_seeding_and_observations():
 
     # Make two seperate envs with different observation builders
     # Global Observation
-    env = RailEnv(width=25,
-                  height=30,
-                  rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=12),
-                  number_of_agents=10,
-                  obs_builder_object=GlobalObsForRailEnv()
-                  )
+    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(seed=12), number_of_agents=10,
+                  obs_builder_object=GlobalObsForRailEnv())
     # Tree Observation
-    env2 = RailEnv(width=25,
-                   height=30,
-                   rail_generator=rail_from_grid_transition_map(rail),
-                   schedule_generator=random_schedule_generator(seed=12),
-                   number_of_agents=10,
-                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
-                   )
+    env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                   schedule_generator=random_schedule_generator(seed=12), number_of_agents=10,
+                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
 
     env.reset(False, False, False, random_seed=12)
     env2.reset(False, False, False, random_seed=12)
@@ -118,24 +106,14 @@ def test_seeding_and_malfunction():
     # Make two seperate envs with different and see if the exhibit the same malfunctions
     # Global Observation
     for tests in range(1, 100):
-        env = RailEnv(width=25,
-                      height=30,
-                      rail_generator=rail_from_grid_transition_map(rail),
-                      schedule_generator=random_schedule_generator(),
-                      number_of_agents=10,
-                      obs_builder_object=GlobalObsForRailEnv(),
-                      stochastic_data=stochastic_data,  # Malfunction data generator
-                      )
+        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                      schedule_generator=random_schedule_generator(), number_of_agents=10,
+                      obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data)
 
         # Tree Observation
-        env2 = RailEnv(width=25,
-                       height=30,
-                       rail_generator=rail_from_grid_transition_map(rail),
-                       schedule_generator=random_schedule_generator(),
-                       number_of_agents=10,
-                       obs_builder_object=GlobalObsForRailEnv(),
-                       stochastic_data=stochastic_data,  # Malfunction data generator
-                       )
+        env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
+                       schedule_generator=random_schedule_generator(), number_of_agents=10,
+                       obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data)
 
         env.reset(True, False, True, random_seed=tests)
         env2.reset(True, False, True, random_seed=tests)
diff --git a/tests/test_speed_classes.py b/tests/test_speed_classes.py
index 1fcf3b3ef0b7cc176d345fa547e91eeeef0a05bd..c1c03c3676cb4b6847265af02432ba8b37ca4cfe 100644
--- a/tests/test_speed_classes.py
+++ b/tests/test_speed_classes.py
@@ -18,11 +18,9 @@ def test_speed_initialization_helper():
 def test_rail_env_speed_intializer():
     speed_ratio_map = {1: 0.3, 2: 0.4, 3: 0.1, 5: 0.2}
 
-    env = RailEnv(width=50,
-                  height=50,
+    env = RailEnv(width=50, height=50,
                   rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999,
-                                                        seed=1),
-                  schedule_generator=complex_schedule_generator(),
+                                                        seed=1), schedule_generator=complex_schedule_generator(),
                   number_of_agents=10)
     env.reset()
     actual_speeds = list(map(lambda agent: agent.speed_data['speed'], env.agents))