diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 18664437ebef0dbf4261cfdb3ba692dd5fab7505..e39a16699eb123b36bcc925bbd9e255988a9bb44 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -43,7 +43,7 @@ stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
                    }
 
 # Custom observation builder
-TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
 
 # Different agent types (trains) with different speeds.
 speed_ration_map = {1.: 0.25,  # Fast passenger train
@@ -79,7 +79,7 @@ action_size = 5
 # We set the number of episodes we would like to train on
 if 'n_trials' not in locals():
     n_trials = 60000
-max_steps = int(3 * (env.height + env.width))
+max_steps = int(4 * 2 * (20 + env.height + env.width))
 eps = 1.
 eps_end = 0.005
 eps_decay = 0.9995
@@ -93,7 +93,7 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size)
-with path(torch_training.Nets, "avoider_checkpoint100.pth") as file_in:
+with path(torch_training.Nets, "navigator_checkpoint1200.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
@@ -118,7 +118,6 @@ for trials in range(1, n_trials + 1):
         for a in range(env.get_num_agents()):
             if info['action_required'][a]:
                 action = agent.act(agent_obs[a], eps=0.)
-
             else:
                 action = 0
 
@@ -129,7 +128,8 @@ for trials in range(1, n_trials + 1):
         env_renderer.render_env(show=True, show_predictions=True, show_observations=False)
         # Build agent specific observations and normalize
         for a in range(env.get_num_agents()):
-            agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
+            if obs[a]:
+                agent_obs[a] = normalize_observation(obs[a], tree_depth, observation_radius=10)
 
 
         if done['__all__']:
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 2e20c63b293b355afec2be33cbd9acca209039d4..0e928d0751395e15eb83ab67266880850d4ffe9e 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -20,6 +20,7 @@ from flatland.utils.rendertools import RenderTool
 from utils.observation_utils import normalize_observation
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.agent_utils import RailAgentStatus
 
 def main(argv):
     try:
@@ -37,24 +38,24 @@ def main(argv):
     # Parameters for the Environment
     x_dim = 35
     y_dim = 35
-    n_agents = 5
+    n_agents = 10
 
 
     # Use a the malfunction generator to break agents from time to time
-    stochastic_data = {'prop_malfunction': 0.0,  # Percentage of defective agents
-                       'malfunction_rate': 30,  # Rate of malfunction occurence
-                       'min_duration': 3,  # Minimal duration of malfunction
-                       'max_duration': 20  # Max duration of malfunction
+    stochastic_data = {'prop_malfunction': 0.05,  # Percentage of defective agents
+                       'malfunction_rate': 100,  # Rate of malfunction occurence
+                       'min_duration': 20,  # Minimal duration of malfunction
+                       'max_duration': 50  # Max duration of malfunction
                        }
 
     # Custom observation builder
     TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
 
     # Different agent types (trains) with different speeds.
-    speed_ration_map = {1.: 0.,  # Fast passenger train
-                        1. / 2.: 1.0,  # Fast freight train
-                        1. / 3.: 0.0,  # Slow commuter train
-                        1. / 4.: 0.0}  # Slow freight train
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
 
     env = RailEnv(width=x_dim,
                   height=y_dim,
@@ -87,7 +88,7 @@ def main(argv):
         n_trials = 15000
 
     # And the max number of steps we want to take per episode
-    max_steps = int(3 * (env.height + env.width))
+    max_steps = int(4 * 2 * (20 + env.height + env.width))
 
     # Define training parameters
     eps = 1.
@@ -107,7 +108,7 @@ def main(argv):
     agent_obs_buffer = [None] * env.get_num_agents()
     agent_action_buffer = [2] * env.get_num_agents()
     cummulated_reward = np.zeros(env.get_num_agents())
-    update_values = False
+    update_values = [False] * env.get_num_agents()
     # Now we load a Double dueling DQN agent
     agent = Agent(state_size, action_size)
 
@@ -127,16 +128,16 @@ def main(argv):
         env_done = 0
 
         # Run episode
-        for step in range(max_steps):
+        while True:
             # Action
             for a in range(env.get_num_agents()):
                 if info['action_required'][a]:
                     # If an action is require, we want to store the obs a that step as well as the action
-                    update_values = True
+                    update_values[a] = True
                     action = agent.act(agent_obs[a], eps=eps)
                     action_prob[action] += 1
                 else:
-                    update_values = False
+                    update_values[a] = False
                     action = 0
                 action_dict.update({a: action})
 
@@ -145,7 +146,7 @@ def main(argv):
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
                 # Only update the values when we are done or when an action was taken and thus relevant information is present
-                if update_values or done[a]:
+                if update_values[a] or done[a]:
                     agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
                                agent_obs[a], done[a])
                     cummulated_reward[a] = 0.
@@ -167,8 +168,8 @@ def main(argv):
 
         # Collection information about training
         tasks_finished = 0
-        for _idx in range(env.get_num_agents()):
-            if done[_idx] == 1:
+        for agent in env.agents:
+            if agent.status == RailAgentStatus.DONE_REMOVED:
                 tasks_finished += 1
         done_window.append(tasks_finished / max(1, env.get_num_agents()))
         scores_window.append(score / max_steps)  # save most recent score