diff --git a/parameters.txt b/parameters.txt
index 80ad8c2457f753f0ce7cbc82a9c06b561853b2f2..173c5a784f389c6b52af59fa2ef3e4fa4fedcc13 100644
--- a/parameters.txt
+++ b/parameters.txt
@@ -1,4 +1,4 @@
-{'Test_0':[100,100,5,3],
+{'Test_0':[20,20,20,3],
 'Test_1':[10,10,3,4321],
 'Test_2':[10,10,5,123],
 'Test_3':[50,50,5,21],
diff --git a/score_test.py b/score_test.py
index bf309d2eb5a366807962701a5b8f166a4ddb77de..ff4a94c5e1b82c90eec0c5bf129bad496046e595 100644
--- a/score_test.py
+++ b/score_test.py
@@ -1,9 +1,7 @@
 import time
 
 import numpy as np
-import torch
 
-from torch_training.dueling_double_dqn import Agent
 from utils.misc_utils import RandomAgent, run_test
 
 with open('parameters.txt','r') as inf:
@@ -23,8 +21,8 @@ test_results = []
 test_times = []
 test_dones = []
 # Load agent
-agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./torch_training/Nets/avoid_checkpoint1700.pth'))
+# agent = Agent(state_size, action_size, "FC", 0)
+# agent.qnetwork_local.load_state_dict(torch.load('./torch_training/Nets/avoid_checkpoint1700.pth'))
 agent = RandomAgent(state_size, action_size)
 start_time_scoring = time.time()
 test_idx = 0
diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth
index ba488f1a074af3f62ad87d54f61a251e8292ae50..833de82752968a507b8e5397e76f55c26558d946 100644
Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and b/torch_training/Nets/avoid_checkpoint15000.pth differ
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb11fb1b183d7041e35186f40693b378c80eec68
--- /dev/null
+++ b/torch_training/multi_agent_training.py
@@ -0,0 +1,194 @@
+import random
+from collections import deque
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from dueling_double_dqn import Agent
+from flatland.envs.generators import complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+
+from utils.observation_utils import norm_obs_clip, split_tree
+
+random.seed(1)
+np.random.seed(1)
+
+"""
+env = RailEnv(width=10,
+              height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
+env.load("./railway/complex_scene.pkl")
+file_load = True
+"""
+
+x_dim = np.random.randint(8, 20)
+y_dim = np.random.randint(8, 20)
+n_agents = np.random.randint(3, 8)
+n_goals = n_agents + np.random.randint(0, 3)
+min_dist = int(0.75 * min(x_dim, y_dim))
+env = RailEnv(width=x_dim,
+              height=y_dim,
+              rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                                                    max_dist=99999,
+                                                    seed=0),
+              obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
+              number_of_agents=n_agents)
+env.reset(True, True)
+file_load = False
+"""
+
+"""
+observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
+env_renderer = RenderTool(env, gl="PILSVG", )
+handle = env.get_agent_handles()
+features_per_node = 9
+state_size = features_per_node * 85 * 2
+action_size = 5
+n_trials = 30000
+max_steps = int(3 * (env.height + env.width))
+eps = 1.
+eps_end = 0.005
+eps_decay = 0.9995
+action_dict = dict()
+final_action_dict = dict()
+scores_window = deque(maxlen=100)
+done_window = deque(maxlen=100)
+time_obs = deque(maxlen=2)
+scores = []
+dones_list = []
+action_prob = [0] * action_size
+agent_obs = [None] * env.get_num_agents()
+agent_next_obs = [None] * env.get_num_agents()
+agent = Agent(state_size, action_size, "FC", 0)
+agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
+
+demo = True
+record_images = False
+
+for trials in range(1, n_trials + 1):
+
+    if trials % 50 == 0 and not demo:
+        x_dim = np.random.randint(8, 20)
+        y_dim = np.random.randint(8, 20)
+        n_agents = np.random.randint(3, 8)
+        n_goals = n_agents + np.random.randint(0, 3)
+        min_dist = int(0.75 * min(x_dim, y_dim))
+        env = RailEnv(width=x_dim,
+                      height=y_dim,
+                      rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                                                            max_dist=99999,
+                                                            seed=0),
+                      obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
+                      number_of_agents=n_agents)
+        env.reset(True, True)
+        max_steps = int(3 * (env.height + env.width))
+        agent_obs = [None] * env.get_num_agents()
+        agent_next_obs = [None] * env.get_num_agents()
+    # Reset environment
+    if file_load:
+        obs = env.reset(False, False)
+    else:
+        obs = env.reset(True, True)
+    if demo:
+        env_renderer.set_new_rail()
+    obs_original = obs.copy()
+    final_obs = obs.copy()
+    final_obs_next = obs.copy()
+    for a in range(env.get_num_agents()):
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
+                                                current_depth=0)
+        data = norm_obs_clip(data)
+        distance = norm_obs_clip(distance)
+        agent_data = np.clip(agent_data, -1, 1)
+        obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+        agent_data = env.agents[a]
+        speed = 1  # np.random.randint(1,5)
+        agent_data.speed_data['speed'] = 1. / speed
+
+    for i in range(2):
+        time_obs.append(obs)
+    # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
+    for a in range(env.get_num_agents()):
+        agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+
+    score = 0
+    env_done = 0
+    # Run episode
+    for step in range(max_steps):
+        if demo:
+            env_renderer.renderEnv(show=True, show_observations=True)
+            # observation_helper.util_print_obs_subtree(obs_original[0])
+            if record_images:
+                env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(step))
+        # print(step)
+        # Action
+        for a in range(env.get_num_agents()):
+            if demo:
+                eps = 0
+            # action = agent.act(np.array(obs[a]), eps=eps)
+            action = agent.act(agent_obs[a], eps=eps)
+            action_prob[action] += 1
+            action_dict.update({a: action})
+        # Environment step
+
+        next_obs, all_rewards, done, _ = env.step(action_dict)
+        # print(all_rewards,action)
+        obs_original = next_obs.copy()
+        for a in range(env.get_num_agents()):
+            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
+                                                    current_depth=0)
+            data = norm_obs_clip(data)
+            distance = norm_obs_clip(distance)
+            agent_data = np.clip(agent_data, -1, 1)
+            next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+        time_obs.append(next_obs)
+
+        # Update replay buffer and train agent
+        for a in range(env.get_num_agents()):
+            agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+            if done[a]:
+                final_obs[a] = agent_obs[a].copy()
+                final_obs_next[a] = agent_next_obs[a].copy()
+                final_action_dict.update({a: action_dict[a]})
+            if not demo and not done[a]:
+                agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
+            score += all_rewards[a] / env.get_num_agents()
+
+        agent_obs = agent_next_obs.copy()
+        if done['__all__']:
+            env_done = 1
+            for a in range(env.get_num_agents()):
+                agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
+            break
+    # Epsilon decay
+    eps = max(eps_end, eps_decay * eps)  # decrease epsilon
+
+    done_window.append(env_done)
+    scores_window.append(score / max_steps)  # save most recent score
+    scores.append(np.mean(scores_window))
+    dones_list.append((np.mean(done_window)))
+
+    print(
+        '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+            env.get_num_agents(), x_dim, y_dim,
+            trials,
+            np.mean(scores_window),
+            100 * np.mean(done_window),
+            eps, action_prob / np.sum(action_prob)), end=" ")
+
+    if trials % 100 == 0:
+        print(
+            '\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                env.get_num_agents(),
+                trials,
+                np.mean(scores_window),
+                100 * np.mean(done_window),
+                eps,
+                action_prob / np.sum(action_prob)))
+        torch.save(agent.qnetwork_local.state_dict(),
+                   './Nets/avoid_checkpoint' + str(trials) + '.pth')
+        action_prob = [1] * action_size
+plt.plot(scores)
+plt.show()
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 0356b531a5c9518820754e04f8e16a450256f5d2..3007ce59bed2e75adc0672511e397c39eea1e2cb 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -44,7 +44,7 @@ env = RailEnv(width=15,
 
 
 env = RailEnv(width=10,
-              height=20, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
+              height=20, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
 env.load("./railway/complex_scene.pkl")
 file_load = True
 """
@@ -62,8 +62,8 @@ env = RailEnv(width=x_dim,
               number_of_agents=n_agents)
 env.reset(True, True)
 file_load = False
-
 """
+
 """
 observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
 env_renderer = RenderTool(env, gl="PILSVG",)
@@ -87,9 +87,9 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-# agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth'))
+agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
 
-demo = False
+demo = True
 record_images = False
 
 
@@ -146,7 +146,7 @@ for trials in range(1, n_trials + 1):
     for step in range(max_steps):
         if demo:
             env_renderer.renderEnv(show=True, show_observations=True)
-            observation_helper.util_print_obs_subtree(obs_original[0])
+            # observation_helper.util_print_obs_subtree(obs_original[0])
             if record_images:
                 env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(step))
         # print(step)
diff --git a/utils/misc_utils.py b/utils/misc_utils.py
index d4c6ef8280d3c5b32201f6e5d5ecbbb4cb075e85..03c9fdde9368bf324f7e10841b2d30b993858fd6 100644
--- a/utils/misc_utils.py
+++ b/utils/misc_utils.py
@@ -4,8 +4,7 @@ from collections import deque
 
 import numpy as np
 from flatland.envs.generators import complex_rail_generator
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
 from line_profiler import LineProfiler
 
@@ -69,7 +68,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
     features_per_node = 9
     start_time_scoring = time.time()
     action_dict = dict()
-    nr_trials_per_test = 100
+    nr_trials_per_test = 5
     print('Running Test {} with (x_dim,y_dim) = ({},{}) and {} Agents.'.format(test_nr, parameters[0], parameters[1],
                                                                                parameters[2]))
 
@@ -88,8 +87,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
                   rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist,
                                                         max_dist=99999,
                                                         seed=parameters[3]),
-                  obs_builder_object=TreeObsForRailEnv(max_depth=tree_depth,
-                                                       predictor=ShortestPathPredictorForRailEnv()),
+                  obs_builder_object=GlobalObsForRailEnv(),
                   number_of_agents=parameters[2])
     max_steps = int(3 * (env.height + env.width))
     lp_step = lp(env.step)