diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 33fe287d6ef70ebda52c57e8b3a110541d28244e..3797d1cf2dcc02dbed231051811bebe8b2098c0d 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -1,7 +1,10 @@
 from flatland.envs.rail_env import *
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import *
-from flatland.agents.dueling_double_dqn import Agent
+from flatland.baselines.dueling_double_dqn import Agent
+from collections import deque
+import torch
+
 random.seed(1)
 np.random.seed(1)
 
@@ -36,6 +39,16 @@ handle = env.get_agent_handles()
 
 state_size = 105
 action_size = 4
+n_trials = 5000
+eps = 1.
+eps_end = 0.005
+eps_decay = 0.998
+action_dict = dict()
+scores_window = deque(maxlen=100)
+done_window = deque(maxlen=100)
+scores = []
+dones_list = []
+
 agent = Agent(state_size, action_size, "FC", 0)
 
 # Example generate a rail given a manual specification,
@@ -49,27 +62,69 @@ env = RailEnv(width=6,
               number_of_agents=1,
               obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
-
 env.agents_position[0] = [1, 4]
 env.agents_target[0] = [1, 1]
 env.agents_direction[0] = 1
 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
 env.obs_builder.reset()
 
-# TODO: delete next line
-#for i in range(4):
-#    print(env.obs_builder.distance_map[0, :, :, i])
 
-obs, all_rewards, done, _ = env.step({0:0})
-#env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+for trials in range(1, n_trials + 1):
 
-env_renderer = RenderTool(env)
-action_dict = {0: 0}
+    # Reset environment
+    obs, all_rewards, done, _ = env.step({0: 0})
+    # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+
+
+    score = 0
+    env_done = 0
+
+    # Run episode
+    for step in range(100):
+
+        # Action
+        for a in range(env.number_of_agents):
+            action = agent.act(np.array(obs[a]), eps=eps)
+            action_dict.update({a: action})
+
+        # Environment step
+        next_obs, all_rewards, done, _ = env.step(action_dict)
+
+
+
+        # Update replay buffer and train agent
+        for a in range(env.number_of_agents):
+            agent.step(obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a])
+            score += all_rewards[a]
+
+        obs = next_obs.copy()
 
-for step in range(100):
-    obs, all_rewards, done, _ = env.step(action_dict)
-    action = agent.act(np.array(obs[0]),eps=1)
+        if all(done):
+            env_done = 1
+            break
+    # Epsioln decay
+    eps = max(eps_end, eps_decay * eps)  # decrease epsilon
 
-    action_dict = {0 :action}
-    print("Rewards: ", all_rewards, "  [done=", done, "]")
+    done_window.append(env_done)
+    scores_window.append(score)  # save most recent score
+    scores.append(np.mean(scores_window))
+    dones_list.append((np.mean(done_window)))
 
+    print('\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
+                                                                                                             trials,
+                                                                                                             np.mean(
+                                                                                                                 scores_window),
+                                                                                                             100 * np.mean(
+                                                                                                                 done_window),
+                                                                                                             eps),
+          end=" ")
+    if trials % 100 == 0:
+        print(
+            '\rTraining {} Agents.\tEpisode {}\tAverage Score: {:.0f}\tDones: {:.2f}%\tEpsilon: {:.2f}'.format(env.number_of_agents,
+                                                                                                               trials,
+                                                                                                               np.mean(
+                                                                                                                   scores_window),
+                                                                                                               100 * np.mean(
+                                                                                                                   done_window),
+                                                                                                               eps))
+        torch.save(agent.qnetwork_local.state_dict(), '../flatland/baselines/Nets/avoid_checkpoint' + str(trials) + '.pth')
diff --git a/flatland/baselines/dueling_double_dqn.py b/flatland/baselines/dueling_double_dqn.py
index 3eacf4c9a66612c87f64e4ae65b7714313ffcf64..084d0a22b8c3ae39a2503e849ed346f50a7c8aa9 100644
--- a/flatland/baselines/dueling_double_dqn.py
+++ b/flatland/baselines/dueling_double_dqn.py
@@ -2,7 +2,7 @@ import numpy as np
 import random
 from collections import namedtuple, deque
 import os
-from flatland.agents.model import QNetwork, QNetwork2
+from flatland.baselines.model import QNetwork, QNetwork2
 import torch
 import torch.nn.functional as F
 import torch.optim as optim