diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth
index 14882a37a86085b137f4422b6bba75f387a2d3b5..d3081e88f97ac75641c0d94c7cf794f34d436581 100644
Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 1fbe1495c0c5c11684f8a192ed3ced73d6cf2520..0c4ed8547591fe6cdd106438a4be28037baba216 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -6,6 +6,8 @@ import numpy as np
 import torch
 
 from flatland.envs.generators import complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 from torch_training.dueling_double_dqn import Agent
@@ -47,10 +49,12 @@ env = RailEnv(width=10,
               height=20)
 env.load_resource('torch_training.railway', "complex_scene.pkl")
 """
-env = RailEnv(width=8,
-              height=8,
-              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=5, max_dist=99999, seed=0),
-              number_of_agents=1)
+
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()),
+              number_of_agents=10)
 env.reset(True, True)
 
 env_renderer = RenderTool(env, gl="PILSVG")
@@ -73,9 +77,9 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint1500.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
-demo = False
+demo = True
 
 
 def max_lt(seq, val):
@@ -149,7 +153,7 @@ for trials in range(1, n_trials + 1):
     score = 0
     env_done = 0
     # Run episode
-    for step in range(100):
+    for step in range(env.height * env.width):
         if demo:
             env_renderer.renderEnv(show=True, show_observations=False)
         # print(step)
@@ -157,6 +161,7 @@ for trials in range(1, n_trials + 1):
         for a in range(env.get_num_agents()):
             if demo:
                 eps = 1
+            # action = agent.act(np.array(obs[a]), eps=eps)
             action = agent.act(agent_obs[a], eps=eps)
             action_prob[action] += 1
             action_dict.update({a: action})