diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 0f9d1e5da7cfb35ca16f4b243f9eb0013f3e824b..e131f46a76c7da7c37981d58ba43736437615f53 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -44,7 +44,7 @@ stochastic_data = {'malfunction_rate': 8000,  # Rate of malfunction occurence of
 
 
 # Custom observation builder
-TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
 
 # Different agent types (trains) with different speeds.
 speed_ration_map = {1.: 0.25,  # Fast passenger train
@@ -80,7 +80,7 @@ action_size = 5
 # We set the number of episodes we would like to train on
 if 'n_trials' not in locals():
     n_trials = 60000
-max_steps = int(3 * (env.height + env.width))
+max_steps = int(4 * 2 * (20 + env.height + env.width))
 eps = 1.
 eps_end = 0.005
 eps_decay = 0.9995
@@ -94,7 +94,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size)
-with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in:
+with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
@@ -119,7 +119,6 @@ for trials in range(1, n_trials + 1):
         for a in range(env.get_num_agents()):
             if info['action_required'][a]:
                 action = agent.act(agent_obs[a], eps=0.)
-
             else:
                 action = 0
 
@@ -130,7 +129,8 @@ for trials in range(1, n_trials + 1):
         env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
         # Build agent specific observations and normalize
         for a in range(env.get_num_agents()):
-            agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
+            if obs[a]:
+                agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
 
 
         if done['__all__']:
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 93c7eff6ed7ea82fbe8cb4b16cbb1102b36f0b38..0b5ccc7d112e635267041871ffbf56d0290a0391 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -22,6 +22,7 @@ from flatland.utils.rendertools import RenderTool
 from utils.observation_utils import normalize_observation
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.agent_utils import RailAgentStatus
 
 def main(argv):
     try:
@@ -39,7 +40,7 @@ def main(argv):
     # Parameters for the Environment
     x_dim = 35
     y_dim = 35
-    n_agents = 5
+    n_agents = 10
 
 
     # Use a the malfunction generator to break agents from time to time
@@ -52,10 +53,10 @@ def main(argv):
     TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
 
     # Different agent types (trains) with different speeds.
-    speed_ration_map = {1.: 0.,  # Fast passenger train
-                        1. / 2.: 1.0,  # Fast freight train
-                        1. / 3.: 0.0,  # Slow commuter train
-                        1. / 4.: 0.0}  # Slow freight train
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
 
     env = RailEnv(width=x_dim,
                   height=y_dim,
@@ -88,7 +89,7 @@ def main(argv):
         n_trials = 15000
 
     # And the max number of steps we want to take per episode
-    max_steps = int(3 * (env.height + env.width))
+    max_steps = int(4 * 2 * (20 + env.height + env.width))
 
     # Define training parameters
     eps = 1.
@@ -108,7 +109,7 @@ def main(argv):
     agent_obs_buffer = [None] * env.get_num_agents()
     agent_action_buffer = [2] * env.get_num_agents()
     cummulated_reward = np.zeros(env.get_num_agents())
-    update_values = False
+    update_values = [False] * env.get_num_agents()
     # Now we load a Double dueling DQN agent
     agent = Agent(state_size, action_size)
 
@@ -128,16 +129,16 @@ def main(argv):
         env_done = 0
 
         # Run episode
-        for step in range(max_steps):
+        while True:
             # Action
             for a in range(env.get_num_agents()):
                 if info['action_required'][a]:
                     # If an action is require, we want to store the obs a that step as well as the action
-                    update_values = True
+                    update_values[a] = True
                     action = agent.act(agent_obs[a], eps=eps)
                     action_prob[action] += 1
                 else:
-                    update_values = False
+                    update_values[a] = False
                     action = 0
                 action_dict.update({a: action})
 
@@ -146,7 +147,7 @@ def main(argv):
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
                 # Only update the values when we are done or when an action was taken and thus relevant information is present
-                if update_values or done[a]:
+                if update_values[a] or done[a]:
                     agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
                                agent_obs[a], done[a])
                     cummulated_reward[a] = 0.
@@ -168,8 +169,8 @@ def main(argv):
 
         # Collection information about training
         tasks_finished = 0
-        for _idx in range(env.get_num_agents()):
-            if done[_idx] == 1:
+        for current_agent in env.agents:
+            if current_agent.status == RailAgentStatus.DONE_REMOVED:
                 tasks_finished += 1
         done_window.append(tasks_finished / max(1, env.get_num_agents()))
         scores_window.append(score / max_steps)  # save most recent score