diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 3007ce59bed2e75adc0672511e397c39eea1e2cb..3152ecbf865e7be5c697b13e5b8e10f0948922dd 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -7,7 +7,6 @@ import torch
 from dueling_double_dqn import Agent
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -16,66 +15,52 @@ from utils.observation_utils import norm_obs_clip, split_tree
 random.seed(1)
 np.random.seed(1)
 
-# Example generate a rail given a manual specification,
-# a map of tuples (cell_type, rotation)
-transition_probability = [15,  # empty cell - Case 0
-                          5,  # Case 1 - straight
-                          5,  # Case 2 - simple switch
-                          1,  # Case 3 - diamond crossing
-                          1,  # Case 4 - single slip
-                          1,  # Case 5 - double slip
-                          1,  # Case 6 - symmetrical
-                          0,  # Case 7 - dead end
-                          1,  # Case 1b (8)  - simple turn right
-                          1,  # Case 1c (9)  - simple turn left
-                          1]  # Case 2b (10) - simple switch mirrored
-
-# Example generate a random rail
-"""
-env = RailEnv(width=20,
-              height=20,
-              rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
-              number_of_agents=1)
-
-env = RailEnv(width=15,
-              height=15,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
-              number_of_agents=1)
-
-
-env = RailEnv(width=10,
-              height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
-env.load("./railway/complex_scene.pkl")
-file_load = True
-"""
-x_dim = np.random.randint(8, 20)
-y_dim = np.random.randint(8, 20)
-n_agents = np.random.randint(3, 8)
-n_goals = n_agents + np.random.randint(0, 3)
-min_dist = int(0.75 * min(x_dim, y_dim))
+# Parameters for the Environment
+x_dim = 10
+y_dim = 10
+n_agents = 1
+n_goals = 5
+min_dist = 5
+
+# We are training an Agent using the Tree Observation with depth 2
+observation_builder = TreeObsForRailEnv(max_depth=2)
+
+# Load the Environment
 env = RailEnv(width=x_dim,
               height=y_dim,
               rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                     max_dist=99999,
                                                     seed=0),
-              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
+              obs_builder_object=observation_builder,
               number_of_agents=n_agents)
 env.reset(True, True)
-file_load = False
-"""
 
-"""
-observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
-env_renderer = RenderTool(env, gl="PILSVG",)
-handle = env.get_agent_handles()
+# After training we want to render the results so we also load a renderer
+env_renderer = RenderTool(env, gl="PILSVG", )
+
+# Given the depth of the tree observation and the number of features per node we get the following state_size
 features_per_node = 9
-state_size = features_per_node * 85 * 2
+tree_depth = 2
+nr_nodes = 0
+for i in range(tree_depth + 1):
+    nr_nodes += np.power(4, i)
+state_size = features_per_node * nr_nodes
+
+# The action space of flatland is 5 discrete actions
 action_size = 5
-n_trials = 30000
+
+# We set the number of episodes we would like to train on
+n_trials = 6000
+
+# And the max number of steps we want to take per episode
 max_steps = int(3 * (env.height + env.width))
+
+# Define training parameters
 eps = 1.
 eps_end = 0.005
-eps_decay = 0.9995
+eps_decay = 0.998
+
+# And some variables to keep track of the progress
 action_dict = dict()
 final_action_dict = dict()
 scores_window = deque(maxlen=100)
@@ -86,112 +71,83 @@ dones_list = []
 action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
-agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
 
-demo = True
-record_images = False
+# Now we load a Double dueling DQN agent
+agent = Agent(state_size, action_size, "FC", 0)
 
+Training = True
 
 for trials in range(1, n_trials + 1):
 
-    if trials % 50 == 0 and not demo:
-        x_dim = np.random.randint(8, 20)
-        y_dim = np.random.randint(8, 20)
-        n_agents = np.random.randint(3, 8)
-        n_goals = n_agents + np.random.randint(0, 3)
-        min_dist = int(0.75 * min(x_dim, y_dim))
-        env = RailEnv(width=x_dim,
-                      height=y_dim,
-                      rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
-                                                            max_dist=99999,
-                                                            seed=0),
-                      obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
-                      number_of_agents=n_agents)
-        env.reset(True, True)
-        max_steps = int(3 * (env.height + env.width))
-        agent_obs = [None] * env.get_num_agents()
-        agent_next_obs = [None] * env.get_num_agents()
     # Reset environment
-    if file_load:
-        obs = env.reset(False, False)
-    else:
-        obs = env.reset(True, True)
-    if demo:
+    obs = env.reset(True, True)
+    if not Training:
         env_renderer.set_new_rail()
-    obs_original = obs.copy()
-    final_obs = obs.copy()
-    final_obs_next = obs.copy()
+
+    # Split the observation tree into its parts and normalize the observation using the utility functions.
+    # Build agent specific local observation
     for a in range(env.get_num_agents()):
-        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
-                                                current_depth=0)
-        data = norm_obs_clip(data)
-        distance = norm_obs_clip(distance)
+        rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
+                                                          num_features_per_node=features_per_node,
+                                                          current_depth=0)
+        rail_data = norm_obs_clip(rail_data)
+        distance_data = norm_obs_clip(distance_data)
         agent_data = np.clip(agent_data, -1, 1)
-        obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
-        agent_data = env.agents[a]
-        speed = 1 #np.random.randint(1,5)
-        agent_data.speed_data['speed'] = 1. / speed
-
-    for i in range(2):
-        time_obs.append(obs)
-    # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
-    for a in range(env.get_num_agents()):
-        agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
-
+        agent_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
 
+    # Reset score and done
     score = 0
     env_done = 0
+
     # Run episode
     for step in range(max_steps):
-        if demo:
+
+        # Only render when not triaing
+        if not Training:
             env_renderer.renderEnv(show=True, show_observations=True)
-            # observation_helper.util_print_obs_subtree(obs_original[0])
-            if record_images:
-                env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(step))
-        # print(step)
-        # Action
+
+        # Chose the actions
         for a in range(env.get_num_agents()):
-            if demo:
+            if not Training:
                 eps = 0
-            # action = agent.act(np.array(obs[a]), eps=eps)
+
             action = agent.act(agent_obs[a], eps=eps)
-            action_prob[action] += 1
             action_dict.update({a: action})
-        # Environment step
 
+            # Count number of actions takes for statistics
+            action_prob[action] += 1
+
+        # Environment step
         next_obs, all_rewards, done, _ = env.step(action_dict)
-        # print(all_rewards,action)
-        obs_original = next_obs.copy()
+
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
-                                                    current_depth=0)
-            data = norm_obs_clip(data)
-            distance = norm_obs_clip(distance)
+            rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                              num_features_per_node=features_per_node,
+                                                              current_depth=0)
+            rail_data = norm_obs_clip(rail_data)
+            distance_data = norm_obs_clip(distance_data)
             agent_data = np.clip(agent_data, -1, 1)
-            next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
-        time_obs.append(next_obs)
+            agent_next_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
 
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
-            agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
-            if done[a]:
-                final_obs[a] = agent_obs[a].copy()
-                final_obs_next[a] = agent_next_obs[a].copy()
-                final_action_dict.update({a: action_dict[a]})
-            if not demo and not done[a]:
+
+            # Remember and train agent
+            if Training:
                 agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
+
+            # Update the current score
             score += all_rewards[a] / env.get_num_agents()
 
         agent_obs = agent_next_obs.copy()
         if done['__all__']:
             env_done = 1
-            for a in range(env.get_num_agents()):
-                agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
             break
+
     # Epsilon decay
     eps = max(eps_end, eps_decay * eps)  # decrease epsilon
 
+    # Store the information about training progress
     done_window.append(env_done)
     scores_window.append(score / max_steps)  # save most recent score
     scores.append(np.mean(scores_window))
@@ -200,22 +156,68 @@ for trials in range(1, n_trials + 1):
     print(
         '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
             env.get_num_agents(), x_dim, y_dim,
-              trials,
-              np.mean(scores_window),
-              100 * np.mean(done_window),
-              eps, action_prob / np.sum(action_prob)), end=" ")
+            trials,
+            np.mean(scores_window),
+            100 * np.mean(done_window),
+            eps, action_prob / np.sum(action_prob)), end=" ")
 
     if trials % 100 == 0:
         print(
-            '\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
-                env.get_num_agents(),
+            '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                env.get_num_agents(), x_dim, y_dim,
                 trials,
                 np.mean(scores_window),
                 100 * np.mean(done_window),
-                eps,
-                action_prob / np.sum(action_prob)))
+                eps, action_prob / np.sum(action_prob)))
         torch.save(agent.qnetwork_local.state_dict(),
-                   './Nets/avoid_checkpoint' + str(trials) + '.pth')
+                   './Nets/navigator_checkpoint' + str(trials) + '.pth')
         action_prob = [1] * action_size
+
+# Render the trained agent
+
+# Reset environment
+obs = env.reset(True, True)
+env_renderer.set_new_rail()
+
+# Split the observation tree into its parts and normalize the observation using the utility functions.
+# Build agent specific local observation
+for a in range(env.get_num_agents()):
+    rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
+                                                      current_depth=0)
+    rail_data = norm_obs_clip(rail_data)
+    distance_data = norm_obs_clip(distance_data)
+    agent_data = np.clip(agent_data, -1, 1)
+    agent_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
+
+# Reset score and done
+score = 0
+env_done = 0
+
+# Run episode
+for step in range(max_steps):
+    env_renderer.renderEnv(show=True, show_observations=False)
+
+    # Chose the actions
+    for a in range(env.get_num_agents()):
+        eps = 0
+        action = agent.act(agent_obs[a], eps=eps)
+        action_dict.update({a: action})
+
+    # Environment step
+    next_obs, all_rewards, done, _ = env.step(action_dict)
+
+    for a in range(env.get_num_agents()):
+        rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                          num_features_per_node=features_per_node,
+                                                          current_depth=0)
+        rail_data = norm_obs_clip(rail_data)
+        distance_data = norm_obs_clip(distance_data)
+        agent_data = np.clip(agent_data, -1, 1)
+        agent_next_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
+
+    agent_obs = agent_next_obs.copy()
+    if done['__all__']:
+        break
+# Plot overall training progress at the end
 plt.plot(scores)
 plt.show()