diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 2a4af22ad148814f397dd32b7e96f3f6d666c70f..1d12b53ab7d1ee248601c6aeaa782b75c6b5b15d 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -43,13 +43,18 @@ env = RailEnv(width=15,
 env = RailEnv(width=10,
               height=20)
 env.load("./railway/complex_scene.pkl")
+
+env = RailEnv(width=15,
+              height=15,
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
+              number_of_agents=1)
 env.reset(False, False)
 
 env_renderer = RenderTool(env, gl="PILSVG")
 handle = env.get_agent_handles()
 
-state_size = 105 * 2
-action_size = 4
+state_size = 147 * 2
+action_size = 5
 n_trials = 15000
 eps = 1.
 eps_end = 0.005
@@ -61,13 +66,13 @@ done_window = deque(maxlen=100)
 time_obs = deque(maxlen=2)
 scores = []
 dones_list = []
-action_prob = [0] * 4
+action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint10400.pth'))
+#agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
 
-demo = True
+demo = False
 
 
 def max_lt(seq, val):
@@ -119,18 +124,18 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
 for trials in range(1, n_trials + 1):
 
     # Reset environment
-    obs = env.reset(False,False)
-
+    obs = env.reset(True, True)
+    if demo:
+        env_renderer.set_new_rail()
     final_obs = obs.copy()
     final_obs_next = obs.copy()
 
     for a in range(env.get_num_agents()):
-        data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0)
-
+        data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
-        obs[a] = np.concatenate((data, distance))
 
+        obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
     for i in range(2):
         time_obs.append(obs)
     # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
@@ -142,25 +147,26 @@ for trials in range(1, n_trials + 1):
     # Run episode
     for step in range(360):
         if demo:
+
             env_renderer.renderEnv(show=True,show_observations=False)
         # print(step)
         # Action
         for a in range(env.get_num_agents()):
             if demo:
-                eps = 0
+                eps = 1
             # action = agent.act(np.array(obs[a]), eps=eps)
-            action = agent.act(agent_obs[a])
+            action = agent.act(agent_obs[a], eps=eps)
             action_prob[action] += 1
             action_dict.update({a: action})
 
         # Environment step
         next_obs, all_rewards, done, _ = env.step(action_dict)
         for a in range(env.get_num_agents()):
-            data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5,
+            data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7,
                                                         current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
-            next_obs[a] = np.concatenate((data, distance))
+            next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data))
 
         time_obs.append(next_obs)