diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py
index f8d1bc66b4f9a66c9657902aaa67ae42c9fd8c71..99082ddd764c62e16ad6b71bef7901333732cc4b 100644
--- a/flatland/envs/malfunction_generators.py
+++ b/flatland/envs/malfunction_generators.py
@@ -6,7 +6,7 @@ import numpy as np
 from numpy.random.mtrand import RandomState
 
 from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
-from flatland.envs import persistence 
+from flatland.envs import persistence
 
 Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)])
 MalfunctionParameters = NamedTuple('MalfunctionParameters',
@@ -25,7 +25,7 @@ def _malfunction_prob(rate: float) -> float:
     if rate <= 0:
         return 0.
     else:
-        return 1 - np.exp(- (1 / rate))
+        return 1 - np.exp(-rate)
 
 
 def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]:
@@ -42,7 +42,7 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun
     """
     # with open(filename, "rb") as file_in:
     #     load_data = file_in.read()
-    
+
     # if filename.endswith("mpk"):
     #     data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
     # elif filename.endswith("pkl"):
@@ -52,7 +52,7 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun
     if "malfunction" in env_dict:
         env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction'])
     else:
-        oMPD=None
+        oMPD = None
     if oMPD is not None:
         # Mean malfunction in number of time steps
         mean_malfunction_rate = oMPD.malfunction_rate
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 464e7523db805bd1a45441c0666f5d37245439c1..eaa3112708f3f0e5d255b7e454078d9a59e7ca22 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -119,9 +119,9 @@ def test_malfunction_process():
 
 
 def test_malfunction_process_statistically():
-    """Tests hat malfunctions are produced by stochastic_data!"""
+    """Tests that malfunctions are produced by stochastic_data!"""
     # Set fixed malfunction duration for this test
-    stochastic_data = MalfunctionParameters(malfunction_rate=5,  # Rate of malfunction occurence
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/5,  # Rate of malfunction occurence
                                             min_duration=5,  # Minimal duration of malfunction
                                             max_duration=5  # Max duration of malfunction
                                             )
@@ -168,7 +168,7 @@ def test_malfunction_process_statistically():
 def test_malfunction_before_entry():
     """Tests that malfunctions are working properly for agents before entering the environment!"""
     # Set fixed malfunction duration for this test
-    stochastic_data = MalfunctionParameters(malfunction_rate=2,  # Rate of malfunction occurence
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/2,  # Rate of malfunction occurrence
                                             min_duration=10,  # Minimal duration of malfunction
                                             max_duration=10  # Max duration of malfunction
                                             )
@@ -215,7 +215,7 @@ def test_malfunction_values_and_behavior():
 
     rail, rail_map = make_simple_rail2()
     action_dict: Dict[int, RailEnvActions] = {}
-    stochastic_data = MalfunctionParameters(malfunction_rate=0.001,  # Rate of malfunction occurence
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001,  # Rate of malfunction occurence
                                             min_duration=10,  # Minimal duration of malfunction
                                             max_duration=10  # Max duration of malfunction
                                             )
@@ -241,7 +241,7 @@ def test_malfunction_values_and_behavior():
 
 
 def test_initial_malfunction():
-    stochastic_data = MalfunctionParameters(malfunction_rate=1000,  # Rate of malfunction occurence
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/1000,  # Rate of malfunction occurence
                                             min_duration=2,  # Minimal duration of malfunction
                                             max_duration=5  # Max duration of malfunction
                                             )
@@ -390,7 +390,7 @@ def test_initial_malfunction_stop_moving():
 
 
 def test_initial_malfunction_do_nothing():
-    stochastic_data = MalfunctionParameters(malfunction_rate=70,  # Rate of malfunction occurence
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/70,  # Rate of malfunction occurence
                                             min_duration=2,  # Minimal duration of malfunction
                                             max_duration=5  # Max duration of malfunction
                                             )