diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth
index d3081e88f97ac75641c0d94c7cf794f34d436581..77c2ad680533b909a245d7d089402520eb55efcb 100644
Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 96593864a31cf45864b3cf8f52c29a0d5a241ef0..4b82d637b400acbb5927ac0ee5a683f5f9c04fa4 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -1,12 +1,13 @@
 import random
 from collections import deque
 
+import matplotlib.pyplot as plt
 import numpy as np
 import torch
 from dueling_double_dqn import Agent
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import DummyPredictorForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -46,10 +47,10 @@ env = RailEnv(width=10,
 env.load("./railway/complex_scene.pkl")
 """
 
-env = RailEnv(width=8,
-              height=8,
-              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=1, min_dist=4, max_dist=99999, seed=0),
-              obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv()),
+env = RailEnv(width=12,
+              height=12,
+              rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=5, min_dist=10, max_dist=99999, seed=0),
+              obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
               number_of_agents=3)
 
 env.reset(True, True)
@@ -59,8 +60,8 @@ handle = env.get_agent_handles()
 
 state_size = 168 * 2
 action_size = 5
-n_trials = 15000
-max_steps = int(1.5 * (env.height + env.width))
+n_trials = 20000
+max_steps = int(3 * (env.height + env.width))
 eps = 1.
 eps_end = 0.005
 eps_decay = 0.9995
@@ -75,7 +76,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint20000.pth'))
 
 demo = False
 
@@ -220,3 +221,5 @@ for trials in range(1, n_trials + 1):
         torch.save(agent.qnetwork_local.state_dict(),
                    './Nets/avoid_checkpoint' + str(trials) + '.pth')
         action_prob = [1] * action_size
+plt.plot(scores)
+plt.show()