diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 0dacc9c25b9bc3e01a6b3e530fa6b081cb8d91df..2e20c63b293b355afec2be33cbd9acca209039d4 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -13,15 +13,13 @@ import numpy as np
 import torch
 from torch_training.dueling_double_dqn import Agent
 
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.utils.rendertools import RenderTool
 from utils.observation_utils import normalize_observation
-
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 
 def main(argv):
     try:
@@ -37,26 +35,26 @@ def main(argv):
     np.random.seed(1)
 
     # Parameters for the Environment
-    x_dim = 40
-    y_dim = 40
-    n_agents = 4
+    x_dim = 35
+    y_dim = 35
+    n_agents = 5
 
 
     # Use a the malfunction generator to break agents from time to time
-    stochastic_data = {'prop_malfunction': 0.05,  # Percentage of defective agents
-                       'malfunction_rate': 50,  # Rate of malfunction occurence
+    stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
+                       'malfunction_rate': 30,  # Rate of malfunction occurence
                        'min_duration': 3,  # Minimal duration of malfunction
                        'max_duration': 20  # Max duration of malfunction
                        }
 
     # Custom observation builder
-    TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
+    TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
 
     # Different agent types (trains) with different speeds.
-    speed_ration_map = {1.: 0.25,  # Fast passenger train
-                        1. / 2.: 0.25,  # Fast freight train
-                        1. / 3.: 0.25,  # Slow commuter train
-                        1. / 4.: 0.25}  # Slow freight train
+    speed_ration_map = {1.: 0.,  # Fast passenger train
+                        1. / 2.: 1.0,  # Fast freight train
+                        1. / 3.: 0.0,  # Slow commuter train
+                        1. / 4.: 0.0}  # Slow freight train
 
     env = RailEnv(width=x_dim,
                   height=y_dim,
@@ -120,8 +118,9 @@ def main(argv):
         env_renderer.reset()
         # Build agent specific observations
         for a in range(env.get_num_agents()):
-            agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
-            agent_obs_buffer[a] = agent_obs[a].copy()
+            if obs[a]:
+                agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
+                agent_obs_buffer[a] = agent_obs[a].copy()
 
         # Reset score and done
         score = 0
@@ -153,7 +152,8 @@ def main(argv):
 
                     agent_obs_buffer[a] = agent_obs[a].copy()
                     agent_action_buffer[a] = action_dict[a]
-                agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
+                if next_obs[a]:
+                    agent_obs[a] = normalize_observation(next_obs[a], tree_depth, observation_radius=10)
 
                 score += all_rewards[a] / env.get_num_agents()
 
@@ -192,7 +192,7 @@ def main(argv):
                     100 * np.mean(done_window),
                     eps, action_prob / np.sum(action_prob)))
             torch.save(agent.qnetwork_local.state_dict(),
-                       './Nets/avoider_checkpoint' + str(trials) + '.pth')
+                       './Nets/navigator_checkpoint' + str(trials) + '.pth')
             action_prob = [1] * action_size
 
     # Plot overall training progress at the end