diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
index d06506383544df5ef03b4128c37213d0a1a2ac69..989800044250d60b68c68b5d2e702b5625964024 100644
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ b/RLLib_training/RailEnvRLLibWrapper.py
@@ -1,9 +1,10 @@
 import numpy as np
-from flatland.envs.generators import complex_rail_generator, random_rail_generator
-from flatland.envs.rail_env import RailEnv
 from ray.rllib.env.multi_agent_env import MultiAgentEnv
 from ray.rllib.utils.seed import seed as set_seed
 
+from flatland.envs.generators import complex_rail_generator, random_rail_generator
+from flatland.envs.rail_env import RailEnv
+
 
 class RailEnvRLLibWrapper(MultiAgentEnv):
 
@@ -63,7 +64,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
 
         for i_agent in range(len(self.env.agents)):
             data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
-                                                                         num_features_per_node=8, current_depth=0)
+                                                                         current_depth=0)
             o[i_agent] = [data, distance, agent_data]
 
         # needed for the renderer
@@ -72,8 +73,6 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         self.agents_static = self.env.agents_static
         self.dev_obs_dict = self.env.dev_obs_dict
 
-
-
         # If step_memory > 1, we need to concatenate it the observations in memory, only works for
         # step_memory = 1 or 2 for the moment
         if self.step_memory < 2:
@@ -96,7 +95,7 @@ class RailEnvRLLibWrapper(MultiAgentEnv):
         for i_agent in range(len(self.env.agents)):
             if i_agent not in self.agents_done:
                 data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
-                                                                             num_features_per_node=8, current_depth=0)
+                                                                             current_depth=0)
 
                 o[i_agent] = [data, distance, agent_data]
                 r[i_agent] = rewards[i_agent]
diff --git a/requirements_torch_training.txt b/requirements_torch_training.txt
index 2bce630587233b4c771ef7a43bc3aaf7f78fbb07..d8c4a46b356cde45be5563ff77cac2042cfa3ff1 100644
--- a/requirements_torch_training.txt
+++ b/requirements_torch_training.txt
@@ -1 +1,4 @@
-torch==1.1.0
\ No newline at end of file
+git+https://gitlab.aicrowd.com/flatland/flatland.git@master
+importlib-metadata>=0.17
+importlib_resources>=1.0.2
+torch>=1.1.0
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 723e1a6f701150f5853a0199057bc234137c2aa2..2b9b731ea02a0c9bdbea7602ea1dfa2ad6e194e2 100644
--- a/setup.py
+++ b/setup.py
@@ -1,13 +1,7 @@
-import os
-
 from setuptools import setup, find_packages
 
-# TODO: setup does not support installation from url, move to requirements*.txt
-# TODO: @master as soon as mr is merged on flatland.
-os.system(
-    'pip install git+https://gitlab.aicrowd.com/flatland/flatland.git@57-access-resources-through-importlib_resources')
-
 install_reqs = []
+dependency_links = []
 # TODO: include requirements_RLLib_training.txt
 requirements_paths = ['requirements_torch_training.txt']  # , 'requirements_RLLib_training.txt']
 for requirements_path in requirements_paths:
@@ -15,8 +9,15 @@ for requirements_path in requirements_paths:
         install_reqs += [
             s for s in [
                 line.strip(' \n') for line in f
-            ] if not s.startswith('#') and s != ''
+            ] if not s.startswith('#') and s != '' and not s.startswith('git+')
         ]
+with open(requirements_path, 'r') as f:
+    dependency_links += [
+        s for s in [
+            line.strip(' \n') for line in f
+        ] if s.startswith('git+')
+    ]
+
 requirements = install_reqs
 setup_requirements = install_reqs
 test_requirements = install_reqs
@@ -47,6 +48,7 @@ setup(
     setup_requires=setup_requirements,
     test_suite='tests',
     tests_require=test_requirements,
+    dependency_links=dependency_links,
     url='https://gitlab.aicrowd.com/flatland/baselines',
     version='0.1.1',
     zip_safe=False,
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 79355763c74b8270430770b3f5613f1068f20284..7b9470d3b980657073af02e87b92edfea98bd879 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -1,16 +1,18 @@
-import random
 from collections import deque
 
 import matplotlib.pyplot as plt
 import numpy as np
+import random
 import torch
 from dueling_double_dqn import Agent
+from importlib_resources import path
+
+import torch_training.Nets
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
-
 from utils.observation_utils import norm_obs_clip, split_tree
 
 random.seed(1)
@@ -62,7 +64,8 @@ action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint30000.pth'))
+with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
+    agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 demo = False
 record_images = False
@@ -97,7 +100,7 @@ for trials in range(1, n_trials + 1):
     final_obs = obs.copy()
     final_obs_next = obs.copy()
     for a in range(env.get_num_agents()):
-        data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
+        data, distance, agent_data = split_tree(tree=np.array(obs[a]),
                                                 current_depth=0)
         data = norm_obs_clip(data)
         distance = norm_obs_clip(distance)
@@ -138,7 +141,7 @@ for trials in range(1, n_trials + 1):
         # print(all_rewards,action)
         obs_original = next_obs.copy()
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), num_features_per_node=features_per_node,
+            data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index dd4d4799b9e72a825be040f722d078d04afacbdf..3096c8766e7fb953c855bb5b95d40794f2ee00f2 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -1,8 +1,10 @@
-import random
+import sys
 from collections import deque
+
+import getopt
 import matplotlib.pyplot as plt
 import numpy as np
-
+import random
 import torch
 from dueling_double_dqn import Agent
 
@@ -12,83 +14,188 @@ from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 from utils.observation_utils import norm_obs_clip, split_tree
 
-random.seed(1)
-np.random.seed(1)
-
-# Parameters for the Environment
-x_dim = 10
-y_dim = 10
-n_agents = 1
-n_goals = 5
-min_dist = 5
-
-# We are training an Agent using the Tree Observation with depth 2
-observation_builder = TreeObsForRailEnv(max_depth=2)
-
-# Load the Environment
-env = RailEnv(width=x_dim,
-              height=y_dim,
-              rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
-                                                    max_dist=99999,
-                                                    seed=0),
-              obs_builder_object=observation_builder,
-              number_of_agents=n_agents)
-env.reset(True, True)
-
-# After training we want to render the results so we also load a renderer
-env_renderer = RenderTool(env, gl="PILSVG", )
-
-# Given the depth of the tree observation and the number of features per node we get the following state_size
-features_per_node = 9
-tree_depth = 2
-nr_nodes = 0
-for i in range(tree_depth + 1):
-    nr_nodes += np.power(4, i)
-state_size = features_per_node * nr_nodes
-
-# The action space of flatland is 5 discrete actions
-action_size = 5
-
-# We set the number of episodes we would like to train on
-n_trials = 6000
-
-# And the max number of steps we want to take per episode
-max_steps = int(3 * (env.height + env.width))
-
-# Define training parameters
-eps = 1.
-eps_end = 0.005
-eps_decay = 0.998
-
-# And some variables to keep track of the progress
-action_dict = dict()
-final_action_dict = dict()
-scores_window = deque(maxlen=100)
-done_window = deque(maxlen=100)
-time_obs = deque(maxlen=2)
-scores = []
-dones_list = []
-action_prob = [0] * action_size
-agent_obs = [None] * env.get_num_agents()
-agent_next_obs = [None] * env.get_num_agents()
-
-# Now we load a Double dueling DQN agent
-agent = Agent(state_size, action_size, "FC", 0)
-
-Training = True
-
-for trials in range(1, n_trials + 1):
+
+def main(argv):
+
+    try:
+        opts, args = getopt.getopt(argv, "n:", ["n_trials="])
+    except getopt.GetoptError:
+        print('training_navigation.py -n <n_trials>')
+        sys.exit(2)
+    for opt, arg in opts:
+        if opt in ('-n','--n_trials'):
+            n_trials = arg
+
+    random.seed(1)
+    np.random.seed(1)
+
+    # Parameters for the Environment
+    x_dim = 10
+    y_dim = 10
+    n_agents = 1
+    n_goals = 5
+    min_dist = 5
+
+    # We are training an Agent using the Tree Observation with depth 2
+    observation_builder = TreeObsForRailEnv(max_depth=2)
+
+    # Load the Environment
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                                                        max_dist=99999,
+                                                        seed=0),
+                  obs_builder_object=observation_builder,
+                  number_of_agents=n_agents)
+    env.reset(True, True)
+
+    # After training we want to render the results so we also load a renderer
+    env_renderer = RenderTool(env, gl="PILSVG", )
+
+    # Given the depth of the tree observation and the number of features per node we get the following state_size
+    features_per_node = 9
+    tree_depth = 2
+    nr_nodes = 0
+    for i in range(tree_depth + 1):
+        nr_nodes += np.power(4, i)
+    state_size = features_per_node * nr_nodes
+
+    # The action space of flatland is 5 discrete actions
+    action_size = 5
+
+    # We set the number of episodes we would like to train on
+    if 'n_trials' not in locals():
+        n_trials = 6000
+
+    # And the max number of steps we want to take per episode
+    max_steps = int(3 * (env.height + env.width))
+
+    # Define training parameters
+    eps = 1.
+    eps_end = 0.005
+    eps_decay = 0.998
+
+    # And some variables to keep track of the progress
+    action_dict = dict()
+    final_action_dict = dict()
+    scores_window = deque(maxlen=100)
+    done_window = deque(maxlen=100)
+    time_obs = deque(maxlen=2)
+    scores = []
+    dones_list = []
+    action_prob = [0] * action_size
+    agent_obs = [None] * env.get_num_agents()
+    agent_next_obs = [None] * env.get_num_agents()
+
+    # Now we load a Double dueling DQN agent
+    agent = Agent(state_size, action_size, "FC", 0)
+
+    Training = True
+
+    for trials in range(1, n_trials + 1):
+
+        # Reset environment
+        obs = env.reset(True, True)
+        if not Training:
+            env_renderer.set_new_rail()
+
+        # Split the observation tree into its parts and normalize the observation using the utility functions.
+        # Build agent specific local observation
+        for a in range(env.get_num_agents()):
+            rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
+                                                              current_depth=0)
+            rail_data = norm_obs_clip(rail_data)
+            distance_data = norm_obs_clip(distance_data)
+            agent_data = np.clip(agent_data, -1, 1)
+            agent_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
+
+        # Reset score and done
+        score = 0
+        env_done = 0
+
+        # Run episode
+        for step in range(max_steps):
+
+            # Only render when not triaing
+            if not Training:
+                env_renderer.renderEnv(show=True, show_observations=True)
+
+            # Chose the actions
+            for a in range(env.get_num_agents()):
+                if not Training:
+                    eps = 0
+
+                action = agent.act(agent_obs[a], eps=eps)
+                action_dict.update({a: action})
+
+                # Count number of actions takes for statistics
+                action_prob[action] += 1
+
+            # Environment step
+            next_obs, all_rewards, done, _ = env.step(action_dict)
+
+            for a in range(env.get_num_agents()):
+                rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                                  current_depth=0)
+                rail_data = norm_obs_clip(rail_data)
+                distance_data = norm_obs_clip(distance_data)
+                agent_data = np.clip(agent_data, -1, 1)
+                agent_next_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
+
+            # Update replay buffer and train agent
+            for a in range(env.get_num_agents()):
+
+                # Remember and train agent
+                if Training:
+                    agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
+
+                # Update the current score
+                score += all_rewards[a] / env.get_num_agents()
+
+            agent_obs = agent_next_obs.copy()
+            if done['__all__']:
+                env_done = 1
+                break
+
+        # Epsilon decay
+        eps = max(eps_end, eps_decay * eps)  # decrease epsilon
+
+        # Store the information about training progress
+        done_window.append(env_done)
+        scores_window.append(score / max_steps)  # save most recent score
+        scores.append(np.mean(scores_window))
+        dones_list.append((np.mean(done_window)))
+
+        print(
+            '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                env.get_num_agents(), x_dim, y_dim,
+                trials,
+                np.mean(scores_window),
+                100 * np.mean(done_window),
+                eps, action_prob / np.sum(action_prob)), end=" ")
+
+        if trials % 100 == 0:
+            print(
+                '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                    env.get_num_agents(), x_dim, y_dim,
+                    trials,
+                    np.mean(scores_window),
+                    100 * np.mean(done_window),
+                    eps, action_prob / np.sum(action_prob)))
+            torch.save(agent.qnetwork_local.state_dict(),
+                       './Nets/navigator_checkpoint' + str(trials) + '.pth')
+            action_prob = [1] * action_size
+
+    # Render the trained agent
 
     # Reset environment
     obs = env.reset(True, True)
-    if not Training:
-        env_renderer.set_new_rail()
+    env_renderer.set_new_rail()
 
     # Split the observation tree into its parts and normalize the observation using the utility functions.
     # Build agent specific local observation
     for a in range(env.get_num_agents()):
         rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]),
-                                                          num_features_per_node=features_per_node,
                                                           current_depth=0)
         rail_data = norm_obs_clip(rail_data)
         distance_data = norm_obs_clip(distance_data)
@@ -101,123 +208,32 @@ for trials in range(1, n_trials + 1):
 
     # Run episode
     for step in range(max_steps):
-
-        # Only render when not triaing
-        if not Training:
-            env_renderer.renderEnv(show=True, show_observations=True)
+        env_renderer.renderEnv(show=True, show_observations=False)
 
         # Chose the actions
         for a in range(env.get_num_agents()):
-            if not Training:
-                eps = 0
-
+            eps = 0
             action = agent.act(agent_obs[a], eps=eps)
             action_dict.update({a: action})
 
-            # Count number of actions takes for statistics
-            action_prob[action] += 1
-
         # Environment step
         next_obs, all_rewards, done, _ = env.step(action_dict)
 
         for a in range(env.get_num_agents()):
             rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                              num_features_per_node=features_per_node,
                                                               current_depth=0)
             rail_data = norm_obs_clip(rail_data)
             distance_data = norm_obs_clip(distance_data)
             agent_data = np.clip(agent_data, -1, 1)
             agent_next_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
 
-        # Update replay buffer and train agent
-        for a in range(env.get_num_agents()):
-
-            # Remember and train agent
-            if Training:
-                agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a])
-
-            # Update the current score
-            score += all_rewards[a] / env.get_num_agents()
-
         agent_obs = agent_next_obs.copy()
         if done['__all__']:
-            env_done = 1
             break
+    # Plot overall training progress at the end
+    plt.plot(scores)
+    plt.show()
 
-    # Epsilon decay
-    eps = max(eps_end, eps_decay * eps)  # decrease epsilon
-
-    # Store the information about training progress
-    done_window.append(env_done)
-    scores_window.append(score / max_steps)  # save most recent score
-    scores.append(np.mean(scores_window))
-    dones_list.append((np.mean(done_window)))
-
-    print(
-        '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
-            env.get_num_agents(), x_dim, y_dim,
-            trials,
-            np.mean(scores_window),
-            100 * np.mean(done_window),
-            eps, action_prob / np.sum(action_prob)), end=" ")
 
-    if trials % 100 == 0:
-        print(
-            '\rTraining {} Agents on ({},{}).\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
-                env.get_num_agents(), x_dim, y_dim,
-                trials,
-                np.mean(scores_window),
-                100 * np.mean(done_window),
-                eps, action_prob / np.sum(action_prob)))
-        torch.save(agent.qnetwork_local.state_dict(),
-                   './Nets/navigator_checkpoint' + str(trials) + '.pth')
-        action_prob = [1] * action_size
-
-# Render the trained agent
-
-# Reset environment
-obs = env.reset(True, True)
-env_renderer.set_new_rail()
-
-# Split the observation tree into its parts and normalize the observation using the utility functions.
-# Build agent specific local observation
-for a in range(env.get_num_agents()):
-    rail_data, distance_data, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=features_per_node,
-                                                      current_depth=0)
-    rail_data = norm_obs_clip(rail_data)
-    distance_data = norm_obs_clip(distance_data)
-    agent_data = np.clip(agent_data, -1, 1)
-    agent_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
-
-# Reset score and done
-score = 0
-env_done = 0
-
-# Run episode
-for step in range(max_steps):
-    env_renderer.renderEnv(show=True, show_observations=False)
-
-    # Chose the actions
-    for a in range(env.get_num_agents()):
-        eps = 0
-        action = agent.act(agent_obs[a], eps=eps)
-        action_dict.update({a: action})
-
-    # Environment step
-    next_obs, all_rewards, done, _ = env.step(action_dict)
-
-    for a in range(env.get_num_agents()):
-        rail_data, distance_data, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                          num_features_per_node=features_per_node,
-                                                          current_depth=0)
-        rail_data = norm_obs_clip(rail_data)
-        distance_data = norm_obs_clip(distance_data)
-        agent_data = np.clip(agent_data, -1, 1)
-        agent_next_obs[a] = np.concatenate((np.concatenate((rail_data, distance_data)), agent_data))
-
-    agent_obs = agent_next_obs.copy()
-    if done['__all__']:
-        break
-# Plot overall training progress at the end
-plt.plot(scores)
-plt.show()
+if __name__ == '__main__':
+    main(sys.argv[1:])
diff --git a/tox.ini b/tox.ini
index 3c22b56780ffa59d41f64cad3f9698c3f62a204d..3363d92b214bdeb74c06a2f958533d306b424cbf 100644
--- a/tox.ini
+++ b/tox.ini
@@ -15,13 +15,14 @@ setenv =
     PYTHONPATH = {toxinidir}
 passenv =
     DISPLAY
+    XAUTHORITY
 ; HTTP_PROXY+HTTPS_PROXY required behind corporate proxies
     HTTP_PROXY
     HTTPS_PROXY
 deps =
     -r{toxinidir}/requirements_torch_training.txt
 commands =
-    python torch_training/training_navigation.py
+    python torch_training/multi_agent_training.py --n_trials=100
 
 [flake8]
 max-line-length = 120
@@ -29,7 +30,12 @@ ignore = E121 E126 E123 E128 E133 E226 E241 E242 E704 W291 W293 W391 W503 W504 W
 
 [testenv:flake8]
 basepython = python
-passenv = DISPLAY
+passenv =
+    DISPLAY
+    XAUTHORITY
+; HTTP_PROXY+HTTPS_PROXY required behind corporate proxies
+    HTTP_PROXY
+    HTTPS_PROXY
 deps =
     -r{toxinidir}/requirements_torch_training.txt
 commands =
diff --git a/utils/misc_utils.py b/utils/misc_utils.py
index 03c9fdde9368bf324f7e10841b2d30b993858fd6..5b29c6b15f61b46062bac8d4fb6c4130fe61c6ec 100644
--- a/utils/misc_utils.py
+++ b/utils/misc_utils.py
@@ -101,7 +101,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
         lp_reset(True, True)
         obs = env.reset(True, True)
         for a in range(env.get_num_agents()):
-            data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=9,
+            data, distance, agent_data = split_tree(tree=np.array(obs[a]),
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
@@ -127,7 +127,6 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3):
 
             for a in range(env.get_num_agents()):
                 data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
-                                                        num_features_per_node=features_per_node,
                                                         current_depth=0)
                 data = norm_obs_clip(data)
                 distance = norm_obs_clip(distance)
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 0c97b186a9331f185cf1a1d3f99685581cb551f7..5e01121fe84c24d8d5a46d92bb30578e1dcda2b0 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -1,5 +1,7 @@
 import numpy as np
 
+from flatland.envs.observations import TreeObsForRailEnv
+
 
 def max_lt(seq, val):
     """
@@ -48,7 +50,7 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1):
     return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
 
 
-def split_tree(tree, num_features_per_node=9, current_depth=0):
+def split_tree(tree, current_depth=0):
     """
     Splits the tree observation into different sub groups that need the same normalization.
     This is necessary because the tree observation includes two different distance:
@@ -64,6 +66,7 @@ def split_tree(tree, num_features_per_node=9, current_depth=0):
     :param current_depth: Keeping track of the current depth in the tree
     :return: Returns the three different groups of distance and binary values.
     """
+    num_features_per_node = TreeObsForRailEnv.observation_dim
 
     if len(tree) < num_features_per_node:
         return [], [], []
@@ -88,7 +91,6 @@ def split_tree(tree, num_features_per_node=9, current_depth=0):
         child_tree = tree[(num_features_per_node + children * child_size):
                           (num_features_per_node + (children + 1) * child_size)]
         tmp_tree_data, tmp_distance_data, tmp_agent_data = split_tree(child_tree,
-                                                                      num_features_per_node,
                                                                       current_depth=current_depth + 1)
         if len(tmp_tree_data) > 0:
             tree_data.extend(tmp_tree_data)