diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index d273bbfdc6ec98da469c7c5c120add4a38026f56..0862e30e1761375bcb35026ea2ec4cfe74886df5 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -1,18 +1,20 @@
+# Import packages for plotting and system
+import getopt
+import random
 import sys
 from collections import deque
 
-import getopt
 import matplotlib.pyplot as plt
 import numpy as np
-import random
 import torch
+# Import Flatland/ Observations and Predictors
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
-from flatland.utils.rendertools import RenderTool
 from importlib_resources import path
 
+# Import Torch and utility functions to normalize observation
 import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import norm_obs_clip, split_tree
@@ -20,52 +22,53 @@ from utils.observation_utils import norm_obs_clip, split_tree
 
 def main(argv):
     try:
-        opts, args = getopt.getopt(argv, "n:", ["n_trials="])
+        opts, args = getopt.getopt(argv, "n:", ["n_episodes="])
     except getopt.GetoptError:
-        print('training_navigation.py -n <n_trials>')
+        print('training_navigation.py -n <n_episodes>')
         sys.exit(2)
     for opt, arg in opts:
-        if opt in ('-n', '--n_trials'):
-            n_trials = int(arg)
+        if opt in ('-n', '--n_episodes'):
+            n_episodes = int(arg)
+
+    ## Initialize the random
     random.seed(1)
     np.random.seed(1)
-    """
-
-    file_name = "./railway/complex_scene.pkl"
-    env = RailEnv(width=10,
-                  height=20,
-                  rail_generator=rail_from_data(file_name),
-                  obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()))
-    x_dim = env.width
-    y_dim = env.height
-    """
 
+    # Initialize a random map with a random number of agents
     x_dim = np.random.randint(8, 20)
     y_dim = np.random.randint(8, 20)
     n_agents = np.random.randint(3, 8)
     n_goals = n_agents + np.random.randint(0, 3)
     min_dist = int(0.75 * min(x_dim, y_dim))
+    tree_depth = 3
     print("main2")
 
+    # Get an observation builder and predictor
+    predictor = ShortestPathPredictorForRailEnv()
+    observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor)
+
     env = RailEnv(width=x_dim,
                   height=y_dim,
                   rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                         max_dist=99999,
                                                         seed=0),
-                  obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()),
+                  obs_builder_object=observation_helper,
                   number_of_agents=n_agents)
     env.reset(True, True)
 
-    observation_helper = TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv())
-    env_renderer = RenderTool(env, gl="PILSVG", )
     handle = env.get_agent_handles()
-    features_per_node = 9
-    state_size = features_per_node * 85 * 2
+    num_features_per_node = env.obs_builder.observation_dim
+    nr_nodes = 0
+    for i in range(tree_depth + 1):
+        nr_nodes += np.power(4, i)
+    state_size = num_features_per_node * nr_nodes
     action_size = 5
 
     # We set the number of episodes we would like to train on
-    if 'n_trials' not in locals():
-        n_trials = 60000
+    if 'n_episodes' not in locals():
+        n_episodes = 60000
+
+    # Set max number of steps per episode as well as other training relevant parameter
     max_steps = int(3 * (env.height + env.width))
     eps = 1.
     eps_end = 0.005
@@ -74,23 +77,28 @@ def main(argv):
     final_action_dict = dict()
     scores_window = deque(maxlen=100)
     done_window = deque(maxlen=100)
-    time_obs = deque(maxlen=2)
     scores = []
     dones_list = []
     action_prob = [0] * action_size
     agent_obs = [None] * env.get_num_agents()
     agent_next_obs = [None] * env.get_num_agents()
-    agent = Agent(state_size, action_size, "FC", 0)
-    with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
-        agent.qnetwork_local.load_state_dict(torch.load(file_in))
+    observation_radius = 10
 
-    demo = False
-    record_images = False
-    frame_step = 0
-
-    for trials in range(1, n_trials + 1):
+    # Initialize the agent
+    agent = Agent(state_size, action_size, "FC", 0)
 
-        if trials % 50 == 0 and not demo:
+    # Here you can pre-load an agent
+    if False:
+        with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
+            agent.qnetwork_local.load_state_dict(torch.load(file_in))
+
+    # Do training over n_episodes
+    for episodes in range(1, n_episodes + 1):
+        """
+        Training Curriculum: In order to get good generalization we change the number of agents
+        and the size of the levels every 50 episodes.
+        """
+        if episodes % 50 == 0:
             x_dim = np.random.randint(8, 20)
             y_dim = np.random.randint(8, 20)
             n_agents = np.random.randint(3, 8)
@@ -101,90 +109,78 @@ def main(argv):
                           rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
                                                                 max_dist=99999,
                                                                 seed=0),
-                          obs_builder_object=TreeObsForRailEnv(max_depth=3,
-                                                               predictor=ShortestPathPredictorForRailEnv()),
+                          obs_builder_object=observation_helper,
                           number_of_agents=n_agents)
-            env.reset(True, True)
+
+            # Adjust the parameters according to the new env.
             max_steps = int(3 * (env.height + env.width))
             agent_obs = [None] * env.get_num_agents()
             agent_next_obs = [None] * env.get_num_agents()
+
         # Reset environment
         obs = env.reset(True, True)
-        if demo:
-            env_renderer.set_new_rail()
-        obs_original = obs.copy()
-        final_obs = obs.copy()
-        final_obs_next = obs.copy()
+
+        # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
+        # different times during an episode
+        final_obs = agent_obs.copy()
+        final_obs_next = agent_next_obs.copy()
+
+        # Build agent specific observations
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(obs[a]),
+            data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node,
                                                     current_depth=0)
-            data = norm_obs_clip(data)
+            data = norm_obs_clip(data, fixed_radius=observation_radius)
             distance = norm_obs_clip(distance)
             agent_data = np.clip(agent_data, -1, 1)
-            obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
-            agent_data = env.agents[a]
-            speed = 1  # np.random.randint(1,5)
-            agent_data.speed_data['speed'] = 1. / speed
-
-        for i in range(2):
-            time_obs.append(obs)
-        # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5)
-        for a in range(env.get_num_agents()):
-            agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+            agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
 
         score = 0
         env_done = 0
+
         # Run episode
         for step in range(max_steps):
-            if demo:
-                env_renderer.renderEnv(show=True, show_observations=False)
-                # observation_helper.util_print_obs_subtree(obs_original[0])
-                if record_images:
-                    env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step))
-                    frame_step += 1
-            # print(step)
+
             # Action
             for a in range(env.get_num_agents()):
-                if demo:
-                    eps = 0
-                # action = agent.act(np.array(obs[a]), eps=eps)
                 action = agent.act(agent_obs[a], eps=eps)
                 action_prob[action] += 1
                 action_dict.update({a: action})
-            # Environment step
 
+            # Environment step
             next_obs, all_rewards, done, _ = env.step(action_dict)
-            # print(all_rewards,action)
-            obs_original = next_obs.copy()
+
+            # Build agent specific observations and normalize
             for a in range(env.get_num_agents()):
                 data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                        current_depth=0)
-                data = norm_obs_clip(data)
+                                                        num_features_per_node=num_features_per_node, current_depth=0)
+                data = norm_obs_clip(data, fixed_radius=observation_radius)
                 distance = norm_obs_clip(distance)
                 agent_data = np.clip(agent_data, -1, 1)
-                next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
-            time_obs.append(next_obs)
+                agent_next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
 
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
-                agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
                 if done[a]:
                     final_obs[a] = agent_obs[a].copy()
                     final_obs_next[a] = agent_next_obs[a].copy()
                     final_action_dict.update({a: action_dict[a]})
-                if not demo and not done[a]:
+                if not done[a]:
                     agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
                 score += all_rewards[a] / env.get_num_agents()
 
+            # Copy observation
             agent_obs = agent_next_obs.copy()
+
             if done['__all__']:
                 env_done = 1
                 for a in range(env.get_num_agents()):
                     agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
                 break
+
         # Epsilon decay
         eps = max(eps_end, eps_decay * eps)  # decrease epsilon
 
+        # Collection information about training
         done_window.append(env_done)
         scores_window.append(score / max_steps)  # save most recent score
         scores.append(np.mean(scores_window))
@@ -193,22 +189,22 @@ def main(argv):
         print(
             '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
                 env.get_num_agents(), x_dim, y_dim,
-                trials,
+                episodes,
                 np.mean(scores_window),
                 100 * np.mean(done_window),
                 eps, action_prob / np.sum(action_prob)), end=" ")
 
-        if trials % 100 == 0:
+        if episodes % 100 == 0:
             print(
                 '\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
                     env.get_num_agents(),
-                    trials,
+                    episodes,
                     np.mean(scores_window),
                     100 * np.mean(done_window),
                     eps,
                     action_prob / np.sum(action_prob)))
             torch.save(agent.qnetwork_local.state_dict(),
-                       './Nets/avoid_checkpoint' + str(trials) + '.pth')
+                       './Nets/avoid_checkpoint' + str(episodes) + '.pth')
             action_prob = [1] * action_size
     plt.plot(scores)
     plt.show()
diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c3225c24dba4c52756b42f41e53656b1f698be3
--- /dev/null
+++ b/torch_training/multi_agent_two_time_step_training.py
@@ -0,0 +1,221 @@
+# Import packages for plotting and system
+import getopt
+import random
+import sys
+from collections import deque
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+# Import Flatland/ Observations and Predictors
+from flatland.envs.generators import complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from importlib_resources import path
+
+# Import Torch and utility functions to normalize observation
+import torch_training.Nets
+from torch_training.dueling_double_dqn import Agent
+from utils.observation_utils import norm_obs_clip, split_tree
+
+
+def main(argv):
+    try:
+        opts, args = getopt.getopt(argv, "n:", ["n_episodes="])
+    except getopt.GetoptError:
+        print('training_navigation.py -n <n_episodes>')
+        sys.exit(2)
+    for opt, arg in opts:
+        if opt in ('-n', '--n_episodes'):
+            n_episodes = int(arg)
+
+    ## Initialize the random
+    random.seed(1)
+    np.random.seed(1)
+
+    # Initialize a random map with a random number of agents
+    x_dim = np.random.randint(8, 20)
+    y_dim = np.random.randint(8, 20)
+    n_agents = np.random.randint(3, 8)
+    n_goals = n_agents + np.random.randint(0, 3)
+    min_dist = int(0.75 * min(x_dim, y_dim))
+    tree_depth = 3
+    print("main2")
+
+    # Get an observation builder and predictor
+    predictor = ShortestPathPredictorForRailEnv()
+    observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor())
+
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                                                        max_dist=99999,
+                                                        seed=0),
+                  obs_builder_object=observation_helper,
+                  number_of_agents=n_agents)
+    env.reset(True, True)
+
+    handle = env.get_agent_handles()
+    features_per_node = env.obs_builder.observation_dim
+    tree_depth = 2
+    nr_nodes = 0
+    for i in range(tree_depth + 1):
+        nr_nodes += np.power(4, i)
+    state_size = 2 * features_per_node * nr_nodes  # We will use two time steps per observation --> 2x state_size
+    action_size = 5
+
+    # We set the number of episodes we would like to train on
+    if 'n_episodes' not in locals():
+        n_episodes = 60000
+
+    # Set max number of steps per episode as well as other training relevant parameter
+    max_steps = int(3 * (env.height + env.width))
+    eps = 1.
+    eps_end = 0.005
+    eps_decay = 0.9995
+    action_dict = dict()
+    final_action_dict = dict()
+    scores_window = deque(maxlen=100)
+    done_window = deque(maxlen=100)
+    time_obs = deque(maxlen=2)
+    scores = []
+    dones_list = []
+    action_prob = [0] * action_size
+    agent_obs = [None] * env.get_num_agents()
+    agent_next_obs = [None] * env.get_num_agents()
+    # Initialize the agent
+    agent = Agent(state_size, action_size, "FC", 0)
+
+    # Here you can pre-load an agent
+    if False:
+        with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
+            agent.qnetwork_local.load_state_dict(torch.load(file_in))
+
+    # Do training over n_episodes
+    for episodes in range(1, n_episodes + 1):
+        """
+        Training Curriculum: In order to get good generalization we change the number of agents
+        and the size of the levels every 50 episodes.
+        """
+        if episodes % 50 == 0:
+            x_dim = np.random.randint(8, 20)
+            y_dim = np.random.randint(8, 20)
+            n_agents = np.random.randint(3, 8)
+            n_goals = n_agents + np.random.randint(0, 3)
+            min_dist = int(0.75 * min(x_dim, y_dim))
+            env = RailEnv(width=x_dim,
+                          height=y_dim,
+                          rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                                                                max_dist=99999,
+                                                                seed=0),
+                          obs_builder_object=TreeObsForRailEnv(max_depth=3,
+                                                               predictor=ShortestPathPredictorForRailEnv()),
+                          number_of_agents=n_agents)
+
+            # Adjust the parameters according to the new env.
+            max_steps = int(3 * (env.height + env.width))
+            agent_obs = [None] * env.get_num_agents()
+            agent_next_obs = [None] * env.get_num_agents()
+
+        # Reset environment
+        obs = env.reset(True, True)
+
+        # Setup placeholder for finals observation of a single agent. This is necessary because agents terminate at
+        # different times during an episode
+        final_obs = agent_obs.copy()
+        final_obs_next = agent_next_obs.copy()
+
+        # Build agent specific observations
+        for a in range(env.get_num_agents()):
+            data, distance, agent_data = split_tree(tree=np.array(obs[a]),
+                                                    current_depth=0)
+            data = norm_obs_clip(data)
+            distance = norm_obs_clip(distance)
+            agent_data = np.clip(agent_data, -1, 1)
+            obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+
+        # Accumulate two time steps of observation (Here just twice the first state)
+        for i in range(2):
+            time_obs.append(obs)
+
+        # Build the agent specific double ti
+        for a in range(env.get_num_agents()):
+            agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+
+        score = 0
+        env_done = 0
+        # Run episode
+        for step in range(max_steps):
+
+            # Action
+            for a in range(env.get_num_agents()):
+                if demo:
+                    eps = 0
+                # action = agent.act(np.array(obs[a]), eps=eps)
+                action = agent.act(agent_obs[a], eps=eps)
+                action_prob[action] += 1
+                action_dict.update({a: action})
+            # Environment step
+
+            next_obs, all_rewards, done, _ = env.step(action_dict)
+            for a in range(env.get_num_agents()):
+                data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                        current_depth=0)
+                data = norm_obs_clip(data)
+                distance = norm_obs_clip(distance)
+                agent_data = np.clip(agent_data, -1, 1)
+                next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data))
+            time_obs.append(next_obs)
+
+            # Update replay buffer and train agent
+            for a in range(env.get_num_agents()):
+                agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a]))
+                if done[a]:
+                    final_obs[a] = agent_obs[a].copy()
+                    final_obs_next[a] = agent_next_obs[a].copy()
+                    final_action_dict.update({a: action_dict[a]})
+                if not demo and not done[a]:
+                    agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
+                score += all_rewards[a] / env.get_num_agents()
+
+            agent_obs = agent_next_obs.copy()
+            if done['__all__']:
+                env_done = 1
+                for a in range(env.get_num_agents()):
+                    agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a])
+                break
+        # Epsilon decay
+        eps = max(eps_end, eps_decay * eps)  # decrease epsilon
+
+        done_window.append(env_done)
+        scores_window.append(score / max_steps)  # save most recent score
+        scores.append(np.mean(scores_window))
+        dones_list.append((np.mean(done_window)))
+
+        print(
+            '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                env.get_num_agents(), x_dim, y_dim,
+                episodes,
+                np.mean(scores_window),
+                100 * np.mean(done_window),
+                eps, action_prob / np.sum(action_prob)), end=" ")
+
+        if episodes % 100 == 0:
+            print(
+                '\rTraining {} Agents.\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                    env.get_num_agents(),
+                    episodes,
+                    np.mean(scores_window),
+                    100 * np.mean(done_window),
+                    eps,
+                    action_prob / np.sum(action_prob)))
+            torch.save(agent.qnetwork_local.state_dict(),
+                       './Nets/avoid_checkpoint' + str(episodes) + '.pth')
+            action_prob = [1] * action_size
+    plt.plot(scores)
+    plt.show()
+
+
+if __name__ == '__main__':
+    main(sys.argv[1:])
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 6ea6b5672a6939c290d0395d7d8795d47b14b508..2b836087611f5a5a0b01e47d6d91b78c16da9d42 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -52,12 +52,12 @@ def main(argv):
     env_renderer = RenderTool(env, gl="PILSVG", )
 
     # Given the depth of the tree observation and the number of features per node we get the following state_size
-    features_per_node = env.obs_builder.observation_dim
+    num_features_per_node = env.obs_builder.observation_dim
     tree_depth = 2
     nr_nodes = 0
     for i in range(tree_depth + 1):
         nr_nodes += np.power(4, i)
-    state_size = features_per_node * nr_nodes
+    state_size = num_features_per_node * nr_nodes
 
     # The action space of flatland is 5 discrete actions
     action_size = 5
@@ -102,6 +102,7 @@ def main(argv):
         # Build agent specific local observation
         for a in range(env.get_num_agents()):
             rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
+                                                              num_features_per_node=num_features_per_node,
                                                               current_depth=0)
             rail_data = norm_obs_clip(rail_data)
             distance_data = norm_obs_clip(distance_data)
@@ -135,6 +136,7 @@ def main(argv):
 
             for a in range(env.get_num_agents()):
                 rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                                  num_features_per_node=num_features_per_node,
                                                                   current_depth=0)
                 rail_data = norm_obs_clip(rail_data)
                 distance_data = norm_obs_clip(distance_data)
@@ -195,6 +197,7 @@ def main(argv):
     # Build agent specific local observation
     for a in range(env.get_num_agents()):
         rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
+                                                          num_features_per_node=num_features_per_node,
                                                           current_depth=0)
         rail_data = norm_obs_clip(rail_data)
         distance_data = norm_obs_clip(distance_data)
@@ -220,6 +223,7 @@ def main(argv):
 
         for a in range(env.get_num_agents()):
             rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                              num_features_per_node=num_features_per_node,
                                                               current_depth=0)
             rail_data = norm_obs_clip(rail_data)
             distance_data = norm_obs_clip(distance_data)
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 121c6eb593043d81dfed61ed5b37f65eaef9af4d..c5f0d5dba7dddc00b9d981325c078b78428ecf48 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -1,7 +1,5 @@
 import numpy as np
 
-from flatland.envs.observations import TreeObsForRailEnv
-
 
 def max_lt(seq, val):
     """
@@ -54,7 +52,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0):
     return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
 
 
-def split_tree(tree, current_depth=0):
+def split_tree(tree, num_features_per_node, current_depth=0):
     """
     Splits the tree observation into different sub groups that need the same normalization.
     This is necessary because the tree observation includes two different distance:
@@ -70,7 +68,6 @@ def split_tree(tree, current_depth=0):
     :param current_depth: Keeping track of the current depth in the tree
     :return: Returns the three different groups of distance and binary values.
     """
-    num_features_per_node = TreeObsForRailEnv.observation_dim
     if len(tree) < num_features_per_node:
         return [], [], []
 
@@ -93,7 +90,7 @@ def split_tree(tree, current_depth=0):
     for children in range(4):
         child_tree = tree[(num_features_per_node + children * child_size):
                           (num_features_per_node + (children + 1) * child_size)]
-        tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree,
+        tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree, num_features_per_node,
                                                                       current_depth=current_depth + 1)
         if len(tmp_tree_data) > 0:
             tree_data.extend(tmp_tree_data)