diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth
index 14882a37a86085b137f4422b6bba75f387a2d3b5..d3081e88f97ac75641c0d94c7cf794f34d436581 100644
Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 8c30d72010d95389dd22af06971170aa9c4b4480..f16a2c43fe9bee0a4f75d5a7ae4ad0931e1bbf7e 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -5,6 +5,8 @@ import numpy as np
 import torch
 from dueling_double_dqn import Agent
 from flatland.envs.generators import complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -43,10 +45,12 @@ env = RailEnv(width=10,
               height=20)
 env.load("./railway/complex_scene.pkl")
 """
-env = RailEnv(width=8,
-              height=8,
-              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0),
-              number_of_agents=1)
+
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()),
+              number_of_agents=10)
 env.reset(True, True)
 
 env_renderer = RenderTool(env, gl="PILSVG")
@@ -69,9 +73,9 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
-demo = False
+demo = True
 
 def max_lt(seq, val):
     """
@@ -143,14 +147,14 @@ for trials in range(1, n_trials + 1):
     score = 0
     env_done = 0
     # Run episode
-    for step in range(100):
+    for step in range(env.height * env.width):
         if demo:
             env_renderer.renderEnv(show=True, show_observations=False)
         # print(step)
         # Action
         for a in range(env.get_num_agents()):
             if demo:
-                eps = 1
+                eps = 0
             # action = agent.act(np.array(obs[a]), eps=eps)
             action = agent.act(agent_obs[a], eps=eps)
             action_prob[action] += 1