diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 68b0c402e0ad7b954100b05253d24eb32985867f..073c8dfc2b6bf8c3aa2dbac5e742269a1db3038e 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -5,7 +5,7 @@ from collections import deque # make sure the root path is in system path from pathlib import Path -from flatland.envs.malfunction_generators import malfunction_from_params +from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir))