diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index 89af545ba26e59b3cc399b4b775ba5846a7b90a5..129da81cb04cc3a3c02fa626036d8e09cb942cd2 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -42,7 +42,7 @@ def malfunction_from_file(filename) -> Tuple[MalfunctionGenerator, MalfunctionPr
         load_data = file_in.read()
     data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
     # TODO: make this better by using namedtuple in the pickle file
-    data['malfunction'] =MalfunctionProcessData._make(data['malfunction'])
+    data['malfunction'] = MalfunctionProcessData._make(data['malfunction'])
     if "malfunction" in data:
         # Mean malfunction in number of time steps
         mean_malfunction_rate = data["malfunction"].malfunction_rate
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 9025d415eef78c3b68495b1a5ae09c45652811bb..81c3a3569642c680b5b0c7246fb6d0b94685d9ed 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -19,8 +19,7 @@ from flatland.core.grid.grid_utils import IntVector2D
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
 from flatland.envs.distance_map import DistanceMap
-from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionGenerator, \
-    MalfunctionProcessData
+from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
@@ -160,7 +159,7 @@ class RailEnv(Environment):
         """
         super().__init__()
 
-        self.malfunction_generator, self.malfunction_process_data  = malfunction_generator_and_process_data
+        self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
         self.rail: Optional[GridTransitionMap] = None
diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py
index e01c74830959a392a5892ab8a43336613463e937..075edc139b6786933a32c915998c0fe56cb7a76c 100644
--- a/tests/test_malfunction_generators.py
+++ b/tests/test_malfunction_generators.py
@@ -23,12 +23,17 @@ def test_malfanction_from_params():
 
     rail, rail_map = make_simple_rail2()
 
-    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
-                  schedule_generator=random_schedule_generator(seed=10), number_of_agents=1)
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=10,
+                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
+                  )
     env.reset()
-    assert env.mean_malfunction_rate == 1000
-    assert env.min_number_of_steps_broken == 2
-    assert env.max_number_of_steps_broken == 5
+    assert env.malfunction_process_data.malfunction_rate == 1000
+    assert env.malfunction_process_data.min_duration == 2
+    assert env.malfunction_process_data.max_duration == 5
 
 
 def test_malfanction_to_and_from_file():