diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index f2458c20c2e47ea56e577f229f4221b1bfe4e195..18664437ebef0dbf4261cfdb3ba692dd5fab7505 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -30,7 +30,7 @@ y_dim = env.height
 # Parameters for the Environment
 x_dim = 25
 y_dim = 25
-n_agents = 1
+n_agents = 10
 
 # We are training an Agent using the Tree Observation with depth 2
 observation_builder = TreeObsForRailEnv(max_depth=2)
@@ -43,13 +43,13 @@ stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                    }
 
 # Custom observation builder
-TreeObservation = TreeObsForRailEnv(max_depth=2)
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
 # Different agent types (trains) with different speeds.
-speed_ration_map = {1.: 1.,  # Fast passenger train
-                    1. / 2.: 0.0,  # Fast freight train
-                    1. / 3.: 0.0,  # Slow commuter train
-                    1. / 4.: 0.0}  # Slow freight train
+speed_ration_map = {1.: 0.25,  # Fast passenger train
+                    1. / 2.: 0.25,  # Fast freight train
+                    1. / 3.: 0.25,  # Slow commuter train
+                    1. / 4.: 0.25}  # Slow freight train
 
 env = RailEnv(width=x_dim,
               height=y_dim,
@@ -93,7 +93,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size)
-with path(torch_training.Nets, "avoider_checkpoint1000.pth") as file_in:
+with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index ed20ea69c2042a7ab51722df6ba553aae741d2c5..0dacc9c25b9bc3e01a6b3e530fa6b081cb8d91df 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -14,6 +14,8 @@ import torch
 from torch_training.dueling_double_dqn import Agent
 
 from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
@@ -48,7 +50,7 @@ def main(argv):
                        }
 
     # Custom observation builder
-    TreeObservation = TreeObsForRailEnv(max_depth=2)
+    TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
 
     # Different agent types (trains) with different speeds.
     speed_ration_map = {1.: 0.25,  # Fast passenger train
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index d599bcf5ad5306ee70670bc672f20be45dbb40dd..62cfa494072d2ac3e8d164beb54f06ef202544d0 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -4,7 +4,6 @@ from collections import deque
 import numpy as np
 import torch
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
@@ -67,7 +66,6 @@ env = RailEnv(width=x_dim,
               obs_builder_object=TreeObservation)
 env.reset(True, True)
 
-observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG", )
 num_features_per_node = env.obs_builder.observation_dim
 
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 8e467968bcdad4a7b6ce39b2d70d096e32d332dd..b0942ee84a63f6b0d97212fb53c59377f8e7d285 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -13,8 +13,8 @@ import numpy as np
 import torch
 from torch_training.dueling_double_dqn import Agent
 
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.rail_env import RailEnv
+flatland.envs.rail_env
+import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool