diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index e131f46a76c7da7c37981d58ba43736437615f53..cb8fa32397f668288e7dd689ea68676ed1cc9592 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -3,7 +3,7 @@ from collections import deque
 
 import numpy as np
 import torch
-from flatland.envs.malfunction_generators import malfunction_from_params
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -37,10 +37,11 @@ n_agents = 10
 observation_builder = TreeObsForRailEnv(max_depth=2)
 
 # Use a the malfunction generator to break agents from time to time
-stochastic_data = {'malfunction_rate': 8000,  # Rate of malfunction occurence of single agent
-                   'min_duration': 15,  # Minimal duration of malfunction
-                   'max_duration': 50  # Max duration of malfunction
-                   }
+stochastic_data = MalfunctionParameters(malfunction_rate=10000,  # Rate of malfunction occurence
+                                        min_duration=15,  # Minimal duration of malfunction
+                                        max_duration=50  # Max duration of malfunction
+                                        )
+
 
 
 # Custom observation builder