diff --git a/MANIFEST.in b/MANIFEST.in
index 15ae22aa6240e9fc05354c6be3fe120383a8f91c..11453ed84ca71670f795228796a2058964549b8c 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -4,7 +4,6 @@ include changelog.md
 include LICENSE
 include README.md
 include requirements_torch_training.txt
-include requirements_RLLib_training.txt
 
 
 
diff --git a/README.md b/README.md
index cdd54b5229cd88c94618782bf242595b453413f8..d8dc09bde2ace60244a901506eb25d1790acb74a 100644
--- a/README.md
+++ b/README.md
@@ -15,9 +15,6 @@ With the above introductions you will solve tasks like these and even more...
 ![Conflict_Avoidance](https://i.imgur.com/AvBHKaD.gif)
 
 
-# RLLib Training
-The `RLLib_training` folder shows an example of how to train agents with  algorithm from implemented in the RLLib library available at: <https://github.com/ray-project/ray/tree/master/python/ray/rllib>
-
 # Sequential Agent
 This is a very simple baseline to show you have the `complex_level_generator` generates feasible network configurations.
 If you run the `run_test.py` file you will see a simple agent that solves the level by sequentially running each agent along its shortest path.
diff --git a/RLLib_training/README.md b/RLLib_training/README.md
deleted file mode 100644
index 8bda956f226af1c7ef4c7e1237b447cf7af4327a..0000000000000000000000000000000000000000
--- a/RLLib_training/README.md
+++ /dev/null
@@ -1,78 +0,0 @@
-This repository allows to run Rail Environment multi agent training with the RLLib Library.
-
-## Installation:
-
-To run scripts of this repository, the deep learning library tensorflow should be installed, along with the following packages:
-```sh
-pip install gym ray==0.7.0 gin-config opencv-python lz4 psutil
-```
-
-To start a training with different parameters, you can create a folder containing a config.gin file (see example in `experiment_configs/config_example/config.gin`.
-
-Then, you can modify the config.gin file path at the end of the `train_experiment.py` file.
-
-The results will be stored inside the folder, and the learning curves can be visualized in 
-tensorboard:
-
-```
-tensorboard --logdir=/path/to/folder_containing_config_gin_file
-```
-
-## Gin config files
-
-In each config.gin files, all the parameters of the `run_experiment` functions have to be specified.
-For example, to indicate the number of agents that have to be initialized at the beginning of each simulation, the following line should be added:
-
-```
-run_experiment.n_agents = 2
-```
-
-If several number of agents have to be explored during the experiment, one can pass the following value to the `n_agents` parameter:
-
-```
-run_experiment.n_agents = {"grid_search": [2,5]}
-```
-
-which is the way to indicate to the tune library to experiment several values for a parameter.
-
-To reference a class or an object within gin, you should first register it from the `train_experiment.py` script adding the following line:
-
-```
-gin.external_configurable(TreeObsForRailEnv)
-```
-
-and then a `TreeObsForRailEnv` object can be referenced in the `config.gin` file:
-
-```
-run_experiment.obs_builder = {"grid_search": [@TreeObsForRailEnv(), @GlobalObsForRailEnv()]}
-TreeObsForRailEnv.max_depth = 2
-```
-
-Note that `@TreeObsForRailEnv` references the class, while `@TreeObsForRailEnv()` references instantiates an object of this class.
-
-
-
-
-More documentation on how to use gin-config can be found on the github repository: https://github.com/google/gin-config
-
-## Run an example:
-To start a training on a 20X20 map, with different numbers of agents initialized at each episode, on can run the train_experiment.py script:
-```
-python RLLib_training/train_experiment.py
-```
-This will load the gin config file in the folder `experiment_configs/config_examples`.
-
-To visualize the result of a training, one can load a training checkpoint and use the policy learned.
-This is done in the `render_training_result.py` script. One has to modify the `CHECKPOINT_PATH` at the beginning of this script:
-
-```
-CHECKPOINT_PATH = os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'ppo_policy_two_obs_with_predictions_n_agents_4_map_size_20q58l5_f7',
-                               'checkpoint_101', 'checkpoint-101')
-```
-and load the corresponding gin config file:
-
-```
-gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'config.gin'))
-```
-
-
diff --git a/RLLib_training/RailEnvRLLibWrapper.py b/RLLib_training/RailEnvRLLibWrapper.py
deleted file mode 100644
index f82cd42d9bbd836b681ff284a82f357b2760bb0c..0000000000000000000000000000000000000000
--- a/RLLib_training/RailEnvRLLibWrapper.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import numpy as np
-from ray.rllib.env.multi_agent_env import MultiAgentEnv
-from ray.rllib.utils.seed import seed as set_seed
-
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import complex_rail_generator, random_rail_generator
-from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
-
-
-class RailEnvRLLibWrapper(MultiAgentEnv):
-
-    def __init__(self, config):
-
-        super(MultiAgentEnv, self).__init__()
-
-        # Environment ID if num_envs_per_worker > 1
-        if hasattr(config, "vector_index"):
-            vector_index = config.vector_index
-        else:
-            vector_index = 1
-
-        self.predefined_env = False
-
-        if config['rail_generator'] == "complex_rail_generator":
-            self.rail_generator = complex_rail_generator(nr_start_goal=config['number_of_agents'],
-                                                         min_dist=config['min_dist'],
-                                                         nr_extra=config['nr_extra'],
-                                                         seed=config['seed'] * (1 + vector_index))
-            self.schedule_generator = complex_schedule_generator()
-
-        elif config['rail_generator'] == "random_rail_generator":
-            self.rail_generator = random_rail_generator()
-            self.schedule_generator = random_schedule_generator()
-        elif config['rail_generator'] == "load_env":
-            self.predefined_env = True
-            self.rail_generator = random_rail_generator()
-            self.schedule_generator = random_schedule_generator()
-        else:
-            raise (ValueError, f'Unknown rail generator: {config["rail_generator"]}')
-
-        set_seed(config['seed'] * (1 + vector_index))
-        self.env = RailEnv(width=config["width"], height=config["height"],
-                           number_of_agents=config["number_of_agents"],
-                           obs_builder_object=config['obs_builder'],
-                           rail_generator=self.rail_generator,
-                           schedule_generator=self.schedule_generator
-                           )
-
-        if self.predefined_env:
-            self.env.load_resource('torch_training.railway', 'complex_scene.pkl')
-
-        self.width = self.env.width
-        self.height = self.env.height
-        self.step_memory = config["step_memory"]
-
-        # needed for the renderer
-        self.rail = self.env.rail
-        self.agents = self.env.agents
-        self.agents_static = self.env.agents_static
-        self.dev_obs_dict = self.env.dev_obs_dict
-
-    def reset(self):
-        self.agents_done = []
-        if self.predefined_env:
-            obs = self.env.reset(False, False)
-        else:
-            obs = self.env.reset()
-
-        # RLLib only receives observation of agents that are not done.
-        o = dict()
-
-        for i_agent in range(len(self.env.agents)):
-            data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
-                                                                         current_depth=0)
-            o[i_agent] = [data, distance, agent_data]
-
-        # needed for the renderer
-        self.rail = self.env.rail
-        self.agents = self.env.agents
-        self.agents_static = self.env.agents_static
-        self.dev_obs_dict = self.env.dev_obs_dict
-
-        # If step_memory > 1, we need to concatenate it the observations in memory, only works for
-        # step_memory = 1 or 2 for the moment
-        if self.step_memory < 2:
-            return o
-        else:
-            self.old_obs = o
-            oo = dict()
-
-            for i_agent in range(len(self.env.agents)):
-                oo[i_agent] = [o[i_agent], o[i_agent]]
-            return oo
-
-    def step(self, action_dict):
-        obs, rewards, dones, infos = self.env.step(action_dict)
-
-        d = dict()
-        r = dict()
-        o = dict()
-
-        for i_agent in range(len(self.env.agents)):
-            if i_agent not in self.agents_done:
-                data, distance, agent_data = self.env.obs_builder.split_tree(tree=np.array(obs[i_agent]),
-                                                                             current_depth=0)
-
-                o[i_agent] = [data, distance, agent_data]
-                r[i_agent] = rewards[i_agent]
-                d[i_agent] = dones[i_agent]
-
-        d['__all__'] = dones['__all__']
-
-        if self.step_memory >= 2:
-            oo = dict()
-
-            for i_agent in range(len(self.env.agents)):
-                if i_agent not in self.agents_done:
-                    oo[i_agent] = [o[i_agent], self.old_obs[i_agent]]
-
-            self.old_obs = o
-
-        for agent, done in dones.items():
-            if done and agent != '__all__':
-                self.agents_done.append(agent)
-
-        if self.step_memory < 2:
-            return o, r, d, infos
-        else:
-            return oo, r, d, infos
-
-    def get_agent_handles(self):
-        return self.env.get_agent_handles()
-
-    def get_num_agents(self):
-        return self.env.get_num_agents()
diff --git a/RLLib_training/__init__.py b/RLLib_training/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/RLLib_training/custom_preprocessors.py b/RLLib_training/custom_preprocessors.py
deleted file mode 100644
index d4c81a83f1c05317315a3f71f99565006e9311e1..0000000000000000000000000000000000000000
--- a/RLLib_training/custom_preprocessors.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import numpy as np
-from ray.rllib.models.preprocessors import Preprocessor
-from utils.observation_utils import norm_obs_clip
-
-class TreeObsPreprocessor(Preprocessor):
-    def _init_shape(self, obs_space, options):
-        print(options)
-        self.step_memory = options["custom_options"]["step_memory"]
-        return sum([space.shape[0] for space in obs_space]),
-
-    def transform(self, observation):
-
-        if self.step_memory == 2:
-            data = norm_obs_clip(observation[0][0])
-            distance = norm_obs_clip(observation[0][1])
-            agent_data = np.clip(observation[0][2], -1, 1)
-            data2 = norm_obs_clip(observation[1][0])
-            distance2 = norm_obs_clip(observation[1][1])
-            agent_data2 = np.clip(observation[1][2], -1, 1)
-        else:
-            data = norm_obs_clip(observation[0])
-            distance = norm_obs_clip(observation[1])
-            agent_data = np.clip(observation[2], -1, 1)
-
-        return np.concatenate((np.concatenate((np.concatenate((data, distance)), agent_data)), np.concatenate((np.concatenate((data2, distance2)), agent_data2))))
-
diff --git a/RLLib_training/experiment_configs/__init__.py b/RLLib_training/experiment_configs/__init__.py
deleted file mode 100644
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000
diff --git a/RLLib_training/experiment_configs/config_example/config.gin b/RLLib_training/experiment_configs/config_example/config.gin
deleted file mode 100644
index 59d2dfb508f13cccf4b9152f24ab06d44c290450..0000000000000000000000000000000000000000
--- a/RLLib_training/experiment_configs/config_example/config.gin
+++ /dev/null
@@ -1,25 +0,0 @@
-run_experiment.name = "experiment_example"
-run_experiment.num_iterations = 1002
-run_experiment.save_every = 100
-run_experiment.hidden_sizes = [32, 32]
-
-run_experiment.map_width = 20
-run_experiment.map_height = 20
-run_experiment.n_agents = {"grid_search": [3, 4, 5, 6, 7, 8]}
-run_experiment.rail_generator = "complex_rail_generator" # Change this to "load_env" in order to load a predefined complex scene
-run_experiment.nr_extra = 5
-run_experiment.policy_folder_name = "ppo_policy_two_obs_with_predictions_n_agents_{config[n_agents]}_"
-
-run_experiment.seed = 123
-
-run_experiment.conv_model = False
-
-run_experiment.obs_builder = @TreeObsForRailEnv()
-TreeObsForRailEnv.predictor = @ShortestPathPredictorForRailEnv()
-TreeObsForRailEnv.max_depth = 2
-
-run_experiment.entropy_coeff = 0.001
-run_experiment.kl_coeff = 0.2 
-run_experiment.lambda_gae = 0.9
-run_experiment.step_memory = 2
-run_experiment.min_dist = 10
diff --git a/RLLib_training/render_training_result.py b/RLLib_training/render_training_result.py
deleted file mode 100644
index 1ee7cc1ce394f3b40791706871aa180ec0510b52..0000000000000000000000000000000000000000
--- a/RLLib_training/render_training_result.py
+++ /dev/null
@@ -1,169 +0,0 @@
-from RailEnvRLLibWrapper import RailEnvRLLibWrapper
-from custom_preprocessors import TreeObsPreprocessor
-import gym
-import os
-
-from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
-from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
-from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
-
-from ray.rllib.models import ModelCatalog
-
-import ray
-import numpy as np
-
-import gin
-
-from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
-gin.external_configurable(DummyPredictorForRailEnv)
-gin.external_configurable(ShortestPathPredictorForRailEnv)
-
-from ray.rllib.utils.seed import seed as set_seed
-from flatland.envs.observations import TreeObsForRailEnv
-
-from flatland.utils.rendertools import RenderTool
-import time
-
-gin.external_configurable(TreeObsForRailEnv)
-
-ModelCatalog.register_custom_preprocessor("tree_obs_prep", TreeObsPreprocessor)
-ray.init()  # object_store_memory=150000000000, redis_max_memory=30000000000)
-
-__file_dirname__ = os.path.dirname(os.path.realpath(__file__))
-
-CHECKPOINT_PATH = os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'ppo_policy_two_obs_with_predictions_n_agents_4_map_size_20q58l5_f7',
-                               'checkpoint_101', 'checkpoint-101')  # To Modify
-N_EPISODES = 10
-N_STEPS_PER_EPISODE = 50
-
-
-def render_training_result(config):
-    print('Init Env')
-
-    set_seed(config['seed'], config['seed'], config['seed'])
-
-    # Example configuration to generate a random rail
-    env_config = {"width": config['map_width'],
-                  "height": config['map_height'],
-                  "rail_generator": config["rail_generator"],
-                  "nr_extra": config["nr_extra"],
-                  "number_of_agents": config['n_agents'],
-                  "seed": config['seed'],
-                  "obs_builder": config['obs_builder'],
-                  "min_dist": config['min_dist'],
-                  "step_memory": config["step_memory"]}
-
-    # Observation space and action space definitions
-    if isinstance(config["obs_builder"], TreeObsForRailEnv):
-        obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(168,)),) * 2)
-        preprocessor = TreeObsPreprocessor
-
-    else:
-        raise ValueError("Undefined observation space")
-
-    act_space = gym.spaces.Discrete(5)
-
-    # Dict with the different policies to train
-    policy_graphs = {
-        "ppo_policy": (PolicyGraph, obs_space, act_space, {})
-    }
-
-    def policy_mapping_fn(agent_id):
-        return "ppo_policy"
-
-    # Trainer configuration
-    trainer_config = DEFAULT_CONFIG.copy()
-
-    trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes']}
-
-    trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
-                                    "policy_mapping_fn": policy_mapping_fn,
-                                    "policies_to_train": list(policy_graphs.keys())}
-
-    trainer_config["num_workers"] = 0
-    trainer_config["num_cpus_per_worker"] = 4
-    trainer_config["num_gpus"] = 0.2
-    trainer_config["num_gpus_per_worker"] = 0.2
-    trainer_config["num_cpus_for_driver"] = 1
-    trainer_config["num_envs_per_worker"] = 1
-    trainer_config['entropy_coeff'] = config['entropy_coeff']
-    trainer_config["env_config"] = env_config
-    trainer_config["batch_mode"] = "complete_episodes"
-    trainer_config['simple_optimizer'] = False
-    trainer_config['postprocess_inputs'] = True
-    trainer_config['log_level'] = 'WARN'
-    trainer_config['num_sgd_iter'] = 10
-    trainer_config['clip_param'] = 0.2
-    trainer_config['kl_coeff'] = config['kl_coeff']
-    trainer_config['lambda'] = config['lambda_gae']
-
-    env = RailEnvRLLibWrapper(env_config)
-
-    trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config)
-
-    trainer.restore(CHECKPOINT_PATH)
-
-    policy = trainer.get_policy("ppo_policy")
-
-    preprocessor = preprocessor(obs_space, {"step_memory": config["step_memory"]})
-    env_renderer = RenderTool(env, gl="PILSVG")
-    for episode in range(N_EPISODES):
-
-        observation = env.reset()
-        for i in range(N_STEPS_PER_EPISODE):
-            preprocessed_obs = []
-            for obs in observation.values():
-                preprocessed_obs.append(preprocessor.transform(obs))
-            action, _, infos = policy.compute_actions(preprocessed_obs, [])
-            logits = infos['behaviour_logits']
-            actions = dict()
-
-            # We select the greedy action.
-            for j, logit in enumerate(logits):
-                actions[j] = np.argmax(logit)
-
-            # In case we prefer to sample an action stochastically according to the policy graph.
-            # for j, act in enumerate(action):
-                # actions[j] = act
-
-            # Time to see the rendering at one step
-            time.sleep(1)
-
-            env_renderer.renderEnv(show=True, frames=True, iEpisode=episode, iStep=i,
-                                   action_dict=list(actions.values()))
-
-            observation, _, _, _ = env.step(actions)
-
-    env_renderer.close_window()
-
-
-@gin.configurable
-def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
-                   map_width, map_height, policy_folder_name, obs_builder,
-                   entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
-                   step_memory, min_dist):
-
-    render_training_result(
-        config={"n_agents": n_agents,
-                "hidden_sizes": hidden_sizes,  # Array containing the sizes of the network layers
-                "save_every": save_every,
-                "map_width": map_width,
-                "map_height": map_height,
-                'policy_folder_name': policy_folder_name,
-                "obs_builder": obs_builder,
-                "entropy_coeff": entropy_coeff,
-                "seed": seed,
-                "conv_model": conv_model,
-                "rail_generator": rail_generator,
-                "nr_extra": nr_extra,
-                "kl_coeff": kl_coeff,
-                "lambda_gae": lambda_gae,
-                "min_dist": min_dist,
-                "step_memory": step_memory
-                }
-    )
-
-
-if __name__ == '__main__':
-    gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', 'config_example', 'config.gin'))  # To Modify
-    run_experiment()
diff --git a/RLLib_training/train_experiment.py b/RLLib_training/train_experiment.py
deleted file mode 100644
index 7435a8fed728ec363321ba7a2bcf04b186513559..0000000000000000000000000000000000000000
--- a/RLLib_training/train_experiment.py
+++ /dev/null
@@ -1,210 +0,0 @@
-import os
-
-import gin
-import gym
-from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
-
-# Import PPO trainer: we can replace these imports by any other trainer from RLLib.
-from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG
-from ray.rllib.agents.ppo.ppo import PPOTrainer as Trainer
-from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph as PolicyGraph
-from ray.rllib.models import ModelCatalog
-
-gin.external_configurable(DummyPredictorForRailEnv)
-gin.external_configurable(ShortestPathPredictorForRailEnv)
-
-import ray
-
-from ray.tune.logger import UnifiedLogger
-from ray.tune.logger import pretty_print
-import os
-
-from RailEnvRLLibWrapper import RailEnvRLLibWrapper
-import tempfile
-
-from ray import tune
-
-from ray.rllib.utils.seed import seed as set_seed
-from flatland.envs.observations import TreeObsForRailEnv
-
-gin.external_configurable(TreeObsForRailEnv)
-
-import numpy as np
-from custom_preprocessors import TreeObsPreprocessor
-
-ModelCatalog.register_custom_preprocessor("tree_obs_prep", TreeObsPreprocessor)
-ray.init()  # object_store_memory=150000000000, redis_max_memory=30000000000)
-
-__file_dirname__ = os.path.dirname(os.path.realpath(__file__))
-
-
-def on_episode_start(info):
-    episode = info['episode']
-    map_width = info['env'].envs[0].width
-    map_height = info['env'].envs[0].height
-    episode.horizon = 3*(map_width + map_height)
-
-
-def on_episode_end(info):
-    episode = info['episode']
-
-    # Calculation of a custom score metric: cum of all accumulated rewards, divided by the number of agents
-    # and the number of the maximum time steps of the episode.
-    score = 0
-    for k, v in episode._agent_reward_history.items():
-        score += np.sum(v)
-    score /= (len(episode._agent_reward_history) * episode.horizon)
-    
-    # Calculation of the proportion of solved episodes before the maximum time step
-    done = 0
-    if len(episode._agent_reward_history[0]) <= episode.horizon-5:
-        done = 1
-
-    episode.custom_metrics["score"] = score
-    episode.custom_metrics["proportion_episode_solved"] = done
-
-
-def train(config, reporter):
-    print('Init Env')
-
-    set_seed(config['seed'], config['seed'], config['seed'])
-
-    # Given the depth of the tree observation and the number of features per node we get the following state_size
-    num_features_per_node = config['obs_builder'].observation_dim
-    tree_depth = 2
-    nr_nodes = 0
-    for i in range(tree_depth + 1):
-        nr_nodes += np.power(4, i)
-    obs_size = num_features_per_node * nr_nodes
-
-
-    # Environment parameters
-    env_config = {"width": config['map_width'],
-                  "height": config['map_height'],
-                  "rail_generator": config["rail_generator"],
-                  "nr_extra": config["nr_extra"],
-                  "number_of_agents": config['n_agents'],
-                  "seed": config['seed'],
-                  "obs_builder": config['obs_builder'],
-                  "min_dist": config['min_dist'],
-                  "step_memory": config["step_memory"]}
-
-    # Observation space and action space definitions
-    if isinstance(config["obs_builder"], TreeObsForRailEnv):
-        obs_space = gym.spaces.Tuple((gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(obs_size,)),) * 2)
-        preprocessor = "tree_obs_prep"
-    else:
-        raise ValueError("Undefined observation space") # Only TreeObservation implemented for now.
-
-    act_space = gym.spaces.Discrete(5)
-
-    # Dict with the different policies to train. In this case, all trains follow the same policy
-    policy_graphs = {
-        "ppo_policy": (PolicyGraph, obs_space, act_space, {})
-    }
-
-    # Function that maps an agent id to the name of its respective policy.
-    def policy_mapping_fn(agent_id):
-        return "ppo_policy"
-
-    # Trainer configuration
-    trainer_config = DEFAULT_CONFIG.copy()
-    trainer_config['model'] = {"fcnet_hiddens": config['hidden_sizes'], "custom_preprocessor": preprocessor,
-            "custom_options": {"step_memory": config["step_memory"], "obs_size": obs_size}}
-
-    trainer_config['multiagent'] = {"policy_graphs": policy_graphs,
-                                    "policy_mapping_fn": policy_mapping_fn,
-                                    "policies_to_train": list(policy_graphs.keys())}
-
-    # Maximum time steps for an episode is set to 3*map_width*map_height
-    trainer_config["horizon"] = 3 * (config['map_width'] + config['map_height'])
-
-    # Parameters for calculation parallelization
-    trainer_config["num_workers"] = 0
-    trainer_config["num_cpus_per_worker"] = 8
-    trainer_config["num_gpus"] = 0.2
-    trainer_config["num_gpus_per_worker"] = 0.2
-    trainer_config["num_cpus_for_driver"] = 1
-    trainer_config["num_envs_per_worker"] = 1
-
-    # Parameters for PPO training
-    trainer_config['entropy_coeff'] = config['entropy_coeff']
-    trainer_config["env_config"] = env_config
-    trainer_config["batch_mode"] = "complete_episodes"
-    trainer_config['simple_optimizer'] = False
-    trainer_config['log_level'] = 'WARN'
-    trainer_config['num_sgd_iter'] = 10
-    trainer_config['clip_param'] = 0.2
-    trainer_config['kl_coeff'] = config['kl_coeff']
-    trainer_config['lambda'] = config['lambda_gae']
-    trainer_config['callbacks'] = {
-            "on_episode_start": tune.function(on_episode_start),
-            "on_episode_end": tune.function(on_episode_end)
-        }
-
-
-    def logger_creator(conf):
-        """Creates a Unified logger with a default logdir prefix."""
-        logdir = config['policy_folder_name'].format(**locals())
-        logdir = tempfile.mkdtemp(
-            prefix=logdir, dir=config['local_dir'])
-        return UnifiedLogger(conf, logdir, None)
-
-    logger = logger_creator
-
-    trainer = Trainer(env=RailEnvRLLibWrapper, config=trainer_config, logger_creator=logger)
-
-    for i in range(100000 + 2):
-        print("== Iteration", i, "==")
-
-        print(pretty_print(trainer.train()))
-
-        if i % config['save_every'] == 0:
-            checkpoint = trainer.save()
-            print("checkpoint saved at", checkpoint)
-
-        reporter(num_iterations_trained=trainer._iteration)
-
-
-@gin.configurable
-def run_experiment(name, num_iterations, n_agents, hidden_sizes, save_every,
-                   map_width, map_height, policy_folder_name, local_dir, obs_builder,
-                   entropy_coeff, seed, conv_model, rail_generator, nr_extra, kl_coeff, lambda_gae,
-                   step_memory, min_dist):
-    tune.run(
-        train,
-        name=name,
-        stop={"num_iterations_trained": num_iterations},
-        config={"n_agents": n_agents,
-                "hidden_sizes": hidden_sizes,  # Array containing the sizes of the network layers
-                "save_every": save_every,
-                "map_width": map_width,
-                "map_height": map_height,
-                "local_dir": local_dir,
-                'policy_folder_name': policy_folder_name,
-                "obs_builder": obs_builder,
-                "entropy_coeff": entropy_coeff,
-                "seed": seed,
-                "conv_model": conv_model,
-                "rail_generator": rail_generator,
-                "nr_extra": nr_extra,
-                "kl_coeff": kl_coeff,
-                "lambda_gae": lambda_gae,
-                "min_dist": min_dist,
-                "step_memory": step_memory  # If equal to two, the current observation plus
-                                            # the observation of last time step will be given as input the the model.
-                },
-        resources_per_trial={
-            "cpu": 8,
-            "gpu": 0.2
-        },
-        verbose=2,
-        local_dir=local_dir
-    )
-
-
-if __name__ == '__main__':
-    folder_name = 'config_example'  # To Modify
-    gin.parse_config_file(os.path.join(__file_dirname__, 'experiment_configs', folder_name, 'config.gin'))
-    dir = os.path.join(__file_dirname__, 'experiment_configs', folder_name)
-    run_experiment(local_dir=dir)
diff --git a/score_test.py b/score_test.py
index ff4a94c5e1b82c90eec0c5bf129bad496046e595..f52a1e010aefa085d42e46155eb0618e27f05702 100644
--- a/score_test.py
+++ b/score_test.py
@@ -20,9 +20,6 @@ nr_trials_per_test = 100
 test_results = []
 test_times = []
 test_dones = []
-# Load agent
-# agent = Agent(state_size, action_size, "FC", 0)
-# agent.qnetwork_local.load_state_dict(torch.load('./torch_training/Nets/avoid_checkpoint1700.pth'))
 agent = RandomAgent(state_size, action_size)
 start_time_scoring = time.time()
 test_idx = 0
diff --git a/scoring/score_test.py b/scoring/score_test.py
index 5665d446047d2c2dd0f7504de6de391ade98f1b3..4baee4a176e80a3c947cbd9402a78b76c735f839 100644
--- a/scoring/score_test.py
+++ b/scoring/score_test.py
@@ -28,8 +28,8 @@ test_dones = []
 sequential_agent_test = False
 
 # Load your agent
-agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint60000.pth'))
+agent = Agent(state_size, action_size)
+agent.qnetwork_local.load_state_dict(torch.load('../torch_training/Nets/avoid_checkpoint500.pth'))
 
 # Load the necessary Observation Builder and Predictor
 predictor = ShortestPathPredictorForRailEnv()
diff --git a/setup.py b/setup.py
index 2b9b731ea02a0c9bdbea7602ea1dfa2ad6e194e2..5bc77c5188d799da8898f959c180c78f9c1496f6 100644
--- a/setup.py
+++ b/setup.py
@@ -2,8 +2,7 @@ from setuptools import setup, find_packages
 
 install_reqs = []
 dependency_links = []
-# TODO: include requirements_RLLib_training.txt
-requirements_paths = ['requirements_torch_training.txt']  # , 'requirements_RLLib_training.txt']
+requirements_paths = ['requirements_torch_training.txt']
 for requirements_path in requirements_paths:
     with open(requirements_path, 'r') as f:
         install_reqs += [
diff --git a/torch_training/Nets/avoid_checkpoint15000.pth b/torch_training/Nets/avoid_checkpoint15000.pth
deleted file mode 100644
index e1daf228b7f1f6b108329715c3cdbd67805e28ae..0000000000000000000000000000000000000000
Binary files a/torch_training/Nets/avoid_checkpoint15000.pth and /dev/null differ
diff --git a/torch_training/Nets/avoid_checkpoint30000.pth b/torch_training/Nets/avoid_checkpoint30000.pth
deleted file mode 100644
index 0e2c1b28c1655bc16c9339066b8d105282f14418..0000000000000000000000000000000000000000
Binary files a/torch_training/Nets/avoid_checkpoint30000.pth and /dev/null differ
diff --git a/torch_training/Nets/avoid_checkpoint60000.pth b/torch_training/Nets/avoid_checkpoint60000.pth
deleted file mode 100644
index b4fef60542f50419353047721fae31f5382e7bd4..0000000000000000000000000000000000000000
Binary files a/torch_training/Nets/avoid_checkpoint60000.pth and /dev/null differ
diff --git a/torch_training/dueling_double_dqn.py b/torch_training/dueling_double_dqn.py
index cf2f7d512b99aafb9fe0477bf048441efa0bff9e..b7bb4bcdc7c72fa09b352fdf5cf99258f8f9ad0c 100644
--- a/torch_training/dueling_double_dqn.py
+++ b/torch_training/dueling_double_dqn.py
@@ -8,7 +8,7 @@ import torch
 import torch.nn.functional as F
 import torch.optim as optim
 
-from torch_training.model import QNetwork, QNetwork2
+from torch_training.model import QNetwork
 
 BUFFER_SIZE = int(1e5)  # replay buffer size
 BATCH_SIZE = 512  # minibatch size
@@ -16,43 +16,33 @@ GAMMA = 0.99  # discount factor 0.99
 TAU = 1e-3  # for soft update of target parameters
 LR = 0.5e-4  # learning rate 0.5e-4 works
 UPDATE_EVERY = 10  # how often to update the network
-double_dqn = True  # If using double dqn algorithm
-input_channels = 5  # Number of Input channels
 
 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-#device = torch.device("cpu")
 print(device)
 
 
 class Agent:
     """Interacts with and learns from the environment."""
 
-    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
+    def __init__(self, state_size, action_size, double_dqn=True):
         """Initialize an Agent object.
 
         Params
         ======
             state_size (int): dimension of each state
             action_size (int): dimension of each action
-            seed (int): random seed
         """
         self.state_size = state_size
         self.action_size = action_size
-        self.seed = random.seed(seed)
-        self.version = net_type
         self.double_dqn = double_dqn
         # Q-Network
-        if self.version == "Conv":
-            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
-        else:
-            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        self.qnetwork_local = QNetwork(state_size, action_size).to(device)
+        self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
 
         self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)
 
         # Replay memory
-        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE)
         # Initialize time step (for updating every UPDATE_EVERY steps)
         self.t_step = 0
 
@@ -152,7 +142,7 @@ class Agent:
 class ReplayBuffer:
     """Fixed-size buffer to store experience tuples."""
 
-    def __init__(self, action_size, buffer_size, batch_size, seed):
+    def __init__(self, action_size, buffer_size, batch_size):
         """Initialize a ReplayBuffer object.
 
         Params
@@ -160,13 +150,11 @@ class ReplayBuffer:
             action_size (int): dimension of each action
             buffer_size (int): maximum size of buffer
             batch_size (int): size of each training batch
-            seed (int): random seed
         """
         self.action_size = action_size
         self.memory = deque(maxlen=buffer_size)
         self.batch_size = batch_size
         self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
-        self.seed = random.seed(seed)
 
     def add(self, state, action, reward, next_state, done):
         """Add a new experience to memory."""
@@ -188,7 +176,7 @@ class ReplayBuffer:
         dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
             .float().to(device)
 
-        return (states, actions, rewards, next_states, dones)
+        return states, actions, rewards, next_states, dones
 
     def __len__(self):
         """Return the current size of internal memory."""
diff --git a/torch_training/model.py b/torch_training/model.py
index 7a5b3d613342a4fba8e2c8f1f45df21381e12684..9a5afccfda50e63271b1f5d8ed5a2d74b5e169e7 100644
--- a/torch_training/model.py
+++ b/torch_training/model.py
@@ -3,7 +3,7 @@ import torch.nn.functional as F
 
 
 class QNetwork(nn.Module):
-    def __init__(self, state_size, action_size, seed, hidsize1=128, hidsize2=128):
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
         super(QNetwork, self).__init__()
 
         self.fc1_val = nn.Linear(state_size, hidsize1)
@@ -24,38 +24,3 @@ class QNetwork(nn.Module):
         adv = F.relu(self.fc2_adv(adv))
         adv = self.fc3_adv(adv)
         return val + adv - adv.mean()
-
-
-class QNetwork2(nn.Module):
-    def __init__(self, state_size, action_size, seed, input_channels, hidsize1=128, hidsize2=64):
-        super(QNetwork2, self).__init__()
-        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1)
-        self.bn1 = nn.BatchNorm2d(16)
-        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=3)
-        self.bn2 = nn.BatchNorm2d(32)
-        self.conv3 = nn.Conv2d(32, 64, kernel_size=5, stride=3)
-        self.bn3 = nn.BatchNorm2d(64)
-
-        self.fc1_val = nn.Linear(6400, hidsize1)
-        self.fc2_val = nn.Linear(hidsize1, hidsize2)
-        self.fc3_val = nn.Linear(hidsize2, 1)
-
-        self.fc1_adv = nn.Linear(6400, hidsize1)
-        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
-        self.fc3_adv = nn.Linear(hidsize2, action_size)
-
-    def forward(self, x):
-        x = F.relu(self.conv1(x))
-        x = F.relu(self.conv2(x))
-        x = F.relu(self.conv3(x))
-
-        # value function approximation
-        val = F.relu(self.fc1_val(x.view(x.size(0), -1)))
-        val = F.relu(self.fc2_val(val))
-        val = self.fc3_val(val)
-
-        # advantage calculation
-        adv = F.relu(self.fc1_adv(x.view(x.size(0), -1)))
-        adv = F.relu(self.fc2_adv(adv))
-        adv = self.fc3_adv(adv)
-        return val + adv - adv.mean()
diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py
index 2fbbe61f7fd61eed4d8589695a06da238091717a..580886b1db73ba34d539e14968deea384b5b98be 100644
--- a/torch_training/multi_agent_inference.py
+++ b/torch_training/multi_agent_inference.py
@@ -4,8 +4,8 @@ from collections import deque
 import numpy as np
 import torch
 from importlib_resources import path
-from observation_builders.observations import TreeObsForRailEnv
-from predictors.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 
 import torch_training.Nets
 from flatland.envs.rail_env import RailEnv
@@ -87,8 +87,8 @@ dones_list = []
 action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
-agent = Agent(state_size, action_size, "FC", 0)
-with path(torch_training.Nets, "avoid_checkpoint100.pth") as file_in:
+agent = Agent(state_size, action_size)
+with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 222430dd0af0f239ddd99d127b21349a13a2e892..fe9e27969f129313c19f54fd36738188ba376082 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -121,11 +121,11 @@ def main(argv):
     observation_radius = 10
 
     # Initialize the agent
-    agent = Agent(state_size, action_size, "FC", 0)
+    agent = Agent(state_size, action_size)
 
     # Here you can pre-load an agent
     if False:
-        with path(torch_training.Nets, "avoid_checkpoint2400.pth") as file_in:
+        with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
             agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
     # Do training over n_episodes
diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py
index 08cd84c379fe54cd4d6b71140a96623ebe2a8cbf..57f4a619d3bd3ae797fa7fe9aff7c799064bc6f6 100644
--- a/torch_training/multi_agent_two_time_step_training.py
+++ b/torch_training/multi_agent_two_time_step_training.py
@@ -7,16 +7,16 @@ from collections import deque
 import matplotlib.pyplot as plt
 import numpy as np
 import torch
-from importlib_resources import path
-
-# Import Torch and utility functions to normalize observation
-import torch_training.Nets
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 # Import Flatland/ Observations and Predictors
 from flatland.envs.schedule_generators import complex_schedule_generator
+from importlib_resources import path
+
+# Import Torch and utility functions to normalize observation
+import torch_training.Nets
 from torch_training.dueling_double_dqn import Agent
 from utils.observation_utils import norm_obs_clip, split_tree
 
@@ -41,12 +41,12 @@ def main(argv):
     n_agents = np.random.randint(3, 8)
     n_goals = n_agents + np.random.randint(0, 3)
     min_dist = int(0.75 * min(x_dim, y_dim))
-    tree_depth = 3
+    tree_depth = 2
     print("main2")
+    demo = False
 
     # Get an observation builder and predictor
-    predictor = ShortestPathPredictorForRailEnv()
-    observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=predictor())
+    observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
 
     env = RailEnv(width=x_dim,
                   height=y_dim,
@@ -60,7 +60,6 @@ def main(argv):
 
     handle = env.get_agent_handles()
     features_per_node = env.obs_builder.observation_dim
-    tree_depth = 2
     nr_nodes = 0
     for i in range(tree_depth + 1):
         nr_nodes += np.power(4, i)
@@ -87,11 +86,11 @@ def main(argv):
     agent_obs = [None] * env.get_num_agents()
     agent_next_obs = [None] * env.get_num_agents()
     # Initialize the agent
-    agent = Agent(state_size, action_size, "FC", 0)
+    agent = Agent(state_size, action_size)
 
     # Here you can pre-load an agent
     if False:
-        with path(torch_training.Nets, "avoid_checkpoint30000.pth") as file_in:
+        with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
             agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
     # Do training over n_episodes
@@ -132,6 +131,7 @@ def main(argv):
         # Build agent specific observations
         for a in range(env.get_num_agents()):
             data, distance, agent_data = split_tree(tree=np.array(obs[a]),
+                                                    num_features_per_node=features_per_node,
                                                     current_depth=0)
             data = norm_obs_clip(data)
             distance = norm_obs_clip(distance)
@@ -164,6 +164,7 @@ def main(argv):
             next_obs, all_rewards, done, _ = env.step(action_dict)
             for a in range(env.get_num_agents()):
                 data, distance, agent_data = split_tree(tree=np.array(next_obs[a]),
+                                                        num_features_per_node=features_per_node,
                                                         current_depth=0)
                 data = norm_obs_clip(data)
                 distance = norm_obs_clip(distance)
diff --git a/torch_training/observation_builders/observations.py b/torch_training/observation_builders/observations.py
deleted file mode 100644
index 66c38301d034151aa24e306d2e432d54d2802018..0000000000000000000000000000000000000000
--- a/torch_training/observation_builders/observations.py
+++ /dev/null
@@ -1,865 +0,0 @@
-"""
-Collection of environment-specific ObservationBuilder.
-"""
-import pprint
-from collections import deque
-
-import numpy as np
-
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.core.grid.grid_utils import coordinate_to_position
-
-
-class TreeObsForRailEnv(ObservationBuilder):
-    """
-    TreeObsForRailEnv object.
-
-    This object returns observation vectors for agents in the RailEnv environment.
-    The information is local to each agent and exploits the graph structure of the rail
-    network to simplify the representation of the state of the environment for each agent.
-
-    For details about the features in the tree observation see the get() function.
-    """
-
-    def __init__(self, max_depth, predictor=None):
-        super().__init__()
-        self.max_depth = max_depth
-        self.observation_dim = 11
-        # Compute the size of the returned observation vector
-        size = 0
-        pow4 = 1
-        for i in range(self.max_depth + 1):
-            size += pow4
-            pow4 *= 4
-        self.observation_space = [size * self.observation_dim]
-        self.location_has_agent = {}
-        self.location_has_agent_direction = {}
-        self.predictor = predictor
-        self.agents_previous_reset = None
-        self.tree_explored_actions = [1, 2, 3, 0]
-        self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
-        self.distance_map = None
-        self.distance_map_computed = False
-
-    def reset(self):
-        agents = self.env.agents
-        nb_agents = len(agents)
-        compute_distance_map = True
-        if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
-            compute_distance_map = False
-            for i in range(nb_agents):
-                if agents[i].target != self.agents_previous_reset[i].target:
-                    compute_distance_map = True
-        # Don't compute the distance map if it was loaded
-        if self.agents_previous_reset is None and self.distance_map is not None:
-            self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-            compute_distance_map = False
-
-        if compute_distance_map:
-            self._compute_distance_map()
-
-        self.agents_previous_reset = agents
-
-    def _compute_distance_map(self):
-        agents = self.env.agents
-        # For testing only --> To assert if a distance map need to be recomputed.
-        self.distance_map_computed = True
-        nb_agents = len(agents)
-        self.distance_map = np.inf * np.ones(shape=(nb_agents,
-                                                    self.env.height,
-                                                    self.env.width,
-                                                    4))
-        self.max_dist = np.zeros(nb_agents)
-        self.max_dist = [self._distance_map_walker(agent.target, i) for i, agent in enumerate(agents)]
-        # Update local lookup table for all agents' target locations
-        self.location_has_target = {tuple(agent.target): 1 for agent in agents}
-
-    def _distance_map_walker(self, position, target_nr):
-        """
-        Utility function to compute distance maps from each cell in the rail network (and each possible
-        orientation within it) to each agent's target cell.
-        """
-        # Returns max distance to target, from the farthest away node, while filling in distance_map
-        self.distance_map[target_nr, position[0], position[1], :] = 0
-
-        # Fill in the (up to) 4 neighboring nodes
-        # direction is the direction of movement, meaning that at least a possible orientation of an agent
-        # in cell (row,col) allows a movement in direction `direction`
-        nodes_queue = deque(self._get_and_update_neighbors(position, target_nr, 0, enforce_target_direction=-1))
-
-        # BFS from target `position` to all the reachable nodes in the grid
-        # Stop the search if the target position is re-visited, in any direction
-        visited = {(position[0], position[1], 0), (position[0], position[1], 1), (position[0], position[1], 2),
-                   (position[0], position[1], 3)}
-
-        max_distance = 0
-
-        while nodes_queue:
-            node = nodes_queue.popleft()
-
-            node_id = (node[0], node[1], node[2])
-
-            if node_id not in visited:
-                visited.add(node_id)
-
-                # From the list of possible neighbors that have at least a path to the current node, only keep those
-                # whose new orientation in the current cell would allow a transition to direction node[2]
-                valid_neighbors = self._get_and_update_neighbors((node[0], node[1]), target_nr, node[3], node[2])
-
-                for n in valid_neighbors:
-                    nodes_queue.append(n)
-
-                if len(valid_neighbors) > 0:
-                    max_distance = max(max_distance, node[3] + 1)
-
-        return max_distance
-
-    def _get_and_update_neighbors(self, position, target_nr, current_distance, enforce_target_direction=-1):
-        """
-        Utility function used by _distance_map_walker to perform a BFS walk over the rail, filling in the
-        minimum distances from each target cell.
-        """
-        neighbors = []
-
-        possible_directions = [0, 1, 2, 3]
-        if enforce_target_direction >= 0:
-            # The agent must land into the current cell with orientation `enforce_target_direction`.
-            # This is only possible if the agent has arrived from the cell in the opposite direction!
-            possible_directions = [(enforce_target_direction + 2) % 4]
-
-        for neigh_direction in possible_directions:
-            new_cell = get_new_position(position, neigh_direction)
-
-            if new_cell[0] >= 0 and new_cell[0] < self.env.height and new_cell[1] >= 0 and new_cell[1] < self.env.width:
-
-                desired_movement_from_new_cell = (neigh_direction + 2) % 4
-
-                # Check all possible transitions in new_cell
-                for agent_orientation in range(4):
-                    # Is a transition along movement `desired_movement_from_new_cell` to the current cell possible?
-                    is_valid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
-                                                            desired_movement_from_new_cell)
-
-                    if is_valid:
-                        """
-                        # TODO: check that it works with deadends! -- still bugged!
-                        movement = desired_movement_from_new_cell
-                        if isNextCellDeadEnd:
-                            movement = (desired_movement_from_new_cell+2) % 4
-                        """
-                        new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
-                                           current_distance + 1)
-                        neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
-                        self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
-
-        return neighbors
-
-    def get_many(self, handles=None):
-        """
-        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
-        in the `handles` list.
-        """
-
-        if handles is None:
-            handles = []
-        if self.predictor:
-            self.max_prediction_depth = 0
-            self.predicted_pos = {}
-            self.predicted_dir = {}
-            self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map})
-            if self.predictions:
-
-                for t in range(len(self.predictions[0])):
-                    pos_list = []
-                    dir_list = []
-                    for a in handles:
-                        pos_list.append(self.predictions[a][t][1:3])
-                        dir_list.append(self.predictions[a][t][3])
-                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
-                    self.predicted_dir.update({t: dir_list})
-                self.max_prediction_depth = len(self.predicted_pos)
-        observations = {}
-        for h in handles:
-            observations[h] = self.get(h)
-        return observations
-
-    def get(self, handle):
-        """
-        Computes the current observation for agent `handle` in env
-
-        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
-        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
-        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
-        the transitions. The order is:
-            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
-
-        Each branch data is organized as:
-            [root node information] +
-            [recursive branch data from 'left'] +
-            [... from 'forward'] +
-            [... from 'right] +
-            [... from 'back']
-
-        Each node information is composed of 9 features:
-
-        #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
-
-        #2: if another agents target is detected the distance in number of cells from the agents current location
-            is stored
-
-        #3: if another agent is detected the distance in number of cells from current agent position is stored.
-
-        #4: possible conflict detected
-            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the
-             distance in number of cells from current agent position
-
-            0 = No other agent reserve the same cell at similar time
-
-        #5: if an not usable switch (for agent) is detected we store the distance.
-
-        #6: This feature stores the distance in number of cells to the next branching  (current node)
-
-        #7: minimum distance from node to the agent's target given the direction of the agent if this path is chosen
-
-        #8: agent in the same direction
-            n = number of agents present same direction
-                (possible future use: number of other agents in the same direction in this branch)
-            0 = no agent present same direction
-
-        #9: agent in the opposite direction
-            n = number of agents present other direction than myself (so conflict)
-                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
-            0 = no agent present other direction than myself
-
-        #10: malfunctioning/blokcing agents
-            n = number of time steps the oberved agent remains blocked
-
-        #11: slowest observed speed of an agent in same direction
-            1 if no agent is observed
-
-            min_fractional speed otherwise
-
-
-
-
-
-        Missing/padding nodes are filled in with -inf (truncated).
-        Missing values in present node are filled in with +inf (truncated).
-
-
-        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
-        In case the target node is reached, the values are [0, 0, 0, 0, 0].
-        """
-
-        # Update local lookup table for all agents' positions
-        self.location_has_agent = {tuple(agent.position): 1 for agent in self.env.agents}
-        self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
-        self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
-        self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
-                                               self.env.agents}
-
-        if handle > len(self.env.agents):
-            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
-        agent = self.env.agents[handle]  # TODO: handle being treated as index
-        possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-        num_transitions = np.count_nonzero(possible_transitions)
-
-        # Root node - current position
-        # Here information about the agent itself is stored
-        observation = [0, 0, 0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0,
-                       agent.malfunction_data['malfunction'], agent.speed_data['speed']]
-
-        visited = set()
-
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
-        orientation = agent.direction
-
-        if num_transitions == 1:
-            orientation = np.argmax(possible_transitions)
-
-        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
-            if possible_transitions[branch_direction]:
-                new_cell = get_new_position(agent.position, branch_direction)
-                branch_observation, branch_visited = \
-                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
-                observation = observation + branch_observation
-                visited = visited.union(branch_visited)
-            else:
-                # add cells filled with infinity if no transition is possible
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
-        self.env.dev_obs_dict[handle] = visited
-
-        return observation
-
-    def _num_cells_to_fill_in(self, remaining_depth):
-        """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
-        num_observations = 0
-        pow4 = 1
-        for i in range(remaining_depth):
-            num_observations += pow4
-            pow4 *= 4
-        return num_observations * self.observation_dim
-
-    def _explore_branch(self, handle, position, direction, tot_dist, depth):
-        """
-        Utility function to compute tree-based observations.
-        We walk along the branch and collect the information documented in the get() function.
-        If there is a branching point a new node is created and each possible branch is explored.
-        """
-
-        # [Recursive branch opened]
-        if depth >= self.max_depth + 1:
-            return [], []
-
-        # Continue along direction until next switch or
-        # until no transitions are possible along the current direction (i.e., dead-ends)
-        # We treat dead-ends as nodes, instead of going back, to avoid loops
-        exploring = True
-        last_is_switch = False
-        last_is_dead_end = False
-        last_is_terminal = False  # wrong cell OR cycle;  either way, we don't want the agent to land here
-        last_is_target = False
-
-        visited = set()
-        agent = self.env.agents[handle]
-        time_per_cell = np.reciprocal(agent.speed_data["speed"])
-        own_target_encountered = np.inf
-        other_agent_encountered = np.inf
-        other_target_encountered = np.inf
-        potential_conflict = np.inf
-        unusable_switch = np.inf
-        other_agent_same_direction = 0
-        other_agent_opposite_direction = 0
-        malfunctioning_agent = 0.
-        min_fractional_speed = 1.
-        num_steps = 1
-        while exploring:
-            # #############################
-            # #############################
-            # Modify here to compute any useful data required to build the end node's features. This code is called
-            # for each cell visited between the previous branching node and the next switch / target / dead-end.
-            if position in self.location_has_agent:
-                if tot_dist < other_agent_encountered:
-                    other_agent_encountered = tot_dist
-
-                # Check if any of the observed agents is malfunctioning, store agent with longest duration left
-                if self.location_has_agent_malfunction[position] > malfunctioning_agent:
-                    malfunctioning_agent = self.location_has_agent_malfunction[position]
-
-                if self.location_has_agent_direction[position] == direction:
-                    # Cummulate the number of agents on branch with same direction
-                    other_agent_same_direction += 1
-
-                    # Check fractional speed of agents
-                    current_fractional_speed = self.location_has_agent_speed[position]
-                    if current_fractional_speed < min_fractional_speed:
-                        min_fractional_speed = current_fractional_speed
-
-                if self.location_has_agent_direction[position] != direction:
-                    # Cummulate the number of agents on branch with other direction
-                    other_agent_opposite_direction += 1
-
-            # Check number of possible transitions for agent and total number of transitions in cell (type)
-            cell_transitions = self.env.rail.get_transitions(*position, direction)
-            transition_bit = bin(self.env.rail.get_full_transitions(*position))
-            total_transitions = transition_bit.count("1")
-            crossing_found = False
-            if int(transition_bit, 2) == int('1000010000100001', 2):
-                crossing_found = True
-
-            # Register possible future conflict
-            predicted_time = int(tot_dist * time_per_cell)
-            if self.predictor and predicted_time < self.max_prediction_depth:
-                int_position = coordinate_to_position(self.env.width, [position])
-                if tot_dist < self.max_prediction_depth:
-
-                    pre_step = max(0, predicted_time - 1)
-                    post_step = min(self.max_prediction_depth - 1, predicted_time + 1)
-
-                    # Look for conflicting paths at distance tot_dist
-                    if int_position in np.delete(self.predicted_pos[predicted_time], handle, 0):
-                        conflicting_agent = np.where(self.predicted_pos[predicted_time] == int_position)
-                        for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[predicted_time][ca] and cell_transitions[
-                                self._reverse_dir(
-                                    self.predicted_dir[predicted_time][ca])] == 1 and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-
-                    # Look for conflicting paths at distance num_step-1
-                    elif int_position in np.delete(self.predicted_pos[pre_step], handle, 0):
-                        conflicting_agent = np.where(self.predicted_pos[pre_step] == int_position)
-                        for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[pre_step][ca] \
-                                    and cell_transitions[self._reverse_dir(self.predicted_dir[pre_step][ca])] == 1 \
-                                    and tot_dist < potential_conflict:  # noqa: E125
-                                potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-
-                    # Look for conflicting paths at distance num_step+1
-                    elif int_position in np.delete(self.predicted_pos[post_step], handle, 0):
-                        conflicting_agent = np.where(self.predicted_pos[post_step] == int_position)
-                        for ca in conflicting_agent[0]:
-                            if direction != self.predicted_dir[post_step][ca] and cell_transitions[self._reverse_dir(
-                                    self.predicted_dir[post_step][ca])] == 1 \
-                                    and tot_dist < potential_conflict:  # noqa: E125
-                                potential_conflict = tot_dist
-                            if self.env.dones[ca] and tot_dist < potential_conflict:
-                                potential_conflict = tot_dist
-
-            if position in self.location_has_target and position != agent.target:
-                if tot_dist < other_target_encountered:
-                    other_target_encountered = tot_dist
-
-            if position == agent.target and tot_dist < own_target_encountered:
-                own_target_encountered = tot_dist
-
-            # #############################
-            # #############################
-            if (position[0], position[1], direction) in visited:
-                last_is_terminal = True
-                break
-            visited.add((position[0], position[1], direction))
-
-            # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            if np.array_equal(position, self.env.agents[handle].target):
-                last_is_target = True
-                break
-
-            # Check if crossing is found --> Not an unusable switch
-            if crossing_found:
-                # Treat the crossing as a straight rail cell
-                total_transitions = 2
-            num_transitions = np.count_nonzero(cell_transitions)
-
-            exploring = False
-
-            # Detect Switches that can only be used by other agents.
-            if total_transitions > 2 > num_transitions and tot_dist < unusable_switch:
-                unusable_switch = tot_dist
-
-            if num_transitions == 1:
-                # Check if dead-end, or if we can go forward along direction
-                nbits = total_transitions
-                if nbits == 1:
-                    # Dead-end!
-                    last_is_dead_end = True
-
-                if not last_is_dead_end:
-                    # Keep walking through the tree along `direction`
-                    exploring = True
-                    # convert one-hot encoding to 0,1,2,3
-                    direction = np.argmax(cell_transitions)
-                    position = get_new_position(position, direction)
-                    num_steps += 1
-                    tot_dist += 1
-            elif num_transitions > 0:
-                # Switch detected
-                last_is_switch = True
-                break
-
-            elif num_transitions == 0:
-                # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
-                print("WRONG CELL TYPE detected in tree-search (0 transitions possible) at cell", position[0],
-                      position[1], direction)
-                last_is_terminal = True
-                break
-
-        # `position` is either a terminal node or a switch
-
-        # #############################
-        # #############################
-        # Modify here to append new / different features for each visited cell!
-
-        if last_is_target:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           0,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-
-        elif last_is_terminal:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           np.inf,
-                           self.distance_map[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-
-        else:
-            observation = [own_target_encountered,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           potential_conflict,
-                           unusable_switch,
-                           tot_dist,
-                           self.distance_map[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction,
-                           malfunctioning_agent,
-                           min_fractional_speed
-                           ]
-        # #############################
-        # #############################
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions(*position, direction)
-        for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
-            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
-                                                                 (branch_direction + 2) % 4):
-                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
-                # it back
-                new_cell = get_new_position(position, (branch_direction + 2) % 4)
-                branch_observation, branch_visited = self._explore_branch(handle,
-                                                                          new_cell,
-                                                                          (branch_direction + 2) % 4,
-                                                                          tot_dist + 1,
-                                                                          depth + 1)
-                observation = observation + branch_observation
-                if len(branch_visited) != 0:
-                    visited = visited.union(branch_visited)
-            elif last_is_switch and possible_transitions[branch_direction]:
-                new_cell = get_new_position(position, branch_direction)
-                branch_observation, branch_visited = self._explore_branch(handle,
-                                                                          new_cell,
-                                                                          branch_direction,
-                                                                          tot_dist + 1,
-                                                                          depth + 1)
-                observation = observation + branch_observation
-                if len(branch_visited) != 0:
-                    visited = visited.union(branch_visited)
-            else:
-                # no exploring possible, add just cells with infinity
-                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
-
-        return observation, visited
-
-    def util_print_obs_subtree(self, tree):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        pp = pprint.PrettyPrinter(indent=4)
-        pp.pprint(self.unfold_observation_tree(tree))
-
-    def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
-        """
-        Utility function to pretty-print tree observations returned by this object.
-        """
-        if len(tree) < self.observation_dim:
-            return
-
-        depth = 0
-        tmp = len(tree) / self.observation_dim - 1
-        pow4 = 4
-        while tmp > 0:
-            tmp -= pow4
-            depth += 1
-            pow4 *= 4
-
-        unfolded = {}
-        unfolded[''] = tree[0:self.observation_dim]
-        child_size = (len(tree) - self.observation_dim) // 4
-        for child in range(4):
-            child_tree = tree[(self.observation_dim + child * child_size):
-                              (self.observation_dim + (child + 1) * child_size)]
-            observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
-            if observation_tree is not None:
-                if actions_for_display:
-                    label = self.tree_explorted_actions_char[child]
-                else:
-                    label = self.tree_explored_actions[child]
-                unfolded[label] = observation_tree
-        return unfolded
-
-    def _set_env(self, env):
-        self.env = env
-        if self.predictor:
-            self.predictor._set_env(self.env)
-
-    def _reverse_dir(self, direction):
-        return int((direction + 2) % 4)
-
-
-class GlobalObsForRailEnv(ObservationBuilder):
-    """
-    Gives a global observation of the entire rail environment.
-    The observation is composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width, 16),
-          assuming 16 bits encoding of transitions.
-
-        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
-         target and the positions of the other agents targets.
-
-        - A 3D array (map_height, map_width, 8) with the 4 first channels containing the one hot encoding
-          of the direction of the given agent and the 4 second channels containing the positions
-          of the other agents at their position coordinates.
-    """
-
-    def __init__(self):
-        self.observation_space = ()
-        super(GlobalObsForRailEnv, self).__init__()
-
-    def _set_env(self, env):
-        super()._set_env(env)
-
-        self.observation_space = [4, self.env.height, self.env.width]
-
-    def reset(self):
-        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle):
-        obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 8))
-        agents = self.env.agents
-        agent = agents[handle]
-
-        direction = np.zeros(4)
-        direction[agent.direction] = 1
-        agent_pos = agents[handle].position
-        obs_agents_state[agent_pos][:4] = direction
-        obs_targets[agent.target][0] += 1
-
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
-                obs_agents_state[agent2.position][4 + agent2.direction] = 1
-                obs_targets[agent2.target][1] += 1
-
-        direction = self._get_one_hot_for_agent_direction(agent)
-
-        return self.rail_obs, obs_agents_state, obs_targets, direction
-
-
-class GlobalObsForRailEnvDirectionDependent(ObservationBuilder):
-    """
-    Gives a global observation of the entire rail environment.
-    The observation is composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width, 16),
-          assuming 16 bits encoding of transitions, flipped in the direction of the agent
-          (the agent is always heading north on the flipped view).
-
-        - Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent
-         target and the positions of the other agents targets, also flipped depending on the agent's direction.
-
-        - A 3D array (map_height, map_width, 5) containing the one hot encoding of the direction of the other
-          agents at their position coordinates, and the last channel containing the position of the given agent.
-
-        - A 4 elements array with one hot encoding of the direction.
-    """
-
-    def __init__(self):
-        self.observation_space = ()
-        super(GlobalObsForRailEnvDirectionDependent, self).__init__()
-
-    def _set_env(self, env):
-        super()._set_env(env)
-
-        self.observation_space = [4, self.env.height, self.env.width]
-
-    def reset(self):
-        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle):
-        obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 5))
-        agents = self.env.agents
-        agent = agents[handle]
-        direction = agent.direction
-
-        idx = np.tile(np.arange(16), 2)
-
-        rail_obs = self.rail_obs[:, :, idx[direction * 4: direction * 4 + 16]]
-
-        if direction == 1:
-            rail_obs = np.flip(rail_obs, axis=1)
-        elif direction == 2:
-            rail_obs = np.flip(rail_obs)
-        elif direction == 3:
-            rail_obs = np.flip(rail_obs, axis=0)
-
-        agent_pos = agents[handle].position
-        obs_agents_state[agent_pos][0] = 1
-        obs_targets[agent.target][0] += 1
-
-        idx = np.tile(np.arange(4), 2)
-        for i in range(len(agents)):
-            if i != handle:  # TODO: handle used as index...?
-                agent2 = agents[i]
-                obs_agents_state[agent2.position][1 + idx[4 + (agent2.direction - direction)]] = 1
-                obs_targets[agent2.target][1] += 1
-
-        direction = self._get_one_hot_for_agent_direction(agent)
-
-        return rail_obs, obs_agents_state, obs_targets, direction
-
-
-class LocalObsForRailEnv(ObservationBuilder):
-    """
-    Gives a local observation of the rail environment around the agent.
-    The observation is composed of the following elements:
-
-        - transition map array of the local environment around the given agent,
-          with dimensions (view_height,2*view_width+1, 16),
-          assuming 16 bits encoding of transitions.
-
-        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively,
-        if they are in the agent's vision range, its target position, the positions of the other targets.
-
-        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions
-          of the other agents at their position coordinates, if they are in the agent's vision range.
-
-        - A 4 elements array with one hot encoding of the direction.
-
-    Use the parameters view_width and view_height to define the rectangular view of the agent.
-    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
-    observation in front of it.
-    """
-
-    def __init__(self, view_width, view_height, center):
-
-        super(LocalObsForRailEnv, self).__init__()
-        self.view_width = view_width
-        self.view_height = view_height
-        self.center = center
-        self.max_padding = max(self.view_width, self.view_height - self.center)
-
-    def reset(self):
-        # We build the transition map with a view_radius empty cells expansion on each side.
-        # This helps to collect the local transition map view when the agent is close to a border.
-        self.max_padding = max(self.view_width, self.view_height)
-        self.rail_obs = np.zeros((self.env.height,
-                                  self.env.width, 16))
-        for i in range(self.env.height):
-            for j in range(self.env.width):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle):
-        agents = self.env.agents
-        agent = agents[handle]
-
-        # Correct agents position for padding
-        # agent_rel_pos[0] = agent.position[0] + self.max_padding
-        # agent_rel_pos[1] = agent.position[1] + self.max_padding
-
-        # Collect visible cells as set to be plotted
-        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
-        local_rail_obs = None
-
-        # Add the visible cells to the observed cells
-        self.env.dev_obs_dict[handle] = set(visited)
-
-        # Locate observed agents and their coresponding targets
-        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
-        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
-        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
-        _idx = 0
-        for pos in visited:
-            curr_rel_coord = rel_coords[_idx]
-            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
-            if pos == agent.target:
-                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
-            else:
-                for tmp_agent in agents:
-                    if pos == tmp_agent.target:
-                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
-            if pos != agent.position:
-                for tmp_agent in agents:
-                    if pos == tmp_agent.position:
-                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
-                            tmp_agent.direction]
-
-            _idx += 1
-
-        direction = np.identity(4)[agent.direction]
-        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
-
-    def get_many(self, handles=None):
-        """
-        Called whenever an observation has to be computed for the `env' environment, for each agent with handle
-        in the `handles' list.
-        """
-
-        observations = {}
-        for h in handles:
-            observations[h] = self.get(h)
-        return observations
-
-    def field_of_view(self, position, direction, state=None):
-        # Compute the local field of view for an agent in the environment
-        data_collection = False
-        if state is not None:
-            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
-            data_collection = True
-        if direction == 0:
-            origin = (position[0] + self.center, position[1] - self.view_width)
-        elif direction == 1:
-            origin = (position[0] - self.view_width, position[1] - self.center)
-        elif direction == 2:
-            origin = (position[0] - self.center, position[1] + self.view_width)
-        else:
-            origin = (position[0] + self.view_width, position[1] + self.center)
-        visible = list()
-        rel_coords = list()
-        for h in range(self.view_height):
-            for w in range(2 * self.view_width + 1):
-                if direction == 0:
-                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
-                        visible.append((origin[0] - h, origin[1] + w))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
-                elif direction == 1:
-                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
-                        visible.append((origin[0] + w, origin[1] + h))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
-                elif direction == 2:
-                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
-                        visible.append((origin[0] + h, origin[1] - w))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
-                else:
-                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
-                        visible.append((origin[0] - w, origin[1] - h))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
-        if data_collection:
-            return temp_visible_data
-        else:
-            return visible, rel_coords
diff --git a/torch_training/predictors/predictions.py b/torch_training/predictors/predictions.py
deleted file mode 100644
index 8306b726572e7a7e168e71755a760e010a6b1eb5..0000000000000000000000000000000000000000
--- a/torch_training/predictors/predictions.py
+++ /dev/null
@@ -1,179 +0,0 @@
-"""
-Collection of environment-specific PredictionBuilder.
-"""
-
-import numpy as np
-
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.rail_env import RailEnvActions
-
-
-class DummyPredictorForRailEnv(PredictionBuilder):
-    """
-    DummyPredictorForRailEnv object.
-
-    This object returns predictions for agents in the RailEnv environment.
-    The prediction acts as if no other agent is in the environment and always takes the forward action.
-    """
-
-    def get(self, custom_args=None, handle=None):
-        """
-        Called whenever get_many in the observation build is called.
-
-        Parameters
-        -------
-        custom_args: dict
-            Not used in this dummy implementation.
-        handle : int (optional)
-            Handle of the agent for which to compute the observation vector.
-
-        Returns
-        -------
-        np.array
-            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
-            - time_offset
-            - position axis 0
-            - position axis 1
-            - direction
-            - action taken to come here
-            The prediction at 0 is the current position, direction etc.
-
-        """
-        agents = self.env.agents
-        if handle:
-            agents = [self.env.agents[handle]]
-
-        prediction_dict = {}
-
-        for agent in agents:
-            action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT]
-            _agent_initial_position = agent.position
-            _agent_initial_direction = agent.direction
-            prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
-            for index in range(1, self.max_depth + 1):
-                action_done = False
-                # if we're at the target, stop moving...
-                if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
-
-                    continue
-                for action in action_priorities:
-                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
-                        self.env._check_action_on_agent(action, agent)
-                    if all([new_cell_isValid, transition_isValid]):
-                        # move and change direction to face the new_direction that was
-                        # performed
-                        agent.position = new_position
-                        agent.direction = new_direction
-                        prediction[index] = [index, *new_position, new_direction, action]
-                        action_done = True
-                        break
-                if not action_done:
-                    raise Exception("Cannot move further. Something is wrong")
-            prediction_dict[agent.handle] = prediction
-            agent.position = _agent_initial_position
-            agent.direction = _agent_initial_direction
-        return prediction_dict
-
-
-class ShortestPathPredictorForRailEnv(PredictionBuilder):
-    """
-    ShortestPathPredictorForRailEnv object.
-
-    This object returns shortest-path predictions for agents in the RailEnv environment.
-    The prediction acts as if no other agent is in the environment and always takes the forward action.
-    """
-
-    def __init__(self, max_depth=20):
-        # Initialize with depth 20
-        self.max_depth = max_depth
-
-    def get(self, custom_args=None, handle=None):
-        """
-        Called whenever get_many in the observation build is called.
-        Requires distance_map to extract the shortest path.
-
-        Parameters
-        ----------
-        custom_args: dict
-            - distance_map : dict
-        handle : int, optional
-            Handle of the agent for which to compute the observation vector.
-
-        Returns
-        -------
-        np.array
-            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
-            - time_offset
-            - position axis 0
-            - position axis 1
-            - direction
-            - action taken to come here
-            The prediction at 0 is the current position, direction etc.
-        """
-        agents = self.env.agents
-        if handle:
-            agents = [self.env.agents[handle]]
-        assert custom_args is not None
-        distance_map = custom_args.get('distance_map')
-        assert distance_map is not None
-
-        prediction_dict = {}
-        for agent in agents:
-            _agent_initial_position = agent.position
-            _agent_initial_direction = agent.direction
-            agent_speed = agent.speed_data["speed"]
-            times_per_cell = int(np.reciprocal(agent_speed))
-            prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
-            new_direction = _agent_initial_direction
-            new_position = _agent_initial_position
-            visited = set()
-            for index in range(1, self.max_depth + 1):
-                # if we're at the target, stop moving...
-                if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
-                    continue
-                if not agent.moving:
-                    prediction[index] = [index, *agent.position, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.add((agent.position[0], agent.position[1], agent.direction))
-                    continue
-                # Take shortest possible path
-                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-
-                if np.sum(cell_transitions) == 1 and index % times_per_cell == 0:
-                    new_direction = np.argmax(cell_transitions)
-                    new_position = get_new_position(agent.position, new_direction)
-                elif np.sum(cell_transitions) > 1 and index % times_per_cell == 0:
-                    min_dist = np.inf
-                    no_dist_found = True
-                    for direction in range(4):
-                        if cell_transitions[direction] == 1:
-                            neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
-                            if target_dist < min_dist or no_dist_found:
-                                min_dist = target_dist
-                                new_direction = direction
-                                no_dist_found = False
-                    new_position = get_new_position(agent.position, new_direction)
-                elif index % times_per_cell == 0:
-                    raise Exception("No transition possible {}".format(cell_transitions))
-
-                # update the agent's position and direction
-                agent.position = new_position
-                agent.direction = new_direction
-
-                # prediction is ready
-                prediction[index] = [index, *new_position, new_direction, 0]
-                visited.add((new_position[0], new_position[1], new_direction))
-            self.env.dev_pred_dict[agent.handle] = visited
-            prediction_dict[agent.handle] = prediction
-
-            # cleanup: reset initial position
-            agent.position = _agent_initial_position
-            agent.direction = _agent_initial_direction
-
-        return prediction_dict
diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py
index 589d6109641c540c103f030e7eb139c35df25298..2649a2367367e17e39328ca8c28cc9c2f1fc0172 100644
--- a/torch_training/render_agent_behavior.py
+++ b/torch_training/render_agent_behavior.py
@@ -101,8 +101,8 @@ dones_list = []
 action_prob = [0] * action_size
 agent_obs = [None] * env.get_num_agents()
 agent_next_obs = [None] * env.get_num_agents()
-agent = Agent(state_size, action_size, "FC", 0)
-with path(torch_training.Nets, "navigator_checkpoint10700.pth") as file_in:
+agent = Agent(state_size, action_size)
+with path(torch_training.Nets, "avoid_checkpoint500.pth") as file_in:
     agent.qnetwork_local.load_state_dict(torch.load(file_in))
 
 record_images = False
diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index f69929f65accc53101ba28d8904cdf76b7e1cfca..bd221ae4b912b89c1d5bb242676d4f75819cfd90 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -11,7 +11,7 @@ sys.path.append(str(base_dir))
 import matplotlib.pyplot as plt
 import numpy as np
 import torch
-from dueling_double_dqn import Agent
+from torch_training.dueling_double_dqn import Agent
 
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -117,7 +117,7 @@ def main(argv):
     cummulated_reward = np.zeros(env.get_num_agents())
 
     # Now we load a Double dueling DQN agent
-    agent = Agent(state_size, action_size, "FC", 0)
+    agent = Agent(state_size, action_size)
 
     for trials in range(1, n_trials + 1):