diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 68b0c402e0ad7b954100b05253d24eb32985867f..073c8dfc2b6bf8c3aa2dbac5e742269a1db3038e 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -5,7 +5,7 @@ from collections import deque
 # make sure the root path is in system path
 from pathlib import Path
 
-from flatland.envs.malfunction_generators import malfunction_from_params
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))