diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 216ddc66c439027fdf4a9603f5dc7a4091c246bf..ba9f144ac2edbebdd24706129bb339f1614b3c67 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -7,7 +7,7 @@ import random
 import torch
 from dueling_double_dqn import Agent
 
-import torch_training
+import torch_training.Nets
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 095f20b4a91c0cab1e97fdec55cf419d3cd77be7..1857b676f0fd4fd65f90ed6e37930988519c6cda 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -8,7 +8,7 @@ import numpy as np
 import torch
 from dueling_double_dqn import Agent
 
-import torch_training
+import torch_training.Nets
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv