diff --git a/reinforcement_learning/__init__.py b/reinforcement_learning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..95c343ec6e87a200e5c2fa5565596b8852155e13
--- /dev/null
+++ b/reinforcement_learning/dddqn_policy.py
@@ -0,0 +1,212 @@
+import copy
+import os
+import pickle
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+
+from reinforcement_learning.model import DuelingQNetwork
+from reinforcement_learning.policy import Policy
+
+
+class DDDQNPolicy(Policy):
+    """Dueling Double DQN policy"""
+
+    def __init__(self, state_size, action_size, parameters, evaluation_mode=False):
+        self.evaluation_mode = evaluation_mode
+
+        self.state_size = state_size
+        self.action_size = action_size
+        self.double_dqn = True
+        self.hidsize = 1
+
+        if not evaluation_mode:
+            self.hidsize = parameters.hidden_size
+            self.buffer_size = parameters.buffer_size
+            self.batch_size = parameters.batch_size
+            self.update_every = parameters.update_every
+            self.learning_rate = parameters.learning_rate
+            self.tau = parameters.tau
+            self.gamma = parameters.gamma
+            self.buffer_min_size = parameters.buffer_min_size
+
+        # Device
+        if parameters.use_gpu and torch.cuda.is_available():
+            self.device = torch.device("cuda:0")
+            # print("🐇 Using GPU")
+        else:
+            self.device = torch.device("cpu")
+            # print("🐢 Using CPU")
+
+        # Q-Network
+        self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(
+            self.device)
+
+        if not evaluation_mode:
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+            self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate)
+            self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
+
+            self.t_step = 0
+            self.loss = 0.0
+
+    def act(self, handle, state, eps=0.):
+        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy())
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    def step(self, handle, state, action, reward, next_state, done):
+        assert not self.evaluation_mode, "Policy has been initialized for evaluation only."
+
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+
+        # Learn every UPDATE_EVERY time steps.
+        self.t_step = (self.t_step + 1) % self.update_every
+        if self.t_step == 0:
+            # If enough samples are available in memory, get random subset and learn
+            if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
+                self._learn()
+
+    def _clip_gradient(self, model, clip):
+        """Computes a gradient clipping coefficient based on gradient norm."""
+        totalnorm = 0
+        for p in model.parameters():
+            if p.grad is not None:
+                modulenorm = p.grad.data.norm()
+                totalnorm += modulenorm ** 2
+        totalnorm = np.sqrt(totalnorm)
+        coeff = min(1, clip / (totalnorm + 1e-6))
+
+        for p in model.parameters():
+            if p.grad is not None:
+                p.grad.mul_(coeff)
+
+    def _learn(self):
+        experiences = self.memory.sample()
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+        # Compute Q targets for current states
+        q_targets = rewards + (self.gamma * q_targets_next * (1 - dones))
+
+        # Compute loss
+        self.loss = F.mse_loss(q_expected, q_targets)
+
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        self.loss.backward()
+        # for param in self.qnetwork_local.parameters():
+        #   param.grad.data.clamp_(-1.0, 1.0)
+        self._clip_gradient(self.qnetwork_local, 1.0)
+
+        self.optimizer.step()
+
+        # Update target network
+        self._soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
+
+    def _soft_update(self, local_model, target_model, tau):
+        # Soft update model parameters.
+        # θ_target = τ*θ_local + (1 - τ)*θ_target
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        if os.path.exists(filename + ".local"):
+            print(' >> ', filename + ".local")
+            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+        if os.path.exists(filename + ".target"):
+            print(' >> ', filename + ".target")
+            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+
+    def save_replay_buffer(self, filename):
+        memory = self.memory.memory
+        with open(filename, 'wb') as f:
+            pickle.dump(list(memory)[-500000:], f)
+
+    def load_replay_buffer(self, filename):
+        with open(filename, 'rb') as f:
+            self.memory.memory = pickle.load(f)
+
+    def test(self):
+        self.act(np.array([[0] * self.state_size]))
+        self._learn()
+
+
+Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, device):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.device = device
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(self.device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(self.device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(self.device)
+
+        return states, actions, rewards, next_states, dones
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
diff --git a/reinforcement_learning/evaluate_agent.py b/reinforcement_learning/evaluate_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..2adb14377803d681d44b17b9a1135203e0af59e7
--- /dev/null
+++ b/reinforcement_learning/evaluate_agent.py
@@ -0,0 +1,376 @@
+import math
+import multiprocessing
+import os
+import sys
+from argparse import ArgumentParser, Namespace
+from multiprocessing import Pool
+from pathlib import Path
+from pprint import pprint
+
+import numpy as np
+import torch
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.deadlock_check import check_if_all_blocked
+from utils.timer import Timer
+from utils.observation_utils import normalize_observation
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+
+
+def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching):
+    # Evaluation is faster on CPU (except if you use a really huge)
+    parameters = {
+        'use_gpu': False
+    }
+
+    policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True)
+    policy.qnetwork_local = torch.load(checkpoint)
+
+    env_params = Namespace(**env_params)
+
+    # Environment parameters
+    n_agents = env_params.n_agents
+    x_dim = env_params.x_dim
+    y_dim = env_params.y_dim
+    n_cities = env_params.n_cities
+    max_rails_between_cities = env_params.max_rails_between_cities
+    max_rails_in_city = env_params.max_rails_in_city
+
+    # Malfunction and speed profiles
+    # TODO pass these parameters properly from main!
+    malfunction_parameters = MalfunctionParameters(
+        malfunction_rate=1. / 2000,  # Rate of malfunctions
+        min_duration=20,  # Minimal duration
+        max_duration=50  # Max duration
+    )
+
+    # Only fast trains in Round 1
+    speed_profiles = {
+        1.: 1.0,  # Fast passenger train
+        1. / 2.: 0.0,  # Fast freight train
+        1. / 3.: 0.0,  # Slow commuter train
+        1. / 4.: 0.0  # Slow freight train
+    }
+
+    # Observation parameters
+    observation_tree_depth = env_params.observation_tree_depth
+    observation_radius = env_params.observation_radius
+    observation_max_path_depth = env_params.observation_max_path_depth
+
+    # Observation builder
+    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+
+    # Setup the environment
+    env = RailEnv(
+        width=x_dim, height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city,
+        ),
+        schedule_generator=sparse_schedule_generator(speed_profiles),
+        number_of_agents=n_agents,
+        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
+        obs_builder_object=tree_observation
+    )
+
+    if render:
+        env_renderer = RenderTool(env, gl="PGL")
+
+    action_dict = dict()
+    scores = []
+    completions = []
+    nb_steps = []
+    inference_times = []
+    preproc_times = []
+    agent_times = []
+    step_times = []
+
+    for episode_idx in range(n_eval_episodes):
+        seed += 1
+
+        inference_timer = Timer()
+        preproc_timer = Timer()
+        agent_timer = Timer()
+        step_timer = Timer()
+
+        step_timer.start()
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True, random_seed=seed)
+        step_timer.end()
+
+        agent_obs = [None] * env.get_num_agents()
+        score = 0.0
+
+        if render:
+            env_renderer.set_new_rail()
+
+        final_step = 0
+        skipped = 0
+
+        nb_hit = 0
+        agent_last_obs = {}
+        agent_last_action = {}
+
+        for step in range(max_steps - 1):
+            if allow_skipping and check_if_all_blocked(env):
+                # FIXME why -1? bug where all agents are "done" after max_steps!
+                skipped = max_steps - step - 1
+                final_step = max_steps - 2
+                n_unfinished_agents = sum(not done[idx] for idx in env.get_agent_handles())
+                score -= skipped * n_unfinished_agents
+                break
+
+            agent_timer.start()
+            for agent in env.get_agent_handles():
+                if obs[agent] and info['action_required'][agent]:
+                    if agent in agent_last_obs and np.all(agent_last_obs[agent] == obs[agent]):
+                        nb_hit += 1
+                        action = agent_last_action[agent]
+
+                    else:
+                        preproc_timer.start()
+                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
+                        preproc_timer.end()
+
+                        inference_timer.start()
+                        action = policy.act(norm_obs, eps=0.0)
+                        inference_timer.end()
+
+                    action_dict.update({agent: action})
+
+                    if allow_caching:
+                        agent_last_obs[agent] = obs[agent]
+                        agent_last_action[agent] = action
+            agent_timer.end()
+
+            step_timer.start()
+            obs, all_rewards, done, info = env.step(action_dict)
+            step_timer.end()
+
+            if render:
+                env_renderer.render_env(
+                    show=True,
+                    frames=False,
+                    show_observations=False,
+                    show_predictions=False
+                )
+
+                if step % 100 == 0:
+                    print("{}/{}".format(step, max_steps - 1))
+
+            for agent in env.get_agent_handles():
+                score += all_rewards[agent]
+
+            final_step = step
+
+            if done['__all__']:
+                break
+
+        normalized_score = score / (max_steps * env.get_num_agents())
+        scores.append(normalized_score)
+
+        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
+        completion = tasks_finished / max(1, env.get_num_agents())
+        completions.append(completion)
+
+        nb_steps.append(final_step)
+
+        inference_times.append(inference_timer.get())
+        preproc_times.append(preproc_timer.get())
+        agent_times.append(agent_timer.get())
+        step_times.append(step_timer.get())
+
+        skipped_text = ""
+        if skipped > 0:
+            skipped_text = "\tâš¡ Skipped {}".format(skipped)
+
+        hit_text = ""
+        if nb_hit > 0:
+            hit_text = "\tâš¡ Hit {} ({:.1f}%)".format(nb_hit, (100 * nb_hit) / (n_agents * final_step))
+
+        print(
+            "☑️  Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} "
+            "\t🍭 Seed: {}"
+            "\t🚉 Env: {:.3f}s  "
+            "\t🤖 Agent: {:.3f}s (per step: {:.3f}s) \t[preproc: {:.3f}s \tinfer: {:.3f}s]"
+            "{}{}".format(
+                normalized_score,
+                completion * 100.0,
+                final_step,
+                seed,
+                step_timer.get(),
+                agent_timer.get(),
+                agent_timer.get() / final_step,
+                preproc_timer.get(),
+                inference_timer.get(),
+                skipped_text,
+                hit_text
+            )
+        )
+
+    return scores, completions, nb_steps, agent_times, step_times
+
+
+def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping, allow_caching):
+    nb_threads = 1
+    eval_per_thread = n_evaluation_episodes
+
+    if not render:
+        nb_threads = multiprocessing.cpu_count()
+        eval_per_thread = max(1, math.ceil(n_evaluation_episodes / nb_threads))
+
+    total_nb_eval = eval_per_thread * nb_threads
+    print("Will evaluate policy {} over {} episodes on {} threads.".format(file, total_nb_eval, nb_threads))
+
+    if total_nb_eval != n_evaluation_episodes:
+        print("(Rounding up from {} to fill all cores)".format(n_evaluation_episodes))
+
+    # Observation parameters need to match the ones used during training!
+
+    # small_v0
+    small_v0_params = {
+        # sample configuration
+        "n_agents": 5,
+        "x_dim": 25,
+        "y_dim": 25,
+        "n_cities": 4,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    # Test_0
+    test0_params = {
+        # sample configuration
+        "n_agents": 5,
+        "x_dim": 25,
+        "y_dim": 25,
+        "n_cities": 2,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    # Test_1
+    test1_params = {
+        # environment
+        "n_agents": 10,
+        "x_dim": 30,
+        "y_dim": 30,
+        "n_cities": 2,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 10
+    }
+
+    # Test_5
+    test5_params = {
+        # environment
+        "n_agents": 80,
+        "x_dim": 35,
+        "y_dim": 35,
+        "n_cities": 5,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 4,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    params = small_v0_params
+    env_params = Namespace(**params)
+
+    print("Environment parameters:")
+    pprint(params)
+
+    # Calculate space dimensions and max steps
+    max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities)))
+    action_size = 5
+    tree_observation = TreeObsForRailEnv(max_depth=env_params.observation_tree_depth)
+    tree_depth = env_params.observation_tree_depth
+    num_features_per_node = tree_observation.observation_dim
+    n_nodes = sum([np.power(4, i) for i in range(tree_depth + 1)])
+    state_size = num_features_per_node * n_nodes
+
+    results = []
+    if render:
+        results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching))
+
+    else:
+        with Pool() as p:
+            results = p.starmap(eval_policy,
+                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching)
+                                 for seed in
+                                 range(total_nb_eval)])
+
+    scores = []
+    completions = []
+    nb_steps = []
+    times = []
+    step_times = []
+    for s, c, n, t, st in results:
+        scores.append(s)
+        completions.append(c)
+        nb_steps.append(n)
+        times.append(t)
+        step_times.append(st)
+
+    print("-" * 200)
+
+    print("✅ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} \tAgent total: {:.3f}s (per step: {:.3f}s)".format(
+        np.mean(scores),
+        np.mean(completions) * 100.0,
+        np.mean(nb_steps),
+        np.mean(times),
+        np.mean(times) / np.mean(nb_steps)
+    ))
+
+    print("⏲️  Agent sum: {:.3f}s \tEnv sum: {:.3f}s \tTotal sum: {:.3f}s".format(
+        np.sum(times),
+        np.sum(step_times),
+        np.sum(times) + np.sum(step_times)
+    ))
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-f", "--file", help="checkpoint to load", required=True, type=str)
+    parser.add_argument("-n", "--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
+
+    # TODO
+    # parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int)
+
+    parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
+    parser.add_argument("--render", help="render a single episode", action='store_true')
+    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
+    parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true')
+    args = parser.parse_args()
+
+    os.environ["OMP_NUM_THREADS"] = str(1)
+    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render,
+                    allow_skipping=args.allow_skipping, allow_caching=args.allow_caching)
diff --git a/reinforcement_learning/model.py b/reinforcement_learning/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..223f2f7707d973a1a9821d1ab44d8e2ef63b438f
--- /dev/null
+++ b/reinforcement_learning/model.py
@@ -0,0 +1,31 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DuelingQNetwork(nn.Module):
+    """Dueling Q-network (https://arxiv.org/abs/1511.06581)"""
+
+    def __init__(self, state_size, action_size, hidsize1=64, hidsize2=64):
+        super(DuelingQNetwork, self).__init__()
+
+        # value network
+        self.fc1_val = nn.Linear(state_size, hidsize1)
+        self.fc2_val = nn.Linear(hidsize1, hidsize2)
+        self.fc4_val = nn.Linear(hidsize2, 1)
+
+        # advantage network
+        self.fc1_adv = nn.Linear(state_size, hidsize1)
+        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
+        self.fc4_adv = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x):
+        val = F.relu(self.fc1_val(x))
+        val = F.relu(self.fc2_val(val))
+        val = self.fc4_val(val)
+
+        # advantage calculation
+        adv = F.relu(self.fc1_adv(x))
+        adv = F.relu(self.fc2_adv(adv))
+        adv = self.fc4_adv(adv)
+
+        return val + adv - adv.mean()
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a80c783d5d12d38aacc43952838d07a434d7d1d
--- /dev/null
+++ b/reinforcement_learning/multi_agent_training.py
@@ -0,0 +1,614 @@
+import os
+import random
+import sys
+from argparse import ArgumentParser, Namespace
+from collections import deque
+from datetime import datetime
+from pathlib import Path
+from pprint import pprint
+
+import numpy as np
+import psutil
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+from torch.utils.tensorboard import SummaryWriter
+
+from utils.deadlock_check import check_if_all_blocked
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.timer import Timer
+from utils.observation_utils import normalize_observation
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from reinforcement_learning.ppo.ppo_agent import PPOAgent
+
+from utils.extra import Extra, ExtraPolicy
+from reinforcement_learning.multi_policy import MultiPolicy
+
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+
+try:
+    import wandb
+
+    wandb.init(sync_tensorboard=True)
+except ImportError:
+    print("Install wandb to log to Weights & Biases")
+
+"""
+This file shows how to train multiple agents using a reinforcement learning approach.
+
+Documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html
+Results: https://app.wandb.ai/masterscrat/flatland-examples-reinforcement_learning/reports/Flatland-Examples--VmlldzoxNDI2MTA
+"""
+
+
+def create_rail_env(env_params, tree_observation, close_following):
+    n_agents = env_params.n_agents
+    x_dim = env_params.x_dim
+    y_dim = env_params.y_dim
+    n_cities = env_params.n_cities
+    max_rails_between_cities = env_params.max_rails_between_cities
+    max_rails_in_city = env_params.max_rails_in_city
+    seed = env_params.seed
+
+    # Break agents from time to time
+    malfunction_parameters = MalfunctionParameters(
+        malfunction_rate=env_params.malfunction_rate,
+        min_duration=20,
+        max_duration=50
+    )
+
+    return RailEnv(
+        width=x_dim, height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city
+        ),
+        schedule_generator=sparse_schedule_generator(),
+        number_of_agents=n_agents,
+        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
+        obs_builder_object=tree_observation,
+        random_seed=seed,
+        close_following=close_following
+    )
+
+
+def train_agent(train_params, train_env_params, eval_env_params, obs_params, close_following):
+    # Environment parameters
+    n_agents = train_env_params.n_agents
+    x_dim = train_env_params.x_dim
+    y_dim = train_env_params.y_dim
+    n_cities = train_env_params.n_cities
+    max_rails_between_cities = train_env_params.max_rails_between_cities
+    max_rails_in_city = train_env_params.max_rails_in_city
+    seed = train_env_params.seed
+
+    # Unique ID for this training
+    now = datetime.now()
+    training_id = now.strftime('%y%m%d%H%M%S')
+
+    # Observation parameters
+    observation_tree_depth = obs_params.observation_tree_depth
+    observation_radius = obs_params.observation_radius
+    observation_max_path_depth = obs_params.observation_max_path_depth
+
+    # Training parameters
+    eps_start = train_params.eps_start
+    eps_end = train_params.eps_end
+    eps_decay = train_params.eps_decay
+    n_episodes = train_params.n_episodes
+    checkpoint_interval = train_params.checkpoint_interval
+    render_interval = 1  # checkpoint_interval
+    n_eval_episodes = train_params.n_evaluation_episodes
+    restore_replay_buffer = train_params.restore_replay_buffer
+    save_replay_buffer = train_params.save_replay_buffer
+
+    # Set the seeds
+    random.seed(seed)
+    np.random.seed(seed)
+
+    # Observation builder
+    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+    if not train_params.use_extra_observation:
+        print("Create TreeObsForRailEnv")
+
+        def check_is_observation_valid(observation):
+            return observation
+
+        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+            return normalize_observation(observation, tree_depth, observation_radius)
+
+        tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+        tree_observation.check_is_observation_valid = check_is_observation_valid
+        tree_observation.get_normalized_observation = get_normalized_observation
+    else:
+        print("Create Extra-Observation")
+
+        def check_is_observation_valid(observation):
+            return True
+
+        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+            return observation
+
+        tree_observation = Extra(max_depth=observation_tree_depth)
+        tree_observation.check_is_observation_valid = check_is_observation_valid
+        tree_observation.get_normalized_observation = get_normalized_observation
+
+    # Setup the environments
+    train_env = create_rail_env(train_env_params, tree_observation, close_following)
+    train_env.reset(regenerate_schedule=True, regenerate_rail=True)
+    eval_env = create_rail_env(eval_env_params, tree_observation, close_following)
+    eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
+
+    if not train_params.use_extra_observation:
+        # Calculate the state size given the depth of the tree observation and the number of features
+        n_features_per_node = train_env.obs_builder.observation_dim
+        n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
+        state_size = n_features_per_node * n_nodes
+    else:
+        # Calculate the state size given the depth of the tree observation and the number of features
+        state_size = tree_observation.observation_dim
+
+    # Setup renderer
+    if train_params.render:
+        env_renderer = RenderTool(train_env, gl="PGL", agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS)
+
+    # The action space of flatland is 5 discrete actions
+    action_size = 5
+
+    # Max number of steps per episode
+    # This is the official formula used during evaluations
+    # See details in flatland.envs.schedule_generators.sparse_schedule_generator
+    # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+    max_steps = train_env._max_episode_steps
+
+    action_count = [0] * action_size
+    action_dict = dict()
+    agent_obs = [None] * n_agents
+    agent_prev_obs = [None] * n_agents
+    agent_prev_action = [2] * n_agents
+    update_values = [False] * n_agents
+
+    # Smoothed values used as target for hyperparameter tuning
+    smoothed_normalized_score = -1.0
+    smoothed_eval_normalized_score = -1.0
+    smoothed_completion = 0.0
+    smoothed_eval_completion = 0.0
+
+    scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
+    completion_window = deque(maxlen=checkpoint_interval)
+
+    # Double Dueling DQN policy
+    policy = DDDQNPolicy(state_size, action_size, train_params)
+    if False:
+        policy = ExtraPolicy(state_size, action_size)
+    if False:
+        policy = PPOAgent(state_size, action_size, n_agents)
+    if False:
+        policy = MultiPolicy(state_size, action_size, n_agents, train_env)
+    if True:
+        policy = DeadLockAvoidanceAgent(train_env,state_size, action_size)
+
+    # Load existing policy
+    if train_params.load_policy is not None:
+        policy.load(train_params.load_policy)
+
+    # Loads existing replay buffer
+    if restore_replay_buffer:
+        try:
+            policy.load_replay_buffer(restore_replay_buffer)
+            policy.test()
+        except RuntimeError as e:
+            print("\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?")
+            print(e)
+            exit(1)
+
+    try:
+        print(
+            "\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))
+    except:
+        print("\n💾 Don't have a Replay buffer")
+
+    hdd = psutil.disk_usage('/')
+    if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
+        print(
+            "⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(
+                hdd.free / (2 ** 30)))
+
+    # TensorBoard writer
+    writer = SummaryWriter()
+    writer.add_hparams(vars(train_params), {})  # FIXME
+    writer.add_hparams(vars(train_env_params), {})
+    writer.add_hparams(vars(obs_params), {})
+
+    training_timer = Timer()
+    training_timer.start()
+
+    print(
+        "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format(
+            train_env.get_num_agents(),
+            x_dim, y_dim,
+            n_episodes,
+            n_eval_episodes,
+            checkpoint_interval,
+            training_id
+        ))
+
+    for episode_idx in range(n_episodes + 1):
+        step_timer = Timer()
+        reset_timer = Timer()
+        learn_timer = Timer()
+        preproc_timer = Timer()
+        inference_timer = Timer()
+
+        # Reset environment
+        reset_timer.start()
+        obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
+        policy.reset()
+        reset_timer.end()
+
+        if train_params.render:
+            env_renderer.set_new_rail()
+
+        score = 0
+        nb_steps = 0
+        actions_taken = []
+
+        # Build initial agent-specific observations
+        for agent in train_env.get_agent_handles():
+            if tree_observation.check_is_observation_valid(obs[agent]):
+                agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth,
+                                                                               observation_radius=observation_radius)
+                agent_prev_obs[agent] = agent_obs[agent].copy()
+
+        # Run episode
+        for step in range(max_steps - 1):
+            inference_timer.start()
+            policy.start_step()
+            for agent in train_env.get_agent_handles():
+                if info['action_required'][agent]:
+                    update_values[agent] = True
+                    action = policy.act(agent,agent_obs[agent], eps=eps_start)
+
+                    action_count[action] += 1
+                    actions_taken.append(action)
+                else:
+                    # An action is not required if the train hasn't joined the railway network,
+                    # if it already reached its target, or if is currently malfunctioning.
+                    update_values[agent] = False
+                    action = 0
+                action_dict.update({agent: action})
+            policy.end_step()
+            inference_timer.end()
+
+            # Environment step
+            step_timer.start()
+            next_obs, all_rewards, done, info = train_env.step(action_dict)
+            step_timer.end()
+
+            # Render an episode at some interval
+            if train_params.render and episode_idx % render_interval == 0:
+                env_renderer.render_env(
+                    show=True,
+                    frames=False,
+                    show_observations=False,
+                    show_predictions=False
+                )
+
+            # Update replay buffer and train agent
+            for agent in train_env.get_agent_handles():
+                if update_values[agent] or done['__all__']:
+                    # Only learn from timesteps where somethings happened
+                    learn_timer.start()
+                    policy.step(agent,
+                                agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
+                                agent_obs[agent],
+                                done[agent] and done['__all__'])
+                    learn_timer.end()
+
+                    agent_prev_obs[agent] = agent_obs[agent].copy()
+                    agent_prev_action[agent] = action_dict[agent]
+
+                # Preprocess the new observations
+                if tree_observation.check_is_observation_valid(next_obs[agent]):
+                    preproc_timer.start()
+                    agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent],
+                                                                                   observation_tree_depth,
+                                                                                   observation_radius=observation_radius)
+                    preproc_timer.end()
+
+                score += all_rewards[agent]
+
+            nb_steps = step
+
+            if done['__all__']:
+                break
+
+            if check_if_all_blocked(train_env):
+                break
+        # Epsilon decay
+        eps_start = max(eps_end, eps_decay * eps_start)
+
+        # Collect information about training
+        tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles())
+        completion = tasks_finished / max(1, train_env.get_num_agents())
+        normalized_score = score / (max_steps * train_env.get_num_agents())
+        action_probs = action_count / max(1, np.sum(action_count))
+
+        scores_window.append(normalized_score)
+        completion_window.append(completion)
+        smoothed_normalized_score = np.mean(scores_window)
+        smoothed_completion = np.mean(completion_window)
+
+        # Print logs
+        if episode_idx % checkpoint_interval == 0:
+            policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
+
+            if save_replay_buffer:
+                policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl')
+
+            if train_params.render and False:
+                env_renderer.close_window()
+
+            # reset action count
+            action_count = [0] * action_size
+
+        print(
+            '\r🚂 Episode {:7}'
+            '\t 🏆 Score: {:7.3f}'
+            ' Avg: {:7.3f}'
+            '\t 💯 Done: {:6.2f}%'
+            ' Avg: {:6.2f}%'
+            '\t 🎲 Epsilon: {:.3f} '
+            '\t 🔀 Action Probs: {}'.format(
+                episode_idx,
+                normalized_score,
+                smoothed_normalized_score,
+                100 * completion,
+                100 * smoothed_completion,
+                eps_start,
+                format_action_prob(action_probs)
+            ), end=" ")
+
+        # Evaluate policy and log results at some interval
+        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0 and episode_idx > 0:
+            scores, completions, nb_steps_eval = eval_policy(eval_env,
+                                                             tree_observation,
+                                                             policy,
+                                                             train_params,
+                                                             obs_params)
+
+            writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_mean", np.mean(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_std", np.std(scores), episode_idx)
+            writer.add_histogram("evaluation/scores", np.array(scores), episode_idx)
+            writer.add_scalar("evaluation/completions_min", np.min(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_max", np.max(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_mean", np.mean(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_std", np.std(completions), episode_idx)
+            writer.add_histogram("evaluation/completions", np.array(completions), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_mean", np.mean(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval), episode_idx)
+            writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx)
+
+            smoothing = 0.9
+            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (
+                    1.0 - smoothing)
+            smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing)
+            writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx)
+            writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx)
+
+        # Save logs to tensorboard
+        writer.add_scalar("training/score", normalized_score, episode_idx)
+        writer.add_scalar("training/smoothed_score", smoothed_normalized_score, episode_idx)
+        writer.add_scalar("training/completion", np.mean(completion), episode_idx)
+        writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx)
+        writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
+        writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx)
+        writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx)
+        writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
+        writer.add_scalar("actions/forward", action_probs[RailEnvActions.MOVE_FORWARD], episode_idx)
+        writer.add_scalar("actions/right", action_probs[RailEnvActions.MOVE_RIGHT], episode_idx)
+        writer.add_scalar("actions/stop", action_probs[RailEnvActions.STOP_MOVING], episode_idx)
+        writer.add_scalar("training/epsilon", eps_start, episode_idx)
+        writer.add_scalar("training/buffer_size", len(policy.memory), episode_idx)
+        writer.add_scalar("training/loss", policy.loss, episode_idx)
+        writer.add_scalar("timer/reset", reset_timer.get(), episode_idx)
+        writer.add_scalar("timer/step", step_timer.get(), episode_idx)
+        writer.add_scalar("timer/learn", learn_timer.get(), episode_idx)
+        writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx)
+        writer.add_scalar("timer/total", training_timer.get_current(), episode_idx)
+
+
+def format_action_prob(action_probs):
+    action_probs = np.round(action_probs, 3)
+    actions = ["↻", "←", "↑", "→", "◼"]
+
+    buffer = ""
+    for action, action_prob in zip(actions, action_probs):
+        buffer += action + " " + "{:.3f}".format(action_prob) + " "
+
+    return buffer
+
+
+def eval_policy(env, tree_observation, policy, train_params, obs_params):
+    n_eval_episodes = train_params.n_evaluation_episodes
+    max_steps = env._max_episode_steps
+    tree_depth = obs_params.observation_tree_depth
+    observation_radius = obs_params.observation_radius
+
+    action_dict = dict()
+    scores = []
+    completions = []
+    nb_steps = []
+
+    for episode_idx in range(n_eval_episodes):
+        agent_obs = [None] * env.get_num_agents()
+        score = 0.0
+
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+
+        final_step = 0
+
+        for step in range(max_steps - 1):
+            for agent in env.get_agent_handles():
+                if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                    agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
+                                                                                   observation_radius=observation_radius)
+
+                action = 0
+                if info['action_required'][agent]:
+                    if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                        action = policy.act(agent,agent_obs[agent], eps=0.0)
+                action_dict.update({agent: action})
+
+            obs, all_rewards, done, info = env.step(action_dict)
+
+            for agent in env.get_agent_handles():
+                score += all_rewards[agent]
+
+            final_step = step
+
+            if done['__all__']:
+                break
+
+            if check_if_all_blocked(env):
+                break
+
+        normalized_score = score / (max_steps * env.get_num_agents())
+        scores.append(normalized_score)
+
+        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
+        completion = tasks_finished / max(1, env.get_num_agents())
+        completions.append(completion)
+
+        nb_steps.append(final_step)
+
+    print("\t✅ Eval: score {:7.3f} done {:6.2f}%".format(np.mean(scores), np.mean(completions) * 100.0))
+
+    return scores, completions, nb_steps
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=200000, type=int)
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
+                        type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
+    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int)
+    parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
+    parser.add_argument("--eps_end", help="min exploration", default=0.0001, type=float)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.999, type=float)
+    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int)
+    parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
+    parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
+    parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
+                        type=bool)
+    parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
+    parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
+    parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
+    parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
+    parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
+    parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
+    parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
+    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
+    parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
+    parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool)
+    parser.add_argument("--max_depth", help="max depth", default=-1, type=int)
+    parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2, type=int)
+    parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
+
+    training_params = parser.parse_args()
+    env_params = [
+        {
+            # Test_0
+            "n_agents": 5,
+            "x_dim": 25,
+            "y_dim": 25,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 50,
+            "seed": 0
+        },
+        {
+            # Test_1
+            "n_agents": 10,
+            "x_dim": 30,
+            "y_dim": 30,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 100,
+            "seed": 0
+        },
+        {
+            # Test_2
+            "n_agents": 20,
+            "x_dim": 30,
+            "y_dim": 30,
+            "n_cities": 3,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 200,
+            "seed": 0
+        },
+        {
+            # Test_3
+            "n_agents": 106,
+            "x_dim": 50,
+            "y_dim": 50,
+            "n_cities": 12,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 4,
+            "malfunction_rate": 1 / 50000,
+            "seed": 0
+        },
+    ]
+
+    obs_params = {
+        "observation_tree_depth": training_params.max_depth,  # FIXME
+        "observation_radius": 10,
+        "observation_max_path_depth": 30
+    }
+
+    print("close_following: ", training_params.close_following)
+
+
+    def check_env_config(id):
+        if id >= len(env_params) or id < 0:
+            print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(
+                len(env_params) - 1))
+            exit(1)
+
+
+    check_env_config(training_params.training_env_config)
+    check_env_config(training_params.evaluation_env_config)
+
+    training_env_params = env_params[training_params.training_env_config]
+    evaluation_env_params = env_params[training_params.evaluation_env_config]
+
+    print("\nTraining parameters:")
+    pprint(vars(training_params))
+    print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config))
+    pprint(training_env_params)
+    print("\nEvaluation environment parameters (Test_{}):".format(training_params.evaluation_env_config))
+    pprint(evaluation_env_params)
+    print("\nObservation parameters:")
+    pprint(obs_params)
+
+    os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads)
+    train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params),
+                Namespace(**obs_params), training_params.close_following)
diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ce6d0f218870063fe546bc4b2674f355867a7e5
--- /dev/null
+++ b/reinforcement_learning/multi_policy.py
@@ -0,0 +1,70 @@
+import numpy as np
+from flatland.envs.rail_env import RailEnvActions
+
+from reinforcement_learning.policy import Policy
+from reinforcement_learning.ppo.ppo_agent import PPOAgent
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+from utils.extra import ExtraPolicy
+
+
+class MultiPolicy(Policy):
+    def __init__(self, state_size, action_size, n_agents, env):
+        self.state_size = state_size
+        self.action_size = action_size
+        self.memory = []
+        self.loss = 0
+        self.dead_lock_avoidance_policy = DeadLockAvoidanceAgent(env, state_size, action_size)
+        self.extra_policy = ExtraPolicy(state_size, action_size)
+        self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)
+
+    def load(self, filename):
+        self.ppo_policy.load(filename)
+        self.extra_policy.load(filename)
+
+    def save(self, filename):
+        self.ppo_policy.save(filename)
+        self.extra_policy.save(filename)
+
+    def step(self, handle, state, action, reward, next_state, done):
+        action_extra_state = self.extra_policy.act(handle, state, 0.0)
+        action_extra_next_state = self.extra_policy.act(handle, next_state, 0.0)
+
+        extended_state = np.copy(state)
+        for action_itr in np.arange(self.action_size):
+            extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
+        extended_next_state = np.copy(next_state)
+        for action_itr in np.arange(self.action_size):
+            extended_next_state = np.append(extended_next_state, [int(action_extra_next_state == action_itr)])
+
+        self.extra_policy.step(handle, state, action, reward, next_state, done)
+        self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done)
+
+    def act(self, handle, state, eps=0.):
+        dead_lock_avoidance_action = self.dead_lock_avoidance_policy.act(handle, state, 0.0)
+        if dead_lock_avoidance_action == RailEnvActions.STOP_MOVING:
+            return RailEnvActions.STOP_MOVING
+        action_extra_state = self.extra_policy.act(handle, state, 0.0)
+        extended_state = np.copy(state)
+        for action_itr in np.arange(self.action_size):
+            extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
+        action_ppo = self.ppo_policy.act(handle, extended_state, eps)
+        self.loss = self.ppo_policy.loss
+        return action_ppo
+
+    def reset(self):
+        self.ppo_policy.reset()
+        self.extra_policy.reset()
+
+    def test(self):
+        self.ppo_policy.test()
+        self.extra_policy.test()
+
+    def start_step(self):
+        self.dead_lock_avoidance_policy.start_step()
+        self.extra_policy.start_step()
+        self.ppo_policy.start_step()
+
+    def end_step(self):
+        self.dead_lock_avoidance_policy.end_step()
+        self.extra_policy.end_step()
+        self.ppo_policy.end_step()
diff --git a/reinforcement_learning/ordered_policy.py b/reinforcement_learning/ordered_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc55ee13e489a526b1346a01bd8737652c77e9f
--- /dev/null
+++ b/reinforcement_learning/ordered_policy.py
@@ -0,0 +1,34 @@
+import sys
+from pathlib import Path
+
+import numpy as np
+
+from reinforcement_learning.policy import Policy
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.observation_utils import split_tree_into_feature_groups, min_gt
+
+
+class OrderedPolicy(Policy):
+    def __init__(self):
+        self.action_size = 5
+
+    def act(self, state, eps=0.):
+        _, distance, _ = split_tree_into_feature_groups(state, 1)
+        distance = distance[1:]
+        min_dist = min_gt(distance, 0)
+        min_direction = np.where(distance == min_dist)
+        if len(min_direction[0]) > 1:
+            return min_direction[0][-1] + 1
+        return min_direction[0] + 1
+
+    def step(self, state, action, reward, next_state, done):
+        return
+
+    def save(self, filename):
+        return
+
+    def load(self, filename):
+        return
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c77845e46540df8c2a0beff658728b1604cbd89
--- /dev/null
+++ b/reinforcement_learning/policy.py
@@ -0,0 +1,27 @@
+class Policy:
+    def step(self, handle, state, action, reward, next_state, done):
+        raise NotImplementedError
+
+    def act(self, state, eps=0.):
+        raise NotImplementedError
+
+    def save(self, filename):
+        pass
+
+    def load(self, filename):
+        pass
+
+    def test(self):
+        pass
+
+    def save_replay_buffer(self):
+        pass
+
+    def reset(self):
+        pass
+
+    def start_step(self):
+        pass
+
+    def end_step(self):
+        pass
diff --git a/reinforcement_learning/ppo/model.py b/reinforcement_learning/ppo/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..51b86ff16691c03f6a754405352bb4cf48e4b914
--- /dev/null
+++ b/reinforcement_learning/ppo/model.py
@@ -0,0 +1,20 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PolicyNetwork(nn.Module):
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
+        super().__init__()
+        self.fc1 = nn.Linear(state_size, hidsize1)
+        self.fc2 = nn.Linear(hidsize1, hidsize2)
+        # self.fc3 = nn.Linear(hidsize2, hidsize3)
+        self.output = nn.Linear(hidsize2, action_size)
+        self.softmax = nn.Softmax(dim=1)
+        self.bn0 = nn.BatchNorm1d(state_size, affine=False)
+
+    def forward(self, inputs):
+        x = self.bn0(inputs.float())
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        # x = F.relu(self.fc3(x))
+        return self.softmax(self.output(x))
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ba5a68bb30171e6279742e50dbdf1753846346e
--- /dev/null
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -0,0 +1,141 @@
+import os
+import random
+
+import numpy as np
+import torch
+from torch.distributions.categorical import Categorical
+
+from reinforcement_learning.policy import Policy
+from reinforcement_learning.ppo.model import PolicyNetwork
+from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer
+
+BUFFER_SIZE = 32_000
+BATCH_SIZE = 4096
+GAMMA = 0.8
+LR = 0.5e-4
+CLIP_FACTOR = .005
+UPDATE_EVERY = 30
+
+device = torch.device("cpu")  # "cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+class PPOAgent(Policy):
+    def __init__(self, state_size, action_size, num_agents, env):
+        self.action_size = action_size
+        self.state_size = state_size
+        self.num_agents = num_agents
+        self.policy = PolicyNetwork(state_size, action_size).to(device)
+        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
+        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
+        self.episodes = [Episode() for _ in range(num_agents)]
+        self.memory = ReplayBuffer(BUFFER_SIZE)
+        self.t_step = 0
+        self.loss = 0
+        self.env = env
+
+    def reset(self):
+        self.finished = [False] * len(self.episodes)
+        self.tot_reward = [0] * self.num_agents
+
+    # Decide on an action to take in the environment
+
+    def act(self, handle, state, eps=None):
+        if True:
+            self.policy.eval()
+            with torch.no_grad():
+                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
+                return Categorical(output).sample().item()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            self.policy.eval()
+            with torch.no_grad():
+                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
+                return Categorical(output).sample().item()
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    # Record the results of the agent's action and update the model
+    def step(self, handle, state, action, reward, next_state, done):
+        if not self.finished[handle]:
+            # Push experience into Episode memory
+            self.tot_reward[handle] += reward
+            if done == 1:
+                reward = 1  # self.tot_reward[handle]
+            else:
+                reward = 0
+
+            self.episodes[handle].push(state, action, reward, next_state, done)
+
+            # When we finish the episode, discount rewards and push the experience into replay memory
+            if done:
+                self.episodes[handle].discount_rewards(GAMMA)
+                self.memory.push_episode(self.episodes[handle])
+                self.episodes[handle].reset()
+                self.finished[handle] = True
+
+        # Perform a gradient update every UPDATE_EVERY time steps
+        self.t_step = (self.t_step + 1) % UPDATE_EVERY
+        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
+            self._learn(*self.memory.sample(BATCH_SIZE, device))
+
+    def _clip_gradient(self, model, clip):
+
+        for p in model.parameters():
+            p.grad.data.clamp_(-clip, clip)
+        return
+
+        """Computes a gradient clipping coefficient based on gradient norm."""
+        totalnorm = 0
+        for p in model.parameters():
+            if p.grad is not None:
+                modulenorm = p.grad.data.norm()
+                totalnorm += modulenorm ** 2
+        totalnorm = np.sqrt(totalnorm)
+        coeff = min(1, clip / (totalnorm + 1e-6))
+
+        for p in model.parameters():
+            if p.grad is not None:
+                p.grad.mul_(coeff)
+
+    def _learn(self, states, actions, rewards, next_state, done):
+        self.policy.train()
+
+        responsible_outputs = torch.gather(self.policy(states), 1, actions)
+        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
+
+        # rewards = rewards - rewards.mean()
+        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
+        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
+        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
+        self.loss = loss
+
+        # Compute loss and perform a gradient step
+        self.old_policy.load_state_dict(self.policy.state_dict())
+        self.optimizer.zero_grad()
+        loss.backward()
+        # self._clip_gradient(self.policy, 1.0)
+        self.optimizer.step()
+
+    # Checkpointing methods
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        torch.save(self.policy.state_dict(), filename + ".policy")
+        torch.save(self.optimizer.state_dict(), filename + ".optimizer")
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        if os.path.exists(filename + ".policy"):
+            print(' >> ', filename + ".policy")
+            try:
+                self.policy.load_state_dict(torch.load(filename + ".policy"))
+            except:
+                print(" >> failed!")
+                pass
+        if os.path.exists(filename + ".optimizer"):
+            print(' >> ', filename + ".optimizer")
+            try:
+                self.optimizer.load_state_dict(torch.load(filename + ".optimizer"))
+            except:
+                print(" >> failed!")
+                pass
diff --git a/reinforcement_learning/ppo/replay_memory.py b/reinforcement_learning/ppo/replay_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6619b40169597d7a4b379f4ce2c9ddccd4cd9b
--- /dev/null
+++ b/reinforcement_learning/ppo/replay_memory.py
@@ -0,0 +1,53 @@
+import torch
+import random
+import numpy as np
+from collections import namedtuple, deque, Iterable
+
+
+Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done"))
+
+
+class Episode:
+    memory = []
+
+    def reset(self):
+        self.memory = []
+
+    def push(self, *args):
+        self.memory.append(tuple(args))
+
+    def discount_rewards(self, gamma):
+        running_add = 0.
+        for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]:
+            running_add = running_add * gamma + reward
+            self.memory[i] = (state, action, running_add, *rest)
+
+
+class ReplayBuffer:
+    def __init__(self, buffer_size):
+        self.memory = deque(maxlen=buffer_size)
+
+    def push(self, state, action, reward, next_state, done):
+        self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done))
+
+    def push_episode(self, episode):
+        for step in episode.memory:
+            self.push(*step)
+
+    def sample(self, batch_size, device):
+        experiences = random.sample(self.memory, k=batch_size)
+
+        states      = torch.from_numpy(self.stack([e.state      for e in experiences])).float().to(device)
+        actions     = torch.from_numpy(self.stack([e.action     for e in experiences])).long().to(device)
+        rewards     = torch.from_numpy(self.stack([e.reward     for e in experiences])).float().to(device)
+        next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device)
+        dones       = torch.from_numpy(self.stack([e.done       for e in experiences]).astype(np.uint8)).float().to(device)
+
+        return states, actions, rewards, next_states, dones
+
+    def stack(self, states):
+        sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1]
+        return np.reshape(np.array(states), (len(states), *sub_dims))
+
+    def __len__(self):
+        return len(self.memory)
diff --git a/reinforcement_learning/sequential_agent_training.py b/reinforcement_learning/sequential_agent_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca19d1fcbbb4e3508a16b847d4b4cfcefc6aad98
--- /dev/null
+++ b/reinforcement_learning/sequential_agent_training.py
@@ -0,0 +1,78 @@
+import sys
+import numpy as np
+
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from pathlib import Path
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.ordered_policy import OrderedPolicy
+
+np.random.seed(2)
+
+x_dim = 20  # np.random.randint(8, 20)
+y_dim = 20  # np.random.randint(8, 20)
+n_agents = 10  # np.random.randint(3, 8)
+n_goals = n_agents + np.random.randint(0, 3)
+min_dist = int(0.75 * min(x_dim, y_dim))
+
+env = RailEnv(width=x_dim,
+              height=y_dim,
+              rail_generator=complex_rail_generator(
+                  nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+                  max_dist=99999,
+                  seed=0
+              ),
+              schedule_generator=complex_schedule_generator(),
+              obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()),
+              number_of_agents=n_agents)
+env.reset(True, True)
+
+tree_depth = 1
+observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
+env_renderer = RenderTool(env, gl="PGL", )
+handle = env.get_agent_handles()
+n_episodes = 1
+max_steps = 100 * (env.height + env.width)
+record_images = False
+policy = OrderedPolicy()
+action_dict = dict()
+
+for trials in range(1, n_episodes + 1):
+
+    # Reset environment
+    obs, info = env.reset(True, True)
+    done = env.dones
+    env_renderer.reset()
+    frame_step = 0
+
+    # Run episode
+    for step in range(max_steps):
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
+
+        if record_images:
+            env_renderer.gl.save_image("./Images/flatland_frame_{:04d}.bmp".format(frame_step))
+            frame_step += 1
+
+        # Action
+        acting_agent = 0
+        for a in range(env.get_num_agents()):
+            if done[a]:
+                acting_agent += 1
+            if a == acting_agent:
+                action = policy.act(obs[a])
+            else:
+                action = 4
+            action_dict.update({a: action})
+
+        # Environment step
+        obs, all_rewards, done, _ = env.step(action_dict)
+
+        if done['__all__']:
+            break
diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..79a88b25a8bc63011ef04208af53858dbb079d7d
--- /dev/null
+++ b/reinforcement_learning/single_agent_training.py
@@ -0,0 +1,203 @@
+import random
+import sys
+from argparse import ArgumentParser, Namespace
+from collections import deque
+from pathlib import Path
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from utils.observation_utils import normalize_observation
+from flatland.envs.observations import TreeObsForRailEnv
+
+"""
+This file shows how to train a single agent using a reinforcement learning approach.
+Documentation: https://flatland.aicrowd.com/getting-started/rl/single-agent.html
+
+This is a simple method used for demonstration purposes.
+multi_agent_training.py is a better starting point to train your own solution!
+"""
+
+
+def train_agent(n_episodes):
+    # Environment parameters
+    n_agents = 1
+    x_dim = 25
+    y_dim = 25
+    n_cities = 4
+    max_rails_between_cities = 2
+    max_rails_in_city = 3
+    seed = 42
+
+    # Observation parameters
+    observation_tree_depth = 2
+    observation_radius = 10
+
+    # Exploration parameters
+    eps_start = 1.0
+    eps_end = 0.01
+    eps_decay = 0.997  # for 2500ts
+
+    # Set the seeds
+    random.seed(seed)
+    np.random.seed(seed)
+
+    # Observation builder
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth)
+
+    # Setup the environment
+    env = RailEnv(
+        width=x_dim,
+        height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            seed=seed,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city
+        ),
+        schedule_generator=sparse_schedule_generator(),
+        number_of_agents=n_agents,
+        obs_builder_object=tree_observation
+    )
+
+    env.reset(True, True)
+
+    # Calculate the state size given the depth of the tree observation and the number of features
+    n_features_per_node = env.obs_builder.observation_dim
+    n_nodes = 0
+    for i in range(observation_tree_depth + 1):
+        n_nodes += np.power(4, i)
+    state_size = n_features_per_node * n_nodes
+
+    # The action space of flatland is 5 discrete actions
+    action_size = 5
+
+    # Max number of steps per episode
+    # This is the official formula used during evaluations
+    max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+
+    action_dict = dict()
+
+    # And some variables to keep track of the progress
+    scores_window = deque(maxlen=100)  # todo smooth when rendering instead
+    completion_window = deque(maxlen=100)
+    scores = []
+    completion = []
+    action_count = [0] * action_size
+    agent_obs = [None] * env.get_num_agents()
+    agent_prev_obs = [None] * env.get_num_agents()
+    agent_prev_action = [2] * env.get_num_agents()
+    update_values = False
+
+    # Training parameters
+    training_parameters = {
+        'buffer_size': int(1e5),
+        'batch_size': 32,
+        'update_every': 8,
+        'learning_rate': 0.5e-4,
+        'tau': 1e-3,
+        'gamma': 0.99,
+        'buffer_min_size': 0,
+        'hidden_size': 256,
+        'use_gpu': False
+    }
+
+    # Double Dueling DQN policy
+    policy = DDDQNPolicy(state_size, action_size, Namespace(**training_parameters))
+
+    for episode_idx in range(n_episodes):
+        score = 0
+
+        # Reset environment
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+
+        # Build agent specific observations
+        for agent in env.get_agent_handles():
+            if obs[agent]:
+                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius)
+                agent_prev_obs[agent] = agent_obs[agent].copy()
+
+        # Run episode
+        for step in range(max_steps - 1):
+            for agent in env.get_agent_handles():
+                if info['action_required'][agent]:
+                    # If an action is required, we want to store the obs at that step as well as the action
+                    update_values = True
+                    action = policy.act(agent_obs[agent], eps=eps_start)
+                    action_count[action] += 1
+                else:
+                    update_values = False
+                    action = 0
+                action_dict.update({agent: action})
+
+            # Environment step
+            next_obs, all_rewards, done, info = env.step(action_dict)
+
+            # Update replay buffer and train agent
+            for agent in range(env.get_num_agents()):
+                # Only update the values when we are done or when an action was taken and thus relevant information is present
+                if update_values or done[agent]:
+                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent])
+
+                    agent_prev_obs[agent] = agent_obs[agent].copy()
+                    agent_prev_action[agent] = action_dict[agent]
+
+                if next_obs[agent]:
+                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=10)
+
+                score += all_rewards[agent]
+
+            if done['__all__']:
+                break
+
+        # Epsilon decay
+        eps_start = max(eps_end, eps_decay * eps_start)
+
+        # Collection information about training
+        tasks_finished = np.sum([int(done[idx]) for idx in env.get_agent_handles()])
+        completion_window.append(tasks_finished / max(1, env.get_num_agents()))
+        scores_window.append(score / (max_steps * env.get_num_agents()))
+        completion.append((np.mean(completion_window)))
+        scores.append(np.mean(scores_window))
+        action_probs = action_count / np.sum(action_count)
+
+        if episode_idx % 100 == 0:
+            end = "\n"
+            torch.save(policy.qnetwork_local, './checkpoints/single-' + str(episode_idx) + '.pth')
+            action_count = [1] * action_size
+        else:
+            end = " "
+
+        print('\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+            env.get_num_agents(),
+            x_dim, y_dim,
+            episode_idx,
+            np.mean(scores_window),
+            100 * np.mean(completion_window),
+            eps_start,
+            action_probs
+        ), end=end)
+
+    # Plot overall training progress at the end
+    plt.plot(scores)
+    plt.show()
+
+    plt.plot(completion)
+    plt.show()
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500, type=int)
+    args = parser.parse_args()
+
+    train_agent(args.n_episodes)
diff --git a/run.py b/run.py
index 5cd92d3c776b7f589fc6e33544470d5d3401f0b4..d920c06630b2ca9ad32bba6ac7a0228b7e45ea7f 100644
--- a/run.py
+++ b/run.py
@@ -1,15 +1,15 @@
 import time
 
 import numpy as np
+from flatland.core.env_observation_builder import DummyObservationBuilder
 from flatland.envs.agent_utils import RailAgentStatus
 from flatland.evaluators.client import FlatlandRemoteClient
 
-
 #####################################################################
 # Instantiate a Remote Client
 #####################################################################
+from src.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 from src.extra import Extra
-from src.simple.DeadLock_Avoidance import calculate_one_step_heuristics, calculate_one_step_package_implementation,calculate_one_step,calculate_one_step_primitive_implementation
 
 remote_client = FlatlandRemoteClient()
 
@@ -21,15 +21,17 @@ remote_client = FlatlandRemoteClient()
 # compute the necessary action for this step for all (or even some)
 # of the agents
 #####################################################################
-def my_controller_RL(extra: Extra, observation, info):
-    return extra.rl_agent_act(observation, info)
+# def my_controller_RL(extra: Extra, observation, info):
+#     return extra.rl_agent_act(observation, info)
 
-def my_controller(local_env, obs, number_of_agents):
-    _action, _ = calculate_one_step(extra.env)
-    # _action, _ = calculate_one_step_package_implementation(local_env)
-    # _action, _ = calculate_one_step_primitive_implementation(local_env)
-    # _action, _ = calculate_one_step_heuristics(local_env)
-    return _action
+def my_controller(policy):
+    policy.start_step()
+    actions = {}
+    for handle in range(policy.env.get_num_agents()):
+        a = policy.act(handle, None, 0)
+        actions.update({handle: a})
+    policy.end_step()
+    return actions
 
 
 #####################################################################
@@ -39,7 +41,8 @@ def my_controller(local_env, obs, number_of_agents):
 # the example here : 
 # https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
 #####################################################################
-my_observation_builder = Extra(max_depth=1)
+# my_observation_builder = Extra(max_depth=1)
+my_observation_builder = DummyObservationBuilder()
 
 # Or if you want to use your own approach to build the observation from the env_step, 
 # please feel free to pass a DummyObservationBuilder() object as mentioned below,
@@ -99,7 +102,9 @@ while True:
     local_env = remote_client.env
     number_of_agents = len(local_env.agents)
 
-    # Now we enter into another infinite loop where we 
+    policy = DeadLockAvoidanceAgent(local_env, None, None)
+
+    # Now we enter into another infinite loop where we
     # compute the actions for all the individual steps in this episode
     # until the episode is `done`
     # 
@@ -131,7 +136,7 @@ while True:
         # Compute the action for this step by using the previously 
         # defined controller
         time_start = time.time()
-        action = my_controller(extra, observation, info)
+        action = my_controller(policy)
         time_taken = time.time() - time_start
         time_taken_by_controller.append(time_taken)
 
diff --git a/src/dead_lock_avoidance_agent.py b/src/dead_lock_avoidance_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..43f1b4ae15aebae1348e0eadc5f9cab243511a28
--- /dev/null
+++ b/src/dead_lock_avoidance_agent.py
@@ -0,0 +1,116 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+
+from reinforcement_learning.policy import Policy
+from utils.shortest_Distance_walker import ShortestDistanceWalker
+
+
+class MyWalker(ShortestDistanceWalker):
+    def __init__(self, env: RailEnv, agent_positions):
+        super().__init__(env)
+        self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
+                                                     self.env.height,
+                                                     self.env.width),
+                                                    dtype=int) - 1
+
+        self.agent_positions = agent_positions
+
+        self.agent_map = {}
+
+    def getData(self):
+        return self.shortest_distance_agent_map
+
+    def callback(self, handle, agent, position, direction, action):
+        opp_a = self.agent_positions[position]
+        if opp_a != -1 and opp_a != handle:
+            d = self.agent_map.get(handle, [])
+            d.append(opp_a)
+            if self.env.agents[opp_a].direction != direction:
+                self.agent_map.update({handle: d})
+        self.shortest_distance_agent_map[(handle, position[0], position[1])] = direction
+
+
+class DeadLockAvoidanceAgent(Policy):
+    def __init__(self, env: RailEnv, state_size, action_size):
+        self.env = env
+        self.action_size = action_size
+        self.state_size = state_size
+        self.memory = []
+        self.loss = 0
+        self.agent_can_move = {}
+
+    def step(self, handle, state, action, reward, next_state, done):
+        pass
+
+    def act(self, handle, state, eps=0.):
+        agent = self.env.agents[handle]
+        #if handle > self.env._elapsed_steps:
+        #    return RailEnvActions.STOP_MOVING
+        if agent.status == RailAgentStatus.ACTIVE:
+            self.active_agent_cnt += 1
+        #if agent.status > 20:
+        #    return RailEnvActions.STOP_MOVING
+        check = self.agent_can_move.get(handle, None)
+        if check is None:
+            # print(handle, RailEnvActions.STOP_MOVING)
+            return RailEnvActions.STOP_MOVING
+
+        return check[3]
+
+    def reset(self):
+        pass
+
+    def start_step(self):
+        self.active_agent_cnt = 0
+        self.shortest_distance_mapper()
+
+    def end_step(self):
+        # print("#A:", self.active_agent_cnt, "/", self.env.get_num_agents(),self.env._elapsed_steps)
+        pass
+
+    def get_actions(self):
+        pass
+
+    def shortest_distance_mapper(self):
+
+        # build map with agent positions (only active agents)
+        agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status <= RailAgentStatus.ACTIVE:
+                if agent.position is not None:
+                    agent_positions[agent.position] = handle
+
+        my_walker = MyWalker(self.env, agent_positions)
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status <= RailAgentStatus.ACTIVE:
+                my_walker.walk_to_target(handle)
+        self.shortest_distance_agent_map = my_walker.getData()
+
+        self.agent_can_move = {}
+        agent_positions_map = np.clip(agent_positions + 1, 0, 1)
+        for handle in range(self.env.get_num_agents()):
+            opp_agents = my_walker.agent_map.get(handle, [])
+            me = np.clip(self.shortest_distance_agent_map[handle] + 1, 0, 1)
+            next_step_ok = True
+            next_position, next_direction, action = my_walker.walk_one_step(handle)
+            for opp_a in opp_agents:
+                opp = np.clip(self.shortest_distance_agent_map[opp_a] + 1, 0, 1)
+                delta = np.clip(me - opp - agent_positions_map, 0, 1)
+                if (np.sum(delta) > 1):
+                    next_step_ok = False
+            if next_step_ok:
+                self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]})
+
+        if False:
+            a = np.floor(np.sqrt(self.env.get_num_agents()))
+            b = np.ceil(self.env.get_num_agents() / a)
+            for handle in range(self.env.get_num_agents()):
+                plt.subplot(a, b, handle + 1)
+                plt.imshow(self.shortest_distance_agent_map[handle])
+            # plt.colorbar()
+            plt.show(block=False)
+            plt.pause(0.001)
diff --git a/src/shortest_Distance_walker.py b/src/shortest_Distance_walker.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69ebcb0a98f732cc44849b10858b2e42a376a23
--- /dev/null
+++ b/src/shortest_Distance_walker.py
@@ -0,0 +1,69 @@
+import numpy as np
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
+
+
+class ShortestDistanceWalker:
+    def __init__(self, env: RailEnv):
+        self.env = env
+
+    def walk(self, handle, position, direction):
+        possible_transitions = self.env.rail.get_transitions(*position, direction)
+        num_transitions = fast_count_nonzero(possible_transitions)
+        if num_transitions == 1:
+            new_direction = fast_argmax(possible_transitions)
+            new_position = get_new_position(position, new_direction)
+            dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
+            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD
+        else:
+            min_distances = []
+            positions = []
+            directions = []
+            for new_direction in [(direction + i) % 4 for i in range(-1, 2)]:
+                if possible_transitions[new_direction]:
+                    new_position = get_new_position(position, new_direction)
+                    min_distances.append(
+                        self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction])
+                    positions.append(new_position)
+                    directions.append(new_direction)
+                else:
+                    min_distances.append(np.inf)
+                    positions.append(None)
+                    directions.append(None)
+
+        a = np.argmin(min_distances)
+        return positions[a], directions[a], min_distances[a], a + 1
+
+    def callback(self, handle, agent, position, direction, action):
+        pass
+
+    def walk_to_target(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        while (position != agent.target):
+            position, direction, dist, action = self.walk(handle, position, direction)
+            if position is None:
+                break
+            self.callback(handle, agent, position, direction, action)
+
+    def callback_one_step(self, handle, agent, position, direction, action):
+        pass
+
+    def walk_one_step(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        if (position != agent.target):
+            new_position, new_direction, dist, action = self.walk(handle, position, direction)
+            if new_position is None:
+                return position, direction, RailEnvActions.STOP_MOVING
+            self.callback_one_step(handle, agent, new_position, new_direction, action)
+        return new_position, new_direction, action
diff --git a/src/simple/ClassifyProblemInstance.py b/src/simple/ClassifyProblemInstance.py
deleted file mode 100644
index cabd67e58ca069c9e7c4fa46d0777c8cdaba1e98..0000000000000000000000000000000000000000
--- a/src/simple/ClassifyProblemInstance.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from enum import IntEnum
-
-import numpy as np
-
-
-class ProblemInstanceClass(IntEnum):
-    SHORTEST_PATH_ONLY = 0
-    SHORTEST_PATH_ORDERING_PROBLEM = 1
-    REQUIRE_ALTERNATIVE_PATH = 2
-
-
-def check_is_only_shortest_path_problem(env, project_path_matrix):
-    x = project_path_matrix.copy()
-    x[x < 2] = 0
-    return np.sum(x) == 0
-
-
-def check_is_shortest_path_and_ordering_problem(env, project_path_matrix):
-    x = project_path_matrix.copy()
-    for a in range(env.get_num_agents()):
-        # loop over all path and project start position and target into the project_path_matrix
-        agent = env.agents[a]
-        if x[agent.position[0]][agent.position[1]] > 1:
-            return False
-        if x[agent.target[0]][agent.target[1]] > 1:
-            return False
-    return True
-
-
-def check_is_require_alternative_path(env, project_path_matrix):
-    paths = env.dev_pred_dict
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        path = paths[a]
-        for path_loop in range(len(path)):
-            p = path[path_loop]
-            if p[0] == agent.target[0] and p[1] == agent.target[1]:
-                break
-            if project_path_matrix[p[0]][p[1]] > 1:
-                # potential overlapping path found
-                for opp_a in range(env.get_num_agents()):
-                    opp_agent = env.agents[opp_a]
-                    opp_path = paths[opp_a]
-                    if p[0] == opp_agent.position[0] and p[1] == opp_agent.position[1]:
-                        opp_path_loop = 0
-                        tmp_path_loop = path_loop
-                        while True:
-                            if tmp_path_loop > len(path) - 1:
-                                break
-                            opp_p = opp_path[opp_path_loop]
-                            tmp_p = path[tmp_path_loop + 1]
-                            if opp_p[0] == opp_agent.target[0] and opp_p[1] == opp_agent.target[1]:
-                                return True
-                            if not (opp_p[0] == tmp_p[0] and opp_p[1] == tmp_p[1]):
-                                break
-                            if tmp_p[0] == agent.target[0] and tmp_p[1] == agent.target[1]:
-                                break
-                            opp_path_loop += 1
-                            tmp_path_loop += 1
-
-    return False
-
-
-def classify_problem_instance(env):
-    # shortest path from ShortesPathPredictorForRailEnv
-    paths = env.dev_pred_dict
-
-    project_path_matrix = np.zeros(shape=(env.height, env.width))
-    for a in range(env.get_num_agents()):
-        # loop over all path and project start position and target into the project_path_matrix
-        agent = env.agents[a]
-        project_path_matrix[agent.position[0]][agent.position[1]] += 1.0
-        project_path_matrix[agent.target[0]][agent.target[1]] += 1.0
-
-        if not (agent.target[0] == agent.position[0] and agent.target[1] == agent.position[1]):
-            # project the whole path into
-            path = paths[a]
-            for path_loop in range(len(path)):
-                p = path[path_loop]
-                if p[0] == agent.target[0] and p[1] == agent.target[1]:
-                    break
-                else:
-                    project_path_matrix[p[0]][p[1]] += 1.0
-
-    return \
-        {
-            # analyse : SHORTEST_PATH_ONLY -> if conflict_mat does not contain any number > 1
-            "SHORTEST_PATH_ONLY": check_is_only_shortest_path_problem(env, project_path_matrix),
-            # analyse : SHORTEST_PATH_ORDERING_PROBLEM -> if agent_start and agent_target position does not contain any number > 1
-            "SHORTEST_PATH_ORDERING_PROBLEM": check_is_shortest_path_and_ordering_problem(env, project_path_matrix),
-            # analyse : REQUIRE_ALTERNATIVE_PATH -> if agent_start and agent_target position does not contain any number > 1
-            "REQUIRE_ALTERNATIVE_PATH": check_is_require_alternative_path(env, project_path_matrix)
-
-        }
diff --git a/src/simple/DeadLock_Avoidance.py b/src/simple/DeadLock_Avoidance.py
deleted file mode 100644
index 7e80a46878cdcece319aade2c97097e609f1f205..0000000000000000000000000000000000000000
--- a/src/simple/DeadLock_Avoidance.py
+++ /dev/null
@@ -1,574 +0,0 @@
-import math
-from typing import Dict, List, Optional, Tuple, Set
-from typing import NamedTuple
-
-import numpy as np
-from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.core.transition_map import GridTransitionMap
-from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.distance_map import DistanceMap
-from flatland.envs.rail_env import RailEnvNextAction, RailEnvActions
-from flatland.envs.rail_env_shortest_paths import get_shortest_paths
-from flatland.utils.ordered_set import OrderedSet
-
-WalkingElement = NamedTuple('WalkingElement',
-                            [('position', Tuple[int, int]), ('direction', int),
-                             ('next_action_element', RailEnvActions)])
-
-
-def get_valid_move_actions_(agent_direction: Grid4TransitionsEnum,
-                            agent_position: Tuple[int, int],
-                            rail: GridTransitionMap) -> Set[RailEnvNextAction]:
-    """
-    Get the valid move actions (forward, left, right) for an agent.
-
-    Parameters
-    ----------
-    agent_direction : Grid4TransitionsEnum
-    agent_position: Tuple[int,int]
-    rail : GridTransitionMap
-
-
-    Returns
-    -------
-    Set of `RailEnvNextAction` (tuples of (action,position,direction))
-        Possible move actions (forward,left,right) and the next position/direction they lead to.
-        It is not checked that the next cell is free.
-    """
-    valid_actions: Set[RailEnvNextAction] = OrderedSet()
-    possible_transitions = rail.get_transitions(*agent_position, agent_direction)
-    num_transitions = np.count_nonzero(possible_transitions)
-    # Start from the current orientation, and see which transitions are available;
-    # organize them as [left, forward, right], relative to the current orientation
-    # If only one transition is possible, the forward branch is aligned with it.
-    if rail.is_dead_end(agent_position):
-        action = RailEnvActions.MOVE_FORWARD
-        exit_direction = (agent_direction + 2) % 4
-        if possible_transitions[exit_direction]:
-            new_position = get_new_position(agent_position, exit_direction)
-            valid_actions.add(RailEnvNextAction(action, new_position, exit_direction))
-    elif num_transitions == 1:
-        action = RailEnvActions.MOVE_FORWARD
-        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
-            if possible_transitions[new_direction]:
-                new_position = get_new_position(agent_position, new_direction)
-                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
-    else:
-        for new_direction in [(agent_direction + i) % 4 for i in range(-1, 2)]:
-            if possible_transitions[new_direction]:
-                if new_direction == agent_direction:
-                    action = RailEnvActions.MOVE_FORWARD
-                elif new_direction == (agent_direction + 1) % 4:
-                    action = RailEnvActions.MOVE_RIGHT
-                elif new_direction == (agent_direction - 1) % 4:
-                    action = RailEnvActions.MOVE_LEFT
-                else:
-                    raise Exception("Illegal state")
-
-                new_position = get_new_position(agent_position, new_direction)
-                valid_actions.add(RailEnvNextAction(action, new_position, new_direction))
-    return valid_actions
-
-
-# N.B. get_shortest_paths is not part of distance_map since it refers to RailEnvActions (would lead to circularity!)
-def get_paths(distance_map: DistanceMap, max_depth: Optional[int] = None, agent_handle: Optional[int] = None) \
-        -> Dict[int, Optional[List[WalkingElement]]]:
-    """
-    Computes the shortest path for each agent to its target and the action to be taken to do so.
-    The paths are derived from a `DistanceMap`.
-
-    If there is no path (rail disconnected), the path is given as None.
-    The agent state (moving or not) and its speed are not taken into account
-
-    example:
-            agent_fixed_travel_paths = get_shortest_paths(env.distance_map, None, agent.handle)
-            path = agent_fixed_travel_paths[agent.handle]
-
-    Parameters
-    ----------
-    distance_map : reference to the distance_map
-    max_depth : max path length, if the shortest path is longer, it will be cutted
-    agent_handle : if set, the shortest for agent.handle will be returned , otherwise for all agents
-
-    Returns
-    -------
-        Dict[int, Optional[List[WalkingElement]]]
-
-    """
-    shortest_paths = dict()
-
-    def _shortest_path_for_agent(agent):
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            position = agent.initial_position
-        elif agent.status == RailAgentStatus.ACTIVE:
-            position = agent.position
-        elif agent.status == RailAgentStatus.DONE:
-            position = agent.target
-        else:
-            shortest_paths[agent.handle] = None
-            return
-        direction = agent.direction
-        shortest_paths[agent.handle] = []
-        distance = math.inf
-        depth = 0
-        cnt = 0
-        while (position != agent.target and (max_depth is None or depth < max_depth)) and cnt < 1000:
-            cnt = cnt + 1
-            next_actions = get_valid_move_actions_(direction, position, distance_map.rail)
-            best_next_action = None
-
-            for next_action in next_actions:
-                next_action_distance = distance_map.get()[
-                    agent.handle, next_action.next_position[0], next_action.next_position[
-                        1], next_action.next_direction]
-                if next_action_distance < distance:
-                    best_next_action = next_action
-                    distance = next_action_distance
-
-            for next_action in next_actions:
-                if next_action.action == RailEnvActions.MOVE_LEFT:
-                    next_action_distance = distance_map.get()[
-                        agent.handle, next_action.next_position[0], next_action.next_position[
-                            1], next_action.next_direction]
-                    if abs(next_action_distance - distance) < 5:
-                        best_next_action = next_action
-                        distance = next_action_distance
-
-            shortest_paths[agent.handle].append(WalkingElement(position, direction, best_next_action))
-            depth += 1
-
-            # if there is no way to continue, the rail must be disconnected!
-            # (or distance map is incorrect)
-            if best_next_action is None:
-                shortest_paths[agent.handle] = None
-                return
-
-            position = best_next_action.next_position
-            direction = best_next_action.next_direction
-        if max_depth is None or depth < max_depth:
-            shortest_paths[agent.handle].append(
-                WalkingElement(position, direction,
-                               RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction)))
-
-    if agent_handle is not None:
-        _shortest_path_for_agent(distance_map.agents[agent_handle])
-    else:
-        for agent in distance_map.agents:
-            _shortest_path_for_agent(agent)
-
-    return shortest_paths
-
-
-def agent_fake_position(agent):
-    if agent.position is not None:
-        return (agent.position[0], agent.position[1], 0)
-    return (-agent.handle - 1, -1, None)
-
-
-def compare_position_equal(a, b):
-    if a is None and b is None:
-        return True
-    if a is None or b is None:
-        return False
-    return (a[0] == b[0] and a[1] == b[1])
-
-
-def calc_conflict_matrix_next_step(env, paths, do_move, agent_position_matrix, agent_target_matrix,
-                                   agent_next_position_matrix):
-    # look step forward
-    conflict_mat = np.zeros(shape=(env.get_num_agents(), env.get_num_agents())) - 1
-
-    # calculate weighted (priority)
-    priority = np.arange(env.get_num_agents()).astype(float)
-    unique_ordered_priority = np.argsort(priority).astype(int)
-
-    # build one-step away dead-lock matrix
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        path = paths[a]
-        if path is None:
-            continue
-
-        conflict_mat[a][a] = unique_ordered_priority[a]
-        for path_loop in range(len(path)):
-            p_el = path[path_loop]
-            p = p_el.position
-            if compare_position_equal(agent.target, p):
-                break
-            else:
-                a_loop = 0
-                opp_a = (int)(agent_next_position_matrix[p[0]][p[1]][a_loop])
-
-                cnt = 0
-                while (opp_a > -1) and (cnt < 1000):
-                    cnt = cnt + 1
-                    opp_path = paths[opp_a]
-                    if opp_path is not None:
-                        opp_a_p1 = opp_path[0].next_action_element.next_position
-                        if path_loop < len(path) - 1:
-                            p1 = path[path_loop + 1].next_action_element.next_position
-                            if not compare_position_equal(opp_a_p1, p1):
-                                conflict_mat[a][opp_a] = unique_ordered_priority[opp_a]
-                                conflict_mat[opp_a][a] = unique_ordered_priority[a]
-                        a_loop += 1
-                        opp_a = (int)(agent_next_position_matrix[p[0]][p[1]][a_loop])
-
-    # update one-step away
-    for a in range(env.get_num_agents()):
-        if not do_move[a]:
-            conflict_mat[conflict_mat == unique_ordered_priority[a]] = -1
-
-    return conflict_mat
-
-
-def avoid_dead_lock(env, a, paths, conflict_matrix, agent_position_matrix, agent_target_matrix,
-                    agent_next_position_matrix):
-    # performance optimisation
-    if conflict_matrix is not None:
-        if np.argmax(conflict_matrix[a]) == a:
-            return True
-
-    # dead lock algorithm
-    agent = env.agents[a]
-    agent_position = agent_fake_position(agent)
-    if compare_position_equal(agent_position, agent.target):
-        return True
-
-    path = paths[a]
-    if path is None:
-        return True
-
-    max_path_step_allowed = np.inf
-    # iterate over agent a's travel path (fixed path)
-    for path_loop in range(len(path)):
-        p_el = path[path_loop]
-        p = p_el.position
-        if compare_position_equal(p, agent.target):
-            break
-
-        # iterate over all agents (opposite)
-        # for opp_a in range(env.get_num_agents()):
-        a_loop = 0
-        opp_a = 0
-        cnt = 0
-        while (a_loop < env.get_num_agents() and opp_a > -1) and cnt < 1000:
-            cnt = cnt + 1
-            if conflict_matrix is not None:
-                opp_a = (int)(agent_next_position_matrix[p[0]][p[1]][a_loop])
-                a_loop += 1
-            else:
-                opp_a = (int)(agent_position_matrix[p[0]][p[1]])
-                a_loop = env.get_num_agents()
-            if opp_a > -1:
-                if opp_a != a:
-                    opp_agent = env.agents[opp_a]
-                    opp_path = paths[opp_a]
-                    if opp_path is not None:
-                        opp_path_0 = opp_path[0]
-
-                        # find all position in the opp.-path which are equal to current position.
-                        # the method has to scan all path through
-                        all_path_idx_offset_array = [0]
-                        for opp_path_loop_itr in range(len(path)):
-                            opp_p_el = opp_path[opp_path_loop_itr]
-                            opp_p = opp_p_el.position
-                            if compare_position_equal(opp_p, opp_agent.target):
-                                break
-                            opp_agent_position = agent_fake_position(opp_agent)
-                            if compare_position_equal(opp_p, opp_agent_position):
-                                all_path_idx_offset_array.extend([opp_path_loop_itr])
-                            opp_p_next = opp_p_el.next_action_element.next_position
-                            if compare_position_equal(opp_p_next, opp_agent_position):
-                                all_path_idx_offset_array.extend([opp_path_loop_itr])
-
-                        for all_path_idx_offset_loop in range(len(all_path_idx_offset_array)):
-                            all_path_idx_offset = all_path_idx_offset_array[all_path_idx_offset_loop]
-                            opp_path_0_el = opp_path[all_path_idx_offset]
-                            opp_path_0 = opp_path_0_el.position
-                            # if check_in_details is set to -1: no dead-lock candidate found
-                            # if check_in_details is set to  0: dead-lock candidate are not yet visible (agents need one step to become visible)(case A)
-                            # if check_in_details is set to  1: dead-lock candidate are visible, thus we have to collect them (case B)
-                            check_in_detail = -1
-
-                            # check mode, if conflict_matrix is set, then we are looking ..
-                            if conflict_matrix is not None:
-                                # Case A
-                                if np.argmax(conflict_matrix[a]) != a:
-                                    # avoid (parallel issue)
-                                    if compare_position_equal(opp_path_0, p):
-                                        check_in_detail = 0
-                            else:
-                                # Case B
-                                # collect all dead-lock candidates and check
-                                opp_agent_position = agent_fake_position(opp_agent)
-                                if compare_position_equal(opp_agent_position, p):
-                                    check_in_detail = 1
-
-                            if check_in_detail > -1:
-                                # print("Conflict risk found. My [", a, "] path is occupied by [", opp_a, "]")
-                                opp_path_loop = all_path_idx_offset
-                                back_path_loop = path_loop - check_in_detail
-                                cnt = 0
-                                while (opp_path_loop < len(opp_path) and back_path_loop > -1) and cnt < 1000:
-                                    cnt = cnt + 1
-                                    # retrieve position information
-                                    opp_p_el = opp_path[opp_path_loop]
-                                    opp_p = opp_p_el.position
-                                    me_p_el = path[back_path_loop]
-                                    me_p = me_p_el.next_action_element.next_position
-
-                                    if not compare_position_equal(opp_p, me_p):
-                                        # Case 1: The opposite train travels in same direction as the current train (agent a)
-                                        # Case 2: The opposite train travels in opposite direction and the path divergent
-                                        break
-
-                                    # make one step backwards (agent a) and one step forward for opposite train (agent opp_a)
-                                    # train a can no travel further than given position, because no divergent paths, this will cause a dead-lock
-                                    max_path_step_allowed = min(max_path_step_allowed, back_path_loop)
-                                    opp_path_loop += 1
-                                    back_path_loop -= 1
-
-                                    # check whether at least one step is allowed
-                                    if max_path_step_allowed < 1:
-                                        return False
-
-                                if back_path_loop == -1:
-                                    # No divergent path found, it cause a deadlock
-                                    # print("conflict (stop): (", a, ",", opp_a, ")")
-                                    return False
-
-    # check whether at least one step is allowed
-    return max_path_step_allowed > 0
-
-
-def calculate_one_step(env):
-    # can agent move array
-    do_move = np.zeros(env.get_num_agents())
-    if True:
-        cnt = 0
-        cnt_done = 0
-        for a in range(env.get_num_agents()):
-            agent = env.agents[a]
-            if agent.status < RailAgentStatus.DONE:
-                cnt += 1
-                if cnt < 30:
-                    do_move[a] = True
-            else:
-                cnt_done += 1
-        print("\r{}/{}\t".format(cnt_done, env.get_num_agents()), end="")
-    else:
-        agent_fixed_travel_paths = get_paths(env.distance_map, 1)
-        # can agent move array
-        do_move = np.zeros(env.get_num_agents())
-        for a in range(env.get_num_agents()):
-            agent = env.agents[a]
-            if agent.position is not None and not compare_position_equal(agent.position, agent.target):
-                do_move[a] = True
-                break
-
-        if np.sum(do_move) == 0:
-            for a in range(env.get_num_agents()):
-                agent = env.agents[a]
-                if agent_fixed_travel_paths[a] is not None:
-                    if agent.position is None and compare_position_equal(agent.initial_position, agent.target):
-                        do_move[a] = True
-                        break
-                    elif not compare_position_equal(agent.initial_position, agent.target):
-                        do_move[a] = True
-                        break
-
-        initial_position = None
-        for a in range(env.get_num_agents()):
-            agent = env.agents[a]
-            if do_move[a]:
-                initial_position = agent.initial_position
-
-            if initial_position is not None:
-                if compare_position_equal(agent.initial_position, initial_position):
-                    do_move[a] = True
-
-    # copy of agents fixed travel path (current path to follow) : only once : quite expensive
-    # agent_fixed_travel_paths = get_shortest_paths(env.distance_map)
-    agent_fixed_travel_paths = dict()
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        if do_move[a]:
-            agent_fixed_travel_paths[agent.handle] = get_paths(env.distance_map, None, agent.handle)[agent.handle]
-        else:
-            agent_fixed_travel_paths[agent.handle] = None
-
-    # copy position, target and next position into cache (matrices)
-    # (The cache idea increases the run-time performance)
-    agent_position_matrix = np.zeros(shape=(env.height, env.width)) - 1.0
-    agent_target_matrix = np.zeros(shape=(env.height, env.width)) - 1.0
-    agent_next_position_matrix = np.zeros(shape=(env.height, env.width, env.get_num_agents() + 1)) - 1.0
-    for a in range(env.get_num_agents()):
-        if do_move[a] == False:
-            continue
-        agent = env.agents[a]
-        agent_position = agent_fake_position(agent)
-        if agent_position[2] is None:
-            agent_position = agent.initial_position
-        agent_position_matrix[agent_position[0]][agent_position[1]] = a
-        agent_target_matrix[agent.target[0]][agent.target[1]] = a
-        if not compare_position_equal(agent.target, agent_position):
-            path = agent_fixed_travel_paths[a]
-            if path is not None:
-                p_el = path[0]
-                p = p_el.position
-                a_loop = 0
-                cnt = 0
-                while (agent_next_position_matrix[p[0]][p[1]][a_loop] > -1) and cnt < 1000:
-                    cnt = cnt + 1
-                    a_loop += 1
-                agent_next_position_matrix[p[0]][p[1]][a_loop] = a
-
-    # check which agents can move (see : avoid_dead_lock (case b))
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        if not compare_position_equal(agent.position, agent.target) and do_move[a]:
-            do_move[a] = avoid_dead_lock(env, a, agent_fixed_travel_paths, None, agent_position_matrix,
-                                         agent_target_matrix,
-                                         agent_next_position_matrix)
-
-    # check which agents can move (see : avoid_dead_lock (case a))
-    # calculate possible candidate for hidden one-step away dead-lock candidates
-    conflict_matrix = calc_conflict_matrix_next_step(env, agent_fixed_travel_paths, do_move, agent_position_matrix,
-                                                     agent_target_matrix,
-                                                     agent_next_position_matrix)
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        if not compare_position_equal(agent.position, agent.target):
-            if do_move[a]:
-                do_move[a] = avoid_dead_lock(env, a, agent_fixed_travel_paths, conflict_matrix, agent_position_matrix,
-                                             agent_target_matrix,
-                                             agent_next_position_matrix)
-
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        if agent.position is not None and compare_position_equal(agent.position, agent.target):
-            do_move[a] = False
-
-    # main loop (calculate actions for all agents)
-    action_dict = {}
-    is_moving_cnt = 0
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        action = RailEnvActions.MOVE_FORWARD
-
-        if do_move[a] and is_moving_cnt < 10:
-            is_moving_cnt += 1
-            # check for deadlock:
-            path = agent_fixed_travel_paths[a]
-            if path is not None:
-                action = path[0].next_action_element.action
-        else:
-            action = RailEnvActions.STOP_MOVING
-        action_dict[a] = action
-
-    return action_dict, do_move
-
-
-def calculate_one_step_heuristics(env):
-    # copy of agents fixed travel path (current path to follow)
-    agent_fixed_travel_paths = get_paths(env.distance_map, 1)
-
-    # main loop (calculate actions for all agents)
-    action_dict = {}
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        action = RailEnvActions.MOVE_FORWARD
-
-        # check for deadlock:
-        path = agent_fixed_travel_paths[a]
-        if path is not None:
-            action = path[0].next_action_element.action
-        action_dict[a] = action
-
-    return action_dict, None
-
-
-def calculate_one_step_primitive_implementation(env):
-    # can agent move array
-    do_move = np.zeros(env.get_num_agents())
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        if agent.status > RailAgentStatus.ACTIVE:
-            continue
-        if (agent.status == RailAgentStatus.ACTIVE):
-            do_move[a] = True
-            break
-        do_move[a] = True
-        break
-
-    # main loop (calculate actions for all agents)
-    action_dict = {}
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        action = RailEnvActions.MOVE_FORWARD
-        if do_move[a]:
-            # check for deadlock:
-            # copy of agents fixed travel path (current path to follow)
-            agent_fixed_travel_paths = get_shortest_paths(env.distance_map, 1, agent.handle)
-            path = agent_fixed_travel_paths[agent.handle]
-            if path is not None:
-                print("\rAgent:{:4d}/{:<4d} ".format(a + 1, env.get_num_agents()), end=" ")
-                action = path[0].next_action_element.action
-        else:
-            action = RailEnvActions.STOP_MOVING
-        action_dict[a] = action
-
-    return action_dict, do_move
-
-
-def calculate_one_step_package_implementation(env):
-    # copy of agents fixed travel path (current path to follow)
-    # agent_fixed_travel_paths = get_shortest_paths(env.distance_map,1)
-    agent_fixed_travel_paths = get_paths(env.distance_map, 1)
-
-    # can agent move array
-    do_move = np.zeros(env.get_num_agents())
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        if agent.position is not None and not compare_position_equal(agent.position, agent.target):
-            do_move[a] = True
-            break
-
-    if np.sum(do_move) == 0:
-        for a in range(env.get_num_agents()):
-            agent = env.agents[a]
-            if agent_fixed_travel_paths[a] is not None:
-                if agent.position is None and compare_position_equal(agent.initial_position, agent.target):
-                    do_move[a] = True
-                    break
-                elif not compare_position_equal(agent.initial_position, agent.target):
-                    do_move[a] = True
-                    break
-
-    initial_position = None
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        if do_move[a]:
-            initial_position = agent.initial_position
-
-        if initial_position is not None:
-            if compare_position_equal(agent.initial_position, initial_position):
-                do_move[a] = True
-
-    # main loop (calculate actions for all agents)
-    action_dict = {}
-    for a in range(env.get_num_agents()):
-        agent = env.agents[a]
-        action = RailEnvActions.MOVE_FORWARD
-
-        if do_move[a]:
-            # check for deadlock:
-            path = agent_fixed_travel_paths[a]
-            if path is not None:
-                action = path[0].next_action_element.action
-        else:
-            action = RailEnvActions.STOP_MOVING
-        action_dict[a] = action
-
-    return action_dict, do_move
diff --git a/src/simple/ShortestPathPredictorForRailEnv.py b/src/simple/ShortestPathPredictorForRailEnv.py
deleted file mode 100644
index f820253fb34939c62a94fc77688326254bc0368e..0000000000000000000000000000000000000000
--- a/src/simple/ShortestPathPredictorForRailEnv.py
+++ /dev/null
@@ -1,107 +0,0 @@
-import numpy as np
-
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.rail_env import RailEnvActions
-
-
-class AdrianShortestPathPredictorForRailEnv(PredictionBuilder):
-    """
-    ShortestPathPredictorForRailEnv object.
-
-    This object returns shortest-path predictions for agents in the RailEnv environment.
-    The prediction acts as if no other agent is in the environment and always takes the forward action.
-    """
-
-    def __init__(self, max_depth=20):
-        # Initialize with depth 20
-        self.max_depth = max_depth
-
-    def get(self, custom_args=None, handle=None):
-        """
-        Called whenever get_many in the observation build is called.
-        Requires distance_map to extract the shortest path.
-
-        Parameters
-        -------
-        custom_args: dict
-            - distance_map : dict
-        handle : int (optional)
-            Handle of the agent for which to compute the observation vector.
-
-        Returns
-        -------
-        np.array
-            Returns a dictionary indexed by the agent handle and for each agent a vector of (max_depth + 1)x5 elements:
-            - time_offset
-            - position axis 0
-            - position axis 1
-            - direction
-            - action taken to come here
-            The prediction at 0 is the current position, direction etc.
-        """
-
-        agents = self.env.agents
-        if handle:
-            agents = [self.env.agents[handle]]
-        assert custom_args is not None
-        distance_map = custom_args.get('distance_map')
-        assert distance_map is not None
-
-        prediction_dict = {}
-        for agent in agents:
-            _agent_initial_position = agent.position
-            _agent_initial_direction = agent.direction
-            prediction = np.zeros(shape=(self.max_depth + 1, 5))
-            prediction[0] = [0, *_agent_initial_position, _agent_initial_direction, 0]
-            visited = []
-            for index in range(1, self.max_depth + 1):
-                # if we're at the target, stop moving...
-                if agent.position == agent.target:
-                    prediction[index] = [index, *agent.target, agent.direction, RailEnvActions.STOP_MOVING]
-                    visited.append((agent.position[0], agent.position[1], agent.direction))
-                    continue
-                # Take shortest possible path
-                cell_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
-
-                new_position = None
-                new_direction = None
-                if np.sum(cell_transitions) == 1:
-                    new_direction = np.argmax(cell_transitions)
-                    new_position = get_new_position(agent.position, new_direction)
-                elif np.sum(cell_transitions) > 1:
-                    min_dist = np.inf
-                    no_dist_found = True
-                    for direction in range(4):
-                        if cell_transitions[direction] == 1:
-                            neighbour_cell = get_new_position(agent.position, direction)
-                            target_dist = distance_map[agent.handle, neighbour_cell[0], neighbour_cell[1], direction]
-                            if target_dist < min_dist or no_dist_found:
-                                min_dist = target_dist
-                                new_direction = direction
-                                no_dist_found = False
-                    new_position = get_new_position(agent.position, new_direction)
-                else:
-                    print("--------------------")
-                    print(agent.position, agent.direction, "valid:", self.env.rail.cell_neighbours_valid(
-                          agent.position),
-                          self.env.rail.get_full_transitions(agent.position[0],agent.position[1])
-                          )
-                    print("--------------------")
-                    raise Exception("No transition possible {}".format(cell_transitions))
-
-                # update the agent's position and direction
-                agent.position = new_position
-                agent.direction = new_direction
-
-                # prediction is ready
-                prediction[index] = [index, *new_position, new_direction, 0]
-                visited.append((new_position[0], new_position[1], new_direction))
-            self.env.dev_pred_dict[agent.handle] = visited
-            prediction_dict[agent.handle] = prediction
-
-            # cleanup: reset initial position
-            agent.position = _agent_initial_position
-            agent.direction = _agent_initial_direction
-
-        return prediction_dict
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..43f1b4ae15aebae1348e0eadc5f9cab243511a28
--- /dev/null
+++ b/utils/dead_lock_avoidance_agent.py
@@ -0,0 +1,116 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+
+from reinforcement_learning.policy import Policy
+from utils.shortest_Distance_walker import ShortestDistanceWalker
+
+
+class MyWalker(ShortestDistanceWalker):
+    def __init__(self, env: RailEnv, agent_positions):
+        super().__init__(env)
+        self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
+                                                     self.env.height,
+                                                     self.env.width),
+                                                    dtype=int) - 1
+
+        self.agent_positions = agent_positions
+
+        self.agent_map = {}
+
+    def getData(self):
+        return self.shortest_distance_agent_map
+
+    def callback(self, handle, agent, position, direction, action):
+        opp_a = self.agent_positions[position]
+        if opp_a != -1 and opp_a != handle:
+            d = self.agent_map.get(handle, [])
+            d.append(opp_a)
+            if self.env.agents[opp_a].direction != direction:
+                self.agent_map.update({handle: d})
+        self.shortest_distance_agent_map[(handle, position[0], position[1])] = direction
+
+
+class DeadLockAvoidanceAgent(Policy):
+    def __init__(self, env: RailEnv, state_size, action_size):
+        self.env = env
+        self.action_size = action_size
+        self.state_size = state_size
+        self.memory = []
+        self.loss = 0
+        self.agent_can_move = {}
+
+    def step(self, handle, state, action, reward, next_state, done):
+        pass
+
+    def act(self, handle, state, eps=0.):
+        agent = self.env.agents[handle]
+        #if handle > self.env._elapsed_steps:
+        #    return RailEnvActions.STOP_MOVING
+        if agent.status == RailAgentStatus.ACTIVE:
+            self.active_agent_cnt += 1
+        #if agent.status > 20:
+        #    return RailEnvActions.STOP_MOVING
+        check = self.agent_can_move.get(handle, None)
+        if check is None:
+            # print(handle, RailEnvActions.STOP_MOVING)
+            return RailEnvActions.STOP_MOVING
+
+        return check[3]
+
+    def reset(self):
+        pass
+
+    def start_step(self):
+        self.active_agent_cnt = 0
+        self.shortest_distance_mapper()
+
+    def end_step(self):
+        # print("#A:", self.active_agent_cnt, "/", self.env.get_num_agents(),self.env._elapsed_steps)
+        pass
+
+    def get_actions(self):
+        pass
+
+    def shortest_distance_mapper(self):
+
+        # build map with agent positions (only active agents)
+        agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status <= RailAgentStatus.ACTIVE:
+                if agent.position is not None:
+                    agent_positions[agent.position] = handle
+
+        my_walker = MyWalker(self.env, agent_positions)
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status <= RailAgentStatus.ACTIVE:
+                my_walker.walk_to_target(handle)
+        self.shortest_distance_agent_map = my_walker.getData()
+
+        self.agent_can_move = {}
+        agent_positions_map = np.clip(agent_positions + 1, 0, 1)
+        for handle in range(self.env.get_num_agents()):
+            opp_agents = my_walker.agent_map.get(handle, [])
+            me = np.clip(self.shortest_distance_agent_map[handle] + 1, 0, 1)
+            next_step_ok = True
+            next_position, next_direction, action = my_walker.walk_one_step(handle)
+            for opp_a in opp_agents:
+                opp = np.clip(self.shortest_distance_agent_map[opp_a] + 1, 0, 1)
+                delta = np.clip(me - opp - agent_positions_map, 0, 1)
+                if (np.sum(delta) > 1):
+                    next_step_ok = False
+            if next_step_ok:
+                self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]})
+
+        if False:
+            a = np.floor(np.sqrt(self.env.get_num_agents()))
+            b = np.ceil(self.env.get_num_agents() / a)
+            for handle in range(self.env.get_num_agents()):
+                plt.subplot(a, b, handle + 1)
+                plt.imshow(self.shortest_distance_agent_map[handle])
+            # plt.colorbar()
+            plt.show(block=False)
+            plt.pause(0.001)
diff --git a/utils/deadlock_check.py b/utils/deadlock_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..28c65fa6185fa9131fbc493a33fa26529e0290db
--- /dev/null
+++ b/utils/deadlock_check.py
@@ -0,0 +1,42 @@
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+
+
+def check_if_all_blocked(env):
+    """
+    Checks whether all the agents are blocked (full deadlock situation).
+    In that case it is pointless to keep running inference as no agent will be able to move.
+    :param env: current environment
+    :return:
+    """
+
+    # First build a map of agents in each position
+    location_has_agent = {}
+    for agent in env.agents:
+        if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position:
+            location_has_agent[tuple(agent.position)] = 1
+
+    # Looks for any agent that can still move
+    for handle in env.get_agent_handles():
+        agent = env.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            agent_virtual_position = agent.target
+        else:
+            continue
+
+        possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
+        orientation = agent.direction
+
+        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
+            if possible_transitions[branch_direction]:
+                new_position = get_new_position(agent_virtual_position, branch_direction)
+
+                if new_position not in location_has_agent:
+                    return False
+
+    # No agent can move at all: full deadlock!
+    return True
diff --git a/utils/extra.py b/utils/extra.py
new file mode 100644
index 0000000000000000000000000000000000000000..1145521ad3a3a087a38a4fb514fa4b60b4a4ffc0
--- /dev/null
+++ b/utils/extra.py
@@ -0,0 +1,445 @@
+# import matplotlib.pyplot as plt
+import numpy as np
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.rail_env import RailEnvActions, RailAgentStatus, RailEnv
+
+from reinforcement_learning.policy import Policy
+from utils.shortest_Distance_walker import ShortestDistanceWalker
+
+
+class ExtraPolicy(Policy):
+    def __init__(self, state_size, action_size):
+        self.state_size = state_size
+        self.action_size = action_size
+        self.memory = []
+        self.loss = 0
+
+    def load(self, filename):
+        pass
+
+    def save(self, filename):
+        pass
+
+    def step(self, handle, state, action, reward, next_state, done):
+        pass
+
+    def act(self, handle, state, eps=0.):
+        a = 0
+        b = 4
+        action = RailEnvActions.STOP_MOVING
+        if state[2] == 1 and state[10 + a] == 0:
+            action = RailEnvActions.MOVE_LEFT
+        elif state[3] == 1 and state[11 + a] == 0:
+            action = RailEnvActions.MOVE_FORWARD
+        elif state[4] == 1 and state[12 + a] == 0:
+            action = RailEnvActions.MOVE_RIGHT
+        elif state[5] == 1 and state[13 + a] == 0:
+            action = RailEnvActions.MOVE_FORWARD
+
+        elif state[6] == 1 and state[10 + b] == 0:
+            action = RailEnvActions.MOVE_LEFT
+        elif state[7] == 1 and state[11 + b] == 0:
+            action = RailEnvActions.MOVE_FORWARD
+        elif state[8] == 1 and state[12 + b] == 0:
+            action = RailEnvActions.MOVE_RIGHT
+        elif state[9] == 1 and state[13 + b] == 0:
+            action = RailEnvActions.MOVE_FORWARD
+
+        return action
+
+    def test(self):
+        pass
+
+
+def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
+    if possible_transitions[0] == 1:
+        return 0
+    if possible_transitions[1] == 1:
+        return 1
+    if possible_transitions[2] == 1:
+        return 2
+    return 3
+
+
+def fast_count_nonzero(possible_transitions: (int, int, int, int)):
+    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
+
+
+class Extra(ObservationBuilder):
+
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
+        self.observation_dim = 62
+
+    def shortest_distance_mapper(self):
+
+        class MyWalker(ShortestDistanceWalker):
+            def __init__(self, env: RailEnv):
+                super().__init__(env)
+                self.shortest_distance_agent_counter = np.zeros((self.env.height, self.env.width), dtype=int)
+                self.shortest_distance_agent_direction_counter = np.zeros((self.env.height, self.env.width, 4),
+                                                                          dtype=int)
+
+            def getData(self):
+                return self.shortest_distance_agent_counter, self.shortest_distance_agent_direction_counter
+
+            def callback(self, handle, agent, position, direction, action):
+                self.shortest_distance_agent_counter[position] += 1
+                self.shortest_distance_agent_direction_counter[(position[0], position[1], direction)] += 1
+
+        my_walker = MyWalker(self.env)
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status <= RailAgentStatus.ACTIVE:
+                my_walker.walk_to_target(handle)
+
+        self.shortest_distance_agent_counter, self.shortest_distance_agent_direction_counter = my_walker.getData()
+
+        # plt.imshow(self.shortest_distance_agent_counter)
+        # plt.colorbar()
+        # plt.show()
+
+    def build_data(self):
+        if self.env is not None:
+            self.env.dev_obs_dict = {}
+        self.switches = {}
+        self.switches_neighbours = {}
+        self.debug_render_list = []
+        self.debug_render_path_list = []
+        if self.env is not None:
+            self.find_all_cell_where_agent_can_choose()
+            self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+            self.history_direction = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+            self.history_same_direction_cnt = np.zeros((self.env.height, self.env.width), dtype=int)
+            self.history_time = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+
+        self.shortest_distance_agent_counter = None
+        self.shortest_distance_agent_direction_counter = None
+
+    def find_all_cell_where_agent_can_choose(self):
+
+        switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in switches.keys():
+                            switches.update({pos: [dir]})
+                        else:
+                            switches[pos].append(dir)
+
+        switches_neighbours = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                # look one step forward
+                for dir in range(4):
+                    pos = (h, w)
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    for d in range(4):
+                        if possible_transitions[d] == 1:
+                            new_cell = get_new_position(pos, d)
+                            if new_cell in switches.keys() and pos not in switches.keys():
+                                if pos not in switches_neighbours.keys():
+                                    switches_neighbours.update({pos: [dir]})
+                                else:
+                                    switches_neighbours[pos].append(dir)
+
+        self.switches = switches
+        self.switches_neighbours = switches_neighbours
+
+    def check_agent_descision(self, position, direction):
+        switches = self.switches
+        switches_neighbours = self.switches_neighbours
+        agents_on_switch = False
+        agents_on_switch_all = False
+        agents_near_to_switch = False
+        agents_near_to_switch_all = False
+        if position in switches.keys():
+            agents_on_switch = direction in switches[position]
+            agents_on_switch_all = True
+
+        if position in switches_neighbours.keys():
+            new_cell = get_new_position(position, direction)
+            if new_cell in switches.keys():
+                if not direction in switches[new_cell]:
+                    agents_near_to_switch = direction in switches_neighbours[position]
+            else:
+                agents_near_to_switch = direction in switches_neighbours[position]
+
+            agents_near_to_switch_all = direction in switches_neighbours[position]
+
+        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def required_agent_descision(self):
+        agents_can_choose = {}
+        agents_on_switch = {}
+        agents_on_switch_all = {}
+        agents_near_to_switch = {}
+        agents_near_to_switch_all = {}
+        for a in range(self.env.get_num_agents()):
+            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
+                self.check_agent_descision(
+                    self.env.agents[a].position,
+                    self.env.agents[a].direction)
+            agents_on_switch.update({a: ret_agents_on_switch})
+            agents_on_switch_all.update({a: ret_agents_on_switch_all})
+            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
+            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
+
+            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
+
+            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
+
+        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def debug_render(self, env_renderer):
+        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
+            self.required_agent_descision()
+        self.env.dev_obs_dict = {}
+        for a in range(max(3, self.env.get_num_agents())):
+            self.env.dev_obs_dict.update({a: []})
+
+        selected_agent = None
+        if agents_can_choose[0]:
+            if self.env.agents[0].position is not None:
+                self.debug_render_list.append(self.env.agents[0].position)
+            else:
+                self.debug_render_list.append(self.env.agents[0].initial_position)
+
+        if self.env.agents[0].position is not None:
+            self.debug_render_path_list.append(self.env.agents[0].position)
+        else:
+            self.debug_render_path_list.append(self.env.agents[0].initial_position)
+
+        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
+        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
+        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
+        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
+
+        self.env.dev_obs_dict[0] = self.debug_render_list
+        self.env.dev_obs_dict[1] = self.switches.keys()
+        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
+        self.env.dev_obs_dict[3] = self.debug_render_path_list
+
+    def reset(self):
+        self.build_data()
+        return
+
+    def fast_argmax(self, array):
+        if array[0] == 1:
+            return 0
+        if array[1] == 1:
+            return 1
+        if array[2] == 1:
+            return 2
+        return 3
+
+    def _explore(self, handle, new_position, new_direction, distance_map, depth):
+
+        may_has_opp_agent = 0
+        has_opp_agent = -1
+        has_other_target = 0
+        has_target = 0
+        visited = []
+
+        new_cell_dist = np.inf
+
+        # stop exploring (max_depth reached)
+        if depth > self.max_depth:
+            return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+
+        # max_explore_steps = 100
+        cnt = 0
+        while cnt < 100:
+            cnt += 1
+            has_other_target = int(new_position in self.agent_targets)
+            new_cell_dist = min(new_cell_dist, distance_map[handle,
+                                                            new_position[0], new_position[1],
+                                                            new_direction])
+
+            visited.append(new_position)
+            has_target = int(self.env.agents[handle].target == new_position)
+            opp_a = self.agent_positions[new_position]
+            if opp_a != -1 and opp_a != handle:
+                possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
+                if possible_transitions[self.env.agents[opp_a].direction] < 1:
+                    # opp agent found
+                    has_opp_agent = opp_a
+                    may_has_opp_agent = 1
+                    return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+
+            # convert one-hot encoding to 0,1,2,3
+            agents_on_switch, \
+            agents_near_to_switch, \
+            agents_near_to_switch_all, \
+            agents_on_switch_all = \
+                self.check_agent_descision(new_position, new_direction)
+
+            if agents_near_to_switch:
+                return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+
+            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
+            if fast_count_nonzero(possible_transitions) > 1:
+                may_has_opp_agent_loop = 1
+                for dir_loop in range(4):
+                    if possible_transitions[dir_loop] == 1:
+                        hoa, mhoa, hot, ht, v, min_cell_dist = self._explore(handle,
+                                                                             get_new_position(new_position,
+                                                                                              dir_loop),
+                                                                             dir_loop,
+                                                                             distance_map,
+                                                                             depth + 1)
+
+                        has_opp_agent = max(has_opp_agent, hoa)
+                        may_has_opp_agent_loop = min(may_has_opp_agent_loop, mhoa)
+                        has_other_target = max(has_other_target, hot)
+                        has_target = max(has_target, ht)
+                        visited.append(v)
+                        new_cell_dist = min(min_cell_dist, new_cell_dist)
+                return has_opp_agent, may_has_opp_agent_loop, has_other_target, has_target, visited, new_cell_dist
+            else:
+                new_direction = fast_argmax(possible_transitions)
+                new_position = get_new_position(new_position, new_direction)
+
+        return has_opp_agent, may_has_opp_agent, has_other_target, has_target, visited, new_cell_dist
+
+    def get(self, handle):
+
+        if (handle == 0):
+            self.updateSharedData()
+
+        # all values are [0,1]
+        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
+        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
+        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
+        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
+        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
+        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
+        # observation[6] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
+        # observation[7] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
+        # observation[8] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
+        # observation[9] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
+
+        observation = np.zeros(self.observation_dim)
+        visited = []
+        agent = self.env.agents[handle]
+
+        agent_done = False
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+            observation[0] = 1
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+            observation[1] = 1
+        else:
+            agent_virtual_position = (-1, -1)
+            agent_done = True
+
+        if not agent_done:
+            visited.append(agent_virtual_position)
+            distance_map = self.env.distance_map.get()
+            current_cell_dist = distance_map[handle,
+                                             agent_virtual_position[0], agent_virtual_position[1],
+                                             agent.direction]
+            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
+            orientation = agent.direction
+            if fast_count_nonzero(possible_transitions) == 1:
+                orientation = fast_argmax(possible_transitions)
+
+            for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
+                if possible_transitions[branch_direction]:
+                    new_position = get_new_position(agent_virtual_position, branch_direction)
+                    new_cell_dist = distance_map[handle,
+                                                 new_position[0], new_position[1],
+                                                 branch_direction]
+
+                    has_opp_agent, \
+                    may_has_opp_agent, \
+                    has_other_target, \
+                    has_target, \
+                    v, \
+                    min_cell_dist = self._explore(handle,
+                                                  new_position,
+                                                  branch_direction,
+                                                  distance_map,
+                                                  0)
+                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
+                        observation[2 + dir_loop] = int(new_cell_dist < current_cell_dist)
+
+                    new_cell_dist = min(min_cell_dist, new_cell_dist)
+                    if not (np.math.isinf(new_cell_dist) and not np.math.isinf(current_cell_dist)):
+                        observation[6 + dir_loop] = int(new_cell_dist < current_cell_dist)
+
+                    visited.append(v)
+
+                    observation[10 + dir_loop] = int(has_opp_agent > -1)
+                    observation[14 + dir_loop] = may_has_opp_agent
+                    observation[18 + dir_loop] = has_other_target
+                    observation[22 + dir_loop] = has_target
+                    observation[26 + dir_loop] = self.getHistorySameDirection(new_position, branch_direction)
+                    observation[30 + dir_loop] = self.getHistoryOppositeDirection(new_position, branch_direction)
+                    observation[34 + dir_loop] = self.getTemporalDistance(new_position)
+                    observation[38 + dir_loop] = self.getFlowDensity(new_position)
+                    observation[42 + dir_loop] = self.getDensitySameDirection(new_position, branch_direction)
+                    observation[44 + dir_loop] = self.getDensity(new_position)
+                    observation[48 + dir_loop] = int(not np.math.isinf(new_cell_dist))
+                    observation[52 + dir_loop] = 1
+                    observation[54 + dir_loop] = int(has_opp_agent > handle)
+
+        self.env.dev_obs_dict.update({handle: visited})
+
+        return observation
+
+    def getDensitySameDirection(self, position, direction):
+        val = self.shortest_distance_agent_direction_counter[(position[0], position[1], direction)]
+        return val / self.env.get_num_agents()
+
+    def getDensity(self, position):
+        val = self.shortest_distance_agent_counter[position]
+        return val / self.env.get_num_agents()
+
+    def getHistorySameDirection(self, position, direction):
+        val = self.history_direction[position]
+        if val == -1:
+            return -1
+        if val == direction:
+            return 1
+        return 0
+
+    def getHistoryOppositeDirection(self, position, direction):
+        val = self.getHistorySameDirection(position, direction)
+        if val == -1:
+            return -1
+        return 1 - val
+
+    def getTemporalDistance(self, position):
+        if self.history_time[position] == -1:
+            return -1
+        val = self.env._elapsed_steps - self.history_time[position]
+        if val < 1:
+            return 0
+        return 1 + np.log(1 + val)
+
+    def getFlowDensity(self, position):
+        val = self.env._elapsed_steps - self.history_same_direction_cnt[position]
+        return 1 + np.log(1 + val)
+
+    def updateSharedData(self):
+        self.shortest_distance_mapper()
+        self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+        self.agent_targets = []
+        for a in np.arange(self.env.get_num_agents()):
+            if self.env.agents[a].status == RailAgentStatus.ACTIVE:
+                self.agent_targets.append(self.env.agents[a].target)
+                if self.env.agents[a].position is not None:
+                    self.agent_positions[self.env.agents[a].position] = a
+                    if self.history_direction[self.env.agents[a].position] == self.env.agents[a].direction:
+                        self.history_same_direction_cnt[self.env.agents[a].position] += 1
+                    else:
+                        self.history_same_direction_cnt[self.env.agents[a].position] = 0
+                    self.history_direction[self.env.agents[a].position] = self.env.agents[a].direction
+                    self.history_time[self.env.agents[a].position] = self.env._elapsed_steps
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc767f2b76197fbb98d63f326d7720dbbbdc020
--- /dev/null
+++ b/utils/observation_utils.py
@@ -0,0 +1,124 @@
+import numpy as np
+from flatland.envs.observations import TreeObsForRailEnv
+
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    max = 0
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
+            max = seq[idx]
+        idx -= 1
+    return max
+
+
+def min_gt(seq, val):
+    """
+    Return smallest item in seq for which item > val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    min = np.inf
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] >= val and seq[idx] < min:
+            min = seq[idx]
+        idx -= 1
+    return min
+
+
+def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
+    """
+    This function returns the difference between min and max value of an observation
+    :param obs: Observation that should be normalized
+    :param clip_min: min value where observation will be clipped
+    :param clip_max: max value where observation will be clipped
+    :return: returnes normalized and clipped observatoin
+    """
+    if fixed_radius > 0:
+        max_obs = fixed_radius
+    else:
+        max_obs = max(1, max_lt(obs, 1000)) + 1
+
+    min_obs = 0  # min(max_obs, min_gt(obs, 0))
+    if normalize_to_range:
+        min_obs = min_gt(obs, 0)
+    if min_obs > max_obs:
+        min_obs = max_obs
+    if max_obs == min_obs:
+        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
+    norm = np.abs(max_obs - min_obs)
+    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
+
+
+def _split_node_into_feature_groups(node) -> (np.ndarray, np.ndarray, np.ndarray):
+    data = np.zeros(6)
+    distance = np.zeros(1)
+    agent_data = np.zeros(4)
+
+    data[0] = node.dist_own_target_encountered
+    data[1] = node.dist_other_target_encountered
+    data[2] = node.dist_other_agent_encountered
+    data[3] = node.dist_potential_conflict
+    data[4] = node.dist_unusable_switch
+    data[5] = node.dist_to_next_branch
+
+    distance[0] = node.dist_min_to_target
+
+    agent_data[0] = node.num_agents_same_direction
+    agent_data[1] = node.num_agents_opposite_direction
+    agent_data[2] = node.num_agents_malfunctioning
+    agent_data[3] = node.speed_min_fractional
+
+    return data, distance, agent_data
+
+
+def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
+    if node == -np.inf:
+        remaining_depth = max_tree_depth - current_tree_depth
+        # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
+        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
+        return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4
+
+    data, distance, agent_data = _split_node_into_feature_groups(node)
+
+    if not node.childs:
+        return data, distance, agent_data
+
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
+        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
+        data = np.concatenate((data, sub_data))
+        distance = np.concatenate((distance, sub_distance))
+        agent_data = np.concatenate((agent_data, sub_agent_data))
+
+    return data, distance, agent_data
+
+
+def split_tree_into_feature_groups(tree, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
+    """
+    This function splits the tree into three difference arrays of values
+    """
+    data, distance, agent_data = _split_node_into_feature_groups(tree)
+
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
+        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
+        data = np.concatenate((data, sub_data))
+        distance = np.concatenate((distance, sub_distance))
+        agent_data = np.concatenate((agent_data, sub_agent_data))
+
+    return data, distance, agent_data
+
+
+def normalize_observation(observation, tree_depth: int, observation_radius=0):
+    """
+    This function normalizes the observation used by the RL algorithm
+    """
+    data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
+
+    data = norm_obs_clip(data, fixed_radius=observation_radius)
+    distance = norm_obs_clip(distance, normalize_to_range=True)
+    agent_data = np.clip(agent_data, -1, 1)
+    normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
+    return normalized_obs
diff --git a/utils/shortest_Distance_walker.py b/utils/shortest_Distance_walker.py
new file mode 100644
index 0000000000000000000000000000000000000000..d69ebcb0a98f732cc44849b10858b2e42a376a23
--- /dev/null
+++ b/utils/shortest_Distance_walker.py
@@ -0,0 +1,69 @@
+import numpy as np
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
+
+
+class ShortestDistanceWalker:
+    def __init__(self, env: RailEnv):
+        self.env = env
+
+    def walk(self, handle, position, direction):
+        possible_transitions = self.env.rail.get_transitions(*position, direction)
+        num_transitions = fast_count_nonzero(possible_transitions)
+        if num_transitions == 1:
+            new_direction = fast_argmax(possible_transitions)
+            new_position = get_new_position(position, new_direction)
+            dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
+            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD
+        else:
+            min_distances = []
+            positions = []
+            directions = []
+            for new_direction in [(direction + i) % 4 for i in range(-1, 2)]:
+                if possible_transitions[new_direction]:
+                    new_position = get_new_position(position, new_direction)
+                    min_distances.append(
+                        self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction])
+                    positions.append(new_position)
+                    directions.append(new_direction)
+                else:
+                    min_distances.append(np.inf)
+                    positions.append(None)
+                    directions.append(None)
+
+        a = np.argmin(min_distances)
+        return positions[a], directions[a], min_distances[a], a + 1
+
+    def callback(self, handle, agent, position, direction, action):
+        pass
+
+    def walk_to_target(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        while (position != agent.target):
+            position, direction, dist, action = self.walk(handle, position, direction)
+            if position is None:
+                break
+            self.callback(handle, agent, position, direction, action)
+
+    def callback_one_step(self, handle, agent, position, direction, action):
+        pass
+
+    def walk_one_step(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        if (position != agent.target):
+            new_position, new_direction, dist, action = self.walk(handle, position, direction)
+            if new_position is None:
+                return position, direction, RailEnvActions.STOP_MOVING
+            self.callback_one_step(handle, agent, new_position, new_direction, action)
+        return new_position, new_direction, action
diff --git a/utils/timer.py b/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e397c79cc46c9f49e967365c3a2ad9bbf7cd5f6
--- /dev/null
+++ b/utils/timer.py
@@ -0,0 +1,33 @@
+from timeit import default_timer
+
+
+class Timer(object):
+    """
+    Utility to measure times.
+
+    TODO:
+    - add "lap" method to make it easier to measure average time (+std) when measuring the same thing multiple times.
+    """
+
+    def __init__(self):
+        self.total_time = 0.0
+        self.start_time = 0.0
+        self.end_time = 0.0
+
+    def start(self):
+        self.start_time = default_timer()
+
+    def end(self):
+        self.total_time += default_timer() - self.start_time
+
+    def get(self):
+        return self.total_time
+
+    def get_current(self):
+        return default_timer() - self.start_time
+
+    def reset(self):
+        self.__init__()
+
+    def __repr__(self):
+        return self.get()
\ No newline at end of file