diff --git a/apt.txt b/apt.txt
index c0e0ffb810cc7ee7dc785f2a04a27ebec944b3ce..d593bcc792fa1aa73b3ca7ab57a93e91dd738d8e 100644
--- a/apt.txt
+++ b/apt.txt
@@ -1,6 +1,6 @@
-curl
-git
-vim
-ssh
-gcc
+curl
+git
+vim
+ssh
+gcc
 build-essential
\ No newline at end of file
diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 1c323c366c7ff6b08baf1ff2d7d9ceee389bbd57..6218ab8c0dd6de2ddb9f3f238b532ba62d8ab0c6 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -1,201 +1,201 @@
-import copy
-import os
-import pickle
-import random
-from collections import namedtuple, deque, Iterable
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-import torch.optim as optim
-
-from reinforcement_learning.model import DuelingQNetwork
-from reinforcement_learning.policy import Policy
-
-
-class DDDQNPolicy(Policy):
-    """Dueling Double DQN policy"""
-
-    def __init__(self, state_size, action_size, parameters, evaluation_mode=False):
-        self.evaluation_mode = evaluation_mode
-
-        self.state_size = state_size
-        self.action_size = action_size
-        self.double_dqn = True
-        self.hidsize = 128
-
-        if not evaluation_mode:
-            self.hidsize = parameters.hidden_size
-            self.buffer_size = parameters.buffer_size
-            self.batch_size = parameters.batch_size
-            self.update_every = parameters.update_every
-            self.learning_rate = parameters.learning_rate
-            self.tau = parameters.tau
-            self.gamma = parameters.gamma
-            self.buffer_min_size = parameters.buffer_min_size
-
-            # Device
-        if parameters.use_gpu and torch.cuda.is_available():
-            self.device = torch.device("cuda:0")
-            # print("🐇 Using GPU")
-        else:
-            self.device = torch.device("cpu")
-            # print("🐢 Using CPU")
-
-        # Q-Network
-        self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(
-            self.device)
-
-        if not evaluation_mode:
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
-            self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate)
-            self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
-
-            self.t_step = 0
-            self.loss = 0.0
-
-    def act(self, state, eps=0.):
-        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
-        self.qnetwork_local.eval()
-        with torch.no_grad():
-            action_values = self.qnetwork_local(state)
-        self.qnetwork_local.train()
-
-        # Epsilon-greedy action selection
-        if random.random() > eps:
-            return np.argmax(action_values.cpu().data.numpy())
-        else:
-            return random.choice(np.arange(self.action_size))
-
-    def step(self, handle, state, action, reward, next_state, done):
-        assert not self.evaluation_mode, "Policy has been initialized for evaluation only."
-
-        # Save experience in replay memory
-        self.memory.add(state, action, reward, next_state, done)
-
-        # Learn every UPDATE_EVERY time steps.
-        self.t_step = (self.t_step + 1) % self.update_every
-        if self.t_step == 0:
-            # If enough samples are available in memory, get random subset and learn
-            if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
-                self._learn()
-
-    def _learn(self):
-        experiences = self.memory.sample()
-        states, actions, rewards, next_states, dones = experiences
-
-        # Get expected Q values from local model
-        q_expected = self.qnetwork_local(states).gather(1, actions)
-
-        if self.double_dqn:
-            # Double DQN
-            q_best_action = self.qnetwork_local(next_states).max(1)[1]
-            q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
-        else:
-            # DQN
-            q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
-
-        # Compute Q targets for current states
-        q_targets = rewards + (self.gamma * q_targets_next * (1 - dones))
-
-        # Compute loss
-        self.loss = F.mse_loss(q_expected, q_targets)
-
-        # Minimize the loss
-        self.optimizer.zero_grad()
-        self.loss.backward()
-        self.optimizer.step()
-
-        # Update target network
-        self._soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
-
-    def _soft_update(self, local_model, target_model, tau):
-        # Soft update model parameters.
-        # θ_target = τ*θ_local + (1 - τ)*θ_target
-        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
-            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
-
-    def save(self, filename):
-        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
-        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
-
-    def load(self, filename):
-        try:
-            if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
-                self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
-                print("qnetwork_local loaded ('{}')".format(filename + ".local"))
-                if not self.evaluation_mode:
-                    self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
-                    print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
-            else:
-                print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
-                                                                                             filename + ".target"))
-        except Exception as exc:
-            print(exc)
-            print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local",
-                                                                                           filename + ".target"))
-
-    def save_replay_buffer(self, filename):
-        memory = self.memory.memory
-        with open(filename, 'wb') as f:
-            pickle.dump(list(memory)[-500000:], f)
-
-    def load_replay_buffer(self, filename):
-        with open(filename, 'rb') as f:
-            self.memory.memory = pickle.load(f)
-
-    def test(self):
-        self.act(np.array([[0] * self.state_size]))
-        self._learn()
-
-
-Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
-
-
-class ReplayBuffer:
-    """Fixed-size buffer to store experience tuples."""
-
-    def __init__(self, action_size, buffer_size, batch_size, device):
-        """Initialize a ReplayBuffer object.
-
-        Params
-        ======
-            action_size (int): dimension of each action
-            buffer_size (int): maximum size of buffer
-            batch_size (int): size of each training batch
-        """
-        self.action_size = action_size
-        self.memory = deque(maxlen=buffer_size)
-        self.batch_size = batch_size
-        self.device = device
-
-    def add(self, state, action, reward, next_state, done):
-        """Add a new experience to memory."""
-        e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
-        self.memory.append(e)
-
-    def sample(self):
-        """Randomly sample a batch of experiences from memory."""
-        experiences = random.sample(self.memory, k=self.batch_size)
-
-        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
-            .float().to(self.device)
-        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
-            .long().to(self.device)
-        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
-            .float().to(self.device)
-        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
-            .float().to(self.device)
-        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
-            .float().to(self.device)
-
-        return states, actions, rewards, next_states, dones
-
-    def __len__(self):
-        """Return the current size of internal memory."""
-        return len(self.memory)
-
-    def __v_stack_impr(self, states):
-        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
-        np_states = np.reshape(np.array(states), (len(states), sub_dim))
-        return np_states
+import copy
+import os
+import pickle
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+
+from reinforcement_learning.model import DuelingQNetwork
+from reinforcement_learning.policy import Policy
+
+
+class DDDQNPolicy(Policy):
+    """Dueling Double DQN policy"""
+
+    def __init__(self, state_size, action_size, parameters, evaluation_mode=False):
+        self.evaluation_mode = evaluation_mode
+
+        self.state_size = state_size
+        self.action_size = action_size
+        self.double_dqn = True
+        self.hidsize = 128
+
+        if not evaluation_mode:
+            self.hidsize = parameters.hidden_size
+            self.buffer_size = parameters.buffer_size
+            self.batch_size = parameters.batch_size
+            self.update_every = parameters.update_every
+            self.learning_rate = parameters.learning_rate
+            self.tau = parameters.tau
+            self.gamma = parameters.gamma
+            self.buffer_min_size = parameters.buffer_min_size
+
+            # Device
+        if parameters.use_gpu and torch.cuda.is_available():
+            self.device = torch.device("cuda:0")
+            # print("🐇 Using GPU")
+        else:
+            self.device = torch.device("cpu")
+            # print("🐢 Using CPU")
+
+        # Q-Network
+        self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(
+            self.device)
+
+        if not evaluation_mode:
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+            self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate)
+            self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
+
+            self.t_step = 0
+            self.loss = 0.0
+
+    def act(self, state, eps=0.):
+        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy())
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    def step(self, handle, state, action, reward, next_state, done):
+        assert not self.evaluation_mode, "Policy has been initialized for evaluation only."
+
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+
+        # Learn every UPDATE_EVERY time steps.
+        self.t_step = (self.t_step + 1) % self.update_every
+        if self.t_step == 0:
+            # If enough samples are available in memory, get random subset and learn
+            if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
+                self._learn()
+
+    def _learn(self):
+        experiences = self.memory.sample()
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+        # Compute Q targets for current states
+        q_targets = rewards + (self.gamma * q_targets_next * (1 - dones))
+
+        # Compute loss
+        self.loss = F.mse_loss(q_expected, q_targets)
+
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        self.loss.backward()
+        self.optimizer.step()
+
+        # Update target network
+        self._soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
+
+    def _soft_update(self, local_model, target_model, tau):
+        # Soft update model parameters.
+        # θ_target = τ*θ_local + (1 - τ)*θ_target
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        try:
+            if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
+                self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+                print("qnetwork_local loaded ('{}')".format(filename + ".local"))
+                if not self.evaluation_mode:
+                    self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+                    print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
+            else:
+                print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
+                                                                                             filename + ".target"))
+        except Exception as exc:
+            print(exc)
+            print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local",
+                                                                                           filename + ".target"))
+
+    def save_replay_buffer(self, filename):
+        memory = self.memory.memory
+        with open(filename, 'wb') as f:
+            pickle.dump(list(memory)[-500000:], f)
+
+    def load_replay_buffer(self, filename):
+        with open(filename, 'rb') as f:
+            self.memory.memory = pickle.load(f)
+
+    def test(self):
+        self.act(np.array([[0] * self.state_size]))
+        self._learn()
+
+
+Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, device):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.device = device
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(self.device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(self.device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(self.device)
+
+        return states, actions, rewards, next_states, dones
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
diff --git a/reinforcement_learning/evaluate_agent.py b/reinforcement_learning/evaluate_agent.py
index 80d42987c40bc1c590152b213ab8abaf6f9a91a6..64eb9433a9df00457d698e5873b31a16712de718 100644
--- a/reinforcement_learning/evaluate_agent.py
+++ b/reinforcement_learning/evaluate_agent.py
@@ -1,376 +1,376 @@
-import math
-import multiprocessing
-import os
-import sys
-from argparse import ArgumentParser, Namespace
-from multiprocessing import Pool
-from pathlib import Path
-from pprint import pprint
-
-import numpy as np
-import torch
-from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import sparse_rail_generator
-from flatland.envs.schedule_generators import sparse_schedule_generator
-from flatland.utils.rendertools import RenderTool
-
-base_dir = Path(__file__).resolve().parent.parent
-sys.path.append(str(base_dir))
-
-from utils.deadlock_check import check_if_all_blocked
-from utils.timer import Timer
-from utils.observation_utils import normalize_observation
-from reinforcement_learning.dddqn_policy import DDDQNPolicy
-
-
-def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching):
-    # Evaluation is faster on CPU (except if you use a really huge policy)
-    parameters = {
-        'use_gpu': False
-    }
-
-    policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True)
-    policy.qnetwork_local = torch.load(checkpoint)
-
-    env_params = Namespace(**env_params)
-
-    # Environment parameters
-    n_agents = env_params.n_agents
-    x_dim = env_params.x_dim
-    y_dim = env_params.y_dim
-    n_cities = env_params.n_cities
-    max_rails_between_cities = env_params.max_rails_between_cities
-    max_rails_in_city = env_params.max_rails_in_city
-
-    # Malfunction and speed profiles
-    # TODO pass these parameters properly from main!
-    malfunction_parameters = MalfunctionParameters(
-        malfunction_rate=1. / 2000,  # Rate of malfunctions
-        min_duration=20,  # Minimal duration
-        max_duration=50  # Max duration
-    )
-
-    # Only fast trains in Round 1
-    speed_profiles = {
-        1.: 1.0,  # Fast passenger train
-        1. / 2.: 0.0,  # Fast freight train
-        1. / 3.: 0.0,  # Slow commuter train
-        1. / 4.: 0.0  # Slow freight train
-    }
-
-    # Observation parameters
-    observation_tree_depth = env_params.observation_tree_depth
-    observation_radius = env_params.observation_radius
-    observation_max_path_depth = env_params.observation_max_path_depth
-
-    # Observation builder
-    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
-
-    # Setup the environment
-    env = RailEnv(
-        width=x_dim, height=y_dim,
-        rail_generator=sparse_rail_generator(
-            max_num_cities=n_cities,
-            grid_mode=False,
-            max_rails_between_cities=max_rails_between_cities,
-            max_rails_in_city=max_rails_in_city,
-        ),
-        schedule_generator=sparse_schedule_generator(speed_profiles),
-        number_of_agents=n_agents,
-        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
-        obs_builder_object=tree_observation
-    )
-
-    if render:
-        env_renderer = RenderTool(env, gl="PGL")
-
-    action_dict = dict()
-    scores = []
-    completions = []
-    nb_steps = []
-    inference_times = []
-    preproc_times = []
-    agent_times = []
-    step_times = []
-
-    for episode_idx in range(n_eval_episodes):
-        seed += 1
-
-        inference_timer = Timer()
-        preproc_timer = Timer()
-        agent_timer = Timer()
-        step_timer = Timer()
-
-        step_timer.start()
-        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True, random_seed=seed)
-        step_timer.end()
-
-        agent_obs = [None] * env.get_num_agents()
-        score = 0.0
-
-        if render:
-            env_renderer.set_new_rail()
-
-        final_step = 0
-        skipped = 0
-
-        nb_hit = 0
-        agent_last_obs = {}
-        agent_last_action = {}
-
-        for step in range(max_steps - 1):
-            if allow_skipping and check_if_all_blocked(env):
-                # FIXME why -1? bug where all agents are "done" after max_steps!
-                skipped = max_steps - step - 1
-                final_step = max_steps - 2
-                n_unfinished_agents = sum(not done[idx] for idx in env.get_agent_handles())
-                score -= skipped * n_unfinished_agents
-                break
-
-            agent_timer.start()
-            for agent in env.get_agent_handles():
-                if obs[agent] and info['action_required'][agent]:
-                    if agent in agent_last_obs and np.all(agent_last_obs[agent] == obs[agent]):
-                        nb_hit += 1
-                        action = agent_last_action[agent]
-
-                    else:
-                        preproc_timer.start()
-                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
-                        preproc_timer.end()
-
-                        inference_timer.start()
-                        action = policy.act(norm_obs, eps=0.0)
-                        inference_timer.end()
-
-                    action_dict.update({agent: action})
-
-                    if allow_caching:
-                        agent_last_obs[agent] = obs[agent]
-                        agent_last_action[agent] = action
-            agent_timer.end()
-
-            step_timer.start()
-            obs, all_rewards, done, info = env.step(action_dict)
-            step_timer.end()
-
-            if render:
-                env_renderer.render_env(
-                    show=True,
-                    frames=False,
-                    show_observations=False,
-                    show_predictions=False
-                )
-
-                if step % 100 == 0:
-                    print("{}/{}".format(step, max_steps - 1))
-
-            for agent in env.get_agent_handles():
-                score += all_rewards[agent]
-
-            final_step = step
-
-            if done['__all__']:
-                break
-
-        normalized_score = score / (max_steps * env.get_num_agents())
-        scores.append(normalized_score)
-
-        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
-        completion = tasks_finished / max(1, env.get_num_agents())
-        completions.append(completion)
-
-        nb_steps.append(final_step)
-
-        inference_times.append(inference_timer.get())
-        preproc_times.append(preproc_timer.get())
-        agent_times.append(agent_timer.get())
-        step_times.append(step_timer.get())
-
-        skipped_text = ""
-        if skipped > 0:
-            skipped_text = "\tâš¡ Skipped {}".format(skipped)
-
-        hit_text = ""
-        if nb_hit > 0:
-            hit_text = "\tâš¡ Hit {} ({:.1f}%)".format(nb_hit, (100 * nb_hit) / (n_agents * final_step))
-
-        print(
-            "☑️  Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} "
-            "\t🍭 Seed: {}"
-            "\t🚉 Env: {:.3f}s  "
-            "\t🤖 Agent: {:.3f}s (per step: {:.3f}s) \t[preproc: {:.3f}s \tinfer: {:.3f}s]"
-            "{}{}".format(
-                normalized_score,
-                completion * 100.0,
-                final_step,
-                seed,
-                step_timer.get(),
-                agent_timer.get(),
-                agent_timer.get() / final_step,
-                preproc_timer.get(),
-                inference_timer.get(),
-                skipped_text,
-                hit_text
-            )
-        )
-
-    return scores, completions, nb_steps, agent_times, step_times
-
-
-def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping, allow_caching):
-    nb_threads = 1
-    eval_per_thread = n_evaluation_episodes
-
-    if not render:
-        nb_threads = multiprocessing.cpu_count()
-        eval_per_thread = max(1, math.ceil(n_evaluation_episodes / nb_threads))
-
-    total_nb_eval = eval_per_thread * nb_threads
-    print("Will evaluate policy {} over {} episodes on {} threads.".format(file, total_nb_eval, nb_threads))
-
-    if total_nb_eval != n_evaluation_episodes:
-        print("(Rounding up from {} to fill all cores)".format(n_evaluation_episodes))
-
-    # Observation parameters need to match the ones used during training!
-
-    # small_v0
-    small_v0_params = {
-        # sample configuration
-        "n_agents": 5,
-        "x_dim": 25,
-        "y_dim": 25,
-        "n_cities": 4,
-        "max_rails_between_cities": 2,
-        "max_rails_in_city": 3,
-
-        # observations
-        "observation_tree_depth": 2,
-        "observation_radius": 10,
-        "observation_max_path_depth": 20
-    }
-
-    # Test_0
-    test0_params = {
-        # sample configuration
-        "n_agents": 5,
-        "x_dim": 25,
-        "y_dim": 25,
-        "n_cities": 2,
-        "max_rails_between_cities": 2,
-        "max_rails_in_city": 3,
-
-        # observations
-        "observation_tree_depth": 2,
-        "observation_radius": 10,
-        "observation_max_path_depth": 20
-    }
-
-    # Test_1
-    test1_params = {
-        # environment
-        "n_agents": 10,
-        "x_dim": 30,
-        "y_dim": 30,
-        "n_cities": 2,
-        "max_rails_between_cities": 2,
-        "max_rails_in_city": 3,
-
-        # observations
-        "observation_tree_depth": 2,
-        "observation_radius": 10,
-        "observation_max_path_depth": 10
-    }
-
-    # Test_5
-    test5_params = {
-        # environment
-        "n_agents": 80,
-        "x_dim": 35,
-        "y_dim": 35,
-        "n_cities": 5,
-        "max_rails_between_cities": 2,
-        "max_rails_in_city": 4,
-
-        # observations
-        "observation_tree_depth": 2,
-        "observation_radius": 10,
-        "observation_max_path_depth": 20
-    }
-
-    params = small_v0_params
-    env_params = Namespace(**params)
-
-    print("Environment parameters:")
-    pprint(params)
-
-    # Calculate space dimensions and max steps
-    max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities)))
-    action_size = 5
-    tree_observation = TreeObsForRailEnv(max_depth=env_params.observation_tree_depth)
-    tree_depth = env_params.observation_tree_depth
-    num_features_per_node = tree_observation.observation_dim
-    n_nodes = sum([np.power(4, i) for i in range(tree_depth + 1)])
-    state_size = num_features_per_node * n_nodes
-
-    results = []
-    if render:
-        results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching))
-
-    else:
-        with Pool() as p:
-            results = p.starmap(eval_policy,
-                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching)
-                                 for seed in
-                                 range(total_nb_eval)])
-
-    scores = []
-    completions = []
-    nb_steps = []
-    times = []
-    step_times = []
-    for s, c, n, t, st in results:
-        scores.append(s)
-        completions.append(c)
-        nb_steps.append(n)
-        times.append(t)
-        step_times.append(st)
-
-    print("-" * 200)
-
-    print("✅ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} \tAgent total: {:.3f}s (per step: {:.3f}s)".format(
-        np.mean(scores),
-        np.mean(completions) * 100.0,
-        np.mean(nb_steps),
-        np.mean(times),
-        np.mean(times) / np.mean(nb_steps)
-    ))
-
-    print("⏲️  Agent sum: {:.3f}s \tEnv sum: {:.3f}s \tTotal sum: {:.3f}s".format(
-        np.sum(times),
-        np.sum(step_times),
-        np.sum(times) + np.sum(step_times)
-    ))
-
-
-if __name__ == "__main__":
-    parser = ArgumentParser()
-    parser.add_argument("-f", "--file", help="checkpoint to load", required=True, type=str)
-    parser.add_argument("-n", "--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
-
-    # TODO
-    # parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int)
-
-    parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
-    parser.add_argument("--render", help="render a single episode", action='store_true')
-    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
-    parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true')
-    args = parser.parse_args()
-
-    os.environ["OMP_NUM_THREADS"] = str(1)
-    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render,
-                    allow_skipping=args.allow_skipping, allow_caching=args.allow_caching)
+import math
+import multiprocessing
+import os
+import sys
+from argparse import ArgumentParser, Namespace
+from multiprocessing import Pool
+from pathlib import Path
+from pprint import pprint
+
+import numpy as np
+import torch
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.deadlock_check import check_if_all_blocked
+from utils.timer import Timer
+from utils.observation_utils import normalize_observation
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+
+
+def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching):
+    # Evaluation is faster on CPU (except if you use a really huge policy)
+    parameters = {
+        'use_gpu': False
+    }
+
+    policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True)
+    policy.qnetwork_local = torch.load(checkpoint)
+
+    env_params = Namespace(**env_params)
+
+    # Environment parameters
+    n_agents = env_params.n_agents
+    x_dim = env_params.x_dim
+    y_dim = env_params.y_dim
+    n_cities = env_params.n_cities
+    max_rails_between_cities = env_params.max_rails_between_cities
+    max_rails_in_city = env_params.max_rails_in_city
+
+    # Malfunction and speed profiles
+    # TODO pass these parameters properly from main!
+    malfunction_parameters = MalfunctionParameters(
+        malfunction_rate=1. / 2000,  # Rate of malfunctions
+        min_duration=20,  # Minimal duration
+        max_duration=50  # Max duration
+    )
+
+    # Only fast trains in Round 1
+    speed_profiles = {
+        1.: 1.0,  # Fast passenger train
+        1. / 2.: 0.0,  # Fast freight train
+        1. / 3.: 0.0,  # Slow commuter train
+        1. / 4.: 0.0  # Slow freight train
+    }
+
+    # Observation parameters
+    observation_tree_depth = env_params.observation_tree_depth
+    observation_radius = env_params.observation_radius
+    observation_max_path_depth = env_params.observation_max_path_depth
+
+    # Observation builder
+    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+
+    # Setup the environment
+    env = RailEnv(
+        width=x_dim, height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city,
+        ),
+        schedule_generator=sparse_schedule_generator(speed_profiles),
+        number_of_agents=n_agents,
+        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
+        obs_builder_object=tree_observation
+    )
+
+    if render:
+        env_renderer = RenderTool(env, gl="PGL")
+
+    action_dict = dict()
+    scores = []
+    completions = []
+    nb_steps = []
+    inference_times = []
+    preproc_times = []
+    agent_times = []
+    step_times = []
+
+    for episode_idx in range(n_eval_episodes):
+        seed += 1
+
+        inference_timer = Timer()
+        preproc_timer = Timer()
+        agent_timer = Timer()
+        step_timer = Timer()
+
+        step_timer.start()
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True, random_seed=seed)
+        step_timer.end()
+
+        agent_obs = [None] * env.get_num_agents()
+        score = 0.0
+
+        if render:
+            env_renderer.set_new_rail()
+
+        final_step = 0
+        skipped = 0
+
+        nb_hit = 0
+        agent_last_obs = {}
+        agent_last_action = {}
+
+        for step in range(max_steps - 1):
+            if allow_skipping and check_if_all_blocked(env):
+                # FIXME why -1? bug where all agents are "done" after max_steps!
+                skipped = max_steps - step - 1
+                final_step = max_steps - 2
+                n_unfinished_agents = sum(not done[idx] for idx in env.get_agent_handles())
+                score -= skipped * n_unfinished_agents
+                break
+
+            agent_timer.start()
+            for agent in env.get_agent_handles():
+                if obs[agent] and info['action_required'][agent]:
+                    if agent in agent_last_obs and np.all(agent_last_obs[agent] == obs[agent]):
+                        nb_hit += 1
+                        action = agent_last_action[agent]
+
+                    else:
+                        preproc_timer.start()
+                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
+                        preproc_timer.end()
+
+                        inference_timer.start()
+                        action = policy.act(norm_obs, eps=0.0)
+                        inference_timer.end()
+
+                    action_dict.update({agent: action})
+
+                    if allow_caching:
+                        agent_last_obs[agent] = obs[agent]
+                        agent_last_action[agent] = action
+            agent_timer.end()
+
+            step_timer.start()
+            obs, all_rewards, done, info = env.step(action_dict)
+            step_timer.end()
+
+            if render:
+                env_renderer.render_env(
+                    show=True,
+                    frames=False,
+                    show_observations=False,
+                    show_predictions=False
+                )
+
+                if step % 100 == 0:
+                    print("{}/{}".format(step, max_steps - 1))
+
+            for agent in env.get_agent_handles():
+                score += all_rewards[agent]
+
+            final_step = step
+
+            if done['__all__']:
+                break
+
+        normalized_score = score / (max_steps * env.get_num_agents())
+        scores.append(normalized_score)
+
+        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
+        completion = tasks_finished / max(1, env.get_num_agents())
+        completions.append(completion)
+
+        nb_steps.append(final_step)
+
+        inference_times.append(inference_timer.get())
+        preproc_times.append(preproc_timer.get())
+        agent_times.append(agent_timer.get())
+        step_times.append(step_timer.get())
+
+        skipped_text = ""
+        if skipped > 0:
+            skipped_text = "\tâš¡ Skipped {}".format(skipped)
+
+        hit_text = ""
+        if nb_hit > 0:
+            hit_text = "\tâš¡ Hit {} ({:.1f}%)".format(nb_hit, (100 * nb_hit) / (n_agents * final_step))
+
+        print(
+            "☑️  Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} "
+            "\t🍭 Seed: {}"
+            "\t🚉 Env: {:.3f}s  "
+            "\t🤖 Agent: {:.3f}s (per step: {:.3f}s) \t[preproc: {:.3f}s \tinfer: {:.3f}s]"
+            "{}{}".format(
+                normalized_score,
+                completion * 100.0,
+                final_step,
+                seed,
+                step_timer.get(),
+                agent_timer.get(),
+                agent_timer.get() / final_step,
+                preproc_timer.get(),
+                inference_timer.get(),
+                skipped_text,
+                hit_text
+            )
+        )
+
+    return scores, completions, nb_steps, agent_times, step_times
+
+
+def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping, allow_caching):
+    nb_threads = 1
+    eval_per_thread = n_evaluation_episodes
+
+    if not render:
+        nb_threads = multiprocessing.cpu_count()
+        eval_per_thread = max(1, math.ceil(n_evaluation_episodes / nb_threads))
+
+    total_nb_eval = eval_per_thread * nb_threads
+    print("Will evaluate policy {} over {} episodes on {} threads.".format(file, total_nb_eval, nb_threads))
+
+    if total_nb_eval != n_evaluation_episodes:
+        print("(Rounding up from {} to fill all cores)".format(n_evaluation_episodes))
+
+    # Observation parameters need to match the ones used during training!
+
+    # small_v0
+    small_v0_params = {
+        # sample configuration
+        "n_agents": 5,
+        "x_dim": 25,
+        "y_dim": 25,
+        "n_cities": 4,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    # Test_0
+    test0_params = {
+        # sample configuration
+        "n_agents": 5,
+        "x_dim": 25,
+        "y_dim": 25,
+        "n_cities": 2,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    # Test_1
+    test1_params = {
+        # environment
+        "n_agents": 10,
+        "x_dim": 30,
+        "y_dim": 30,
+        "n_cities": 2,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 10
+    }
+
+    # Test_5
+    test5_params = {
+        # environment
+        "n_agents": 80,
+        "x_dim": 35,
+        "y_dim": 35,
+        "n_cities": 5,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 4,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    params = small_v0_params
+    env_params = Namespace(**params)
+
+    print("Environment parameters:")
+    pprint(params)
+
+    # Calculate space dimensions and max steps
+    max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities)))
+    action_size = 5
+    tree_observation = TreeObsForRailEnv(max_depth=env_params.observation_tree_depth)
+    tree_depth = env_params.observation_tree_depth
+    num_features_per_node = tree_observation.observation_dim
+    n_nodes = sum([np.power(4, i) for i in range(tree_depth + 1)])
+    state_size = num_features_per_node * n_nodes
+
+    results = []
+    if render:
+        results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching))
+
+    else:
+        with Pool() as p:
+            results = p.starmap(eval_policy,
+                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching)
+                                 for seed in
+                                 range(total_nb_eval)])
+
+    scores = []
+    completions = []
+    nb_steps = []
+    times = []
+    step_times = []
+    for s, c, n, t, st in results:
+        scores.append(s)
+        completions.append(c)
+        nb_steps.append(n)
+        times.append(t)
+        step_times.append(st)
+
+    print("-" * 200)
+
+    print("✅ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} \tAgent total: {:.3f}s (per step: {:.3f}s)".format(
+        np.mean(scores),
+        np.mean(completions) * 100.0,
+        np.mean(nb_steps),
+        np.mean(times),
+        np.mean(times) / np.mean(nb_steps)
+    ))
+
+    print("⏲️  Agent sum: {:.3f}s \tEnv sum: {:.3f}s \tTotal sum: {:.3f}s".format(
+        np.sum(times),
+        np.sum(step_times),
+        np.sum(times) + np.sum(step_times)
+    ))
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-f", "--file", help="checkpoint to load", required=True, type=str)
+    parser.add_argument("-n", "--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
+
+    # TODO
+    # parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int)
+
+    parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
+    parser.add_argument("--render", help="render a single episode", action='store_true')
+    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
+    parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true')
+    args = parser.parse_args()
+
+    os.environ["OMP_NUM_THREADS"] = str(1)
+    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render,
+                    allow_skipping=args.allow_skipping, allow_caching=args.allow_caching)
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
index b9d103961da0b2474eb0fb3cb1dd65058e2bc51c..74288e73bc182195e6b3f0778c4739ffbc487217 100755
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -1,608 +1,608 @@
-import os
-import random
-import sys
-from argparse import ArgumentParser, Namespace
-from collections import deque
-from datetime import datetime
-from pathlib import Path
-from pprint import pprint
-
-import numpy as np
-import psutil
-from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
-from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-from flatland.envs.rail_env import RailEnv, RailEnvActions
-from flatland.envs.rail_generators import sparse_rail_generator
-from flatland.envs.schedule_generators import sparse_schedule_generator
-from flatland.utils.rendertools import RenderTool
-from torch.utils.tensorboard import SummaryWriter
-
-from reinforcement_learning.dddqn_policy import DDDQNPolicy
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-
-base_dir = Path(__file__).resolve().parent.parent
-sys.path.append(str(base_dir))
-
-from utils.timer import Timer
-from utils.observation_utils import normalize_observation
-from utils.fast_tree_obs import FastTreeObs
-
-try:
-    import wandb
-
-    wandb.init(sync_tensorboard=True)
-except ImportError:
-    print("Install wandb to log to Weights & Biases")
-
-"""
-This file shows how to train multiple agents using a reinforcement learning approach.
-After training an agent, you can submit it straight away to the NeurIPS 2020 Flatland challenge!
-
-Agent documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html
-Submission documentation: https://flatland.aicrowd.com/getting-started/first-submission.html
-"""
-
-
-def create_rail_env(env_params, tree_observation):
-    n_agents = env_params.n_agents
-    x_dim = env_params.x_dim
-    y_dim = env_params.y_dim
-    n_cities = env_params.n_cities
-    max_rails_between_cities = env_params.max_rails_between_cities
-    max_rails_in_city = env_params.max_rails_in_city
-    seed = env_params.seed
-
-    # Break agents from time to time
-    malfunction_parameters = MalfunctionParameters(
-        malfunction_rate=env_params.malfunction_rate,
-        min_duration=20,
-        max_duration=50
-    )
-
-    return RailEnv(
-        width=x_dim, height=y_dim,
-        rail_generator=sparse_rail_generator(
-            max_num_cities=n_cities,
-            grid_mode=False,
-            max_rails_between_cities=max_rails_between_cities,
-            max_rails_in_city=max_rails_in_city
-        ),
-        schedule_generator=sparse_schedule_generator(),
-        number_of_agents=n_agents,
-        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
-        obs_builder_object=tree_observation,
-        random_seed=seed
-    )
-
-
-def train_agent(train_params, train_env_params, eval_env_params, obs_params):
-    # Environment parameters
-    n_agents = train_env_params.n_agents
-    x_dim = train_env_params.x_dim
-    y_dim = train_env_params.y_dim
-    n_cities = train_env_params.n_cities
-    max_rails_between_cities = train_env_params.max_rails_between_cities
-    max_rails_in_city = train_env_params.max_rails_in_city
-    seed = train_env_params.seed
-
-    # Unique ID for this training
-    now = datetime.now()
-    training_id = now.strftime('%y%m%d%H%M%S')
-
-    # Observation parameters
-    observation_tree_depth = obs_params.observation_tree_depth
-    observation_radius = obs_params.observation_radius
-    observation_max_path_depth = obs_params.observation_max_path_depth
-
-    # Training parameters
-    eps_start = train_params.eps_start
-    eps_end = train_params.eps_end
-    eps_decay = train_params.eps_decay
-    n_episodes = train_params.n_episodes
-    checkpoint_interval = train_params.checkpoint_interval
-    n_eval_episodes = train_params.n_evaluation_episodes
-    restore_replay_buffer = train_params.restore_replay_buffer
-    save_replay_buffer = train_params.save_replay_buffer
-
-    # Set the seeds
-    random.seed(seed)
-    np.random.seed(seed)
-
-    # Observation builder
-    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-    if not train_params.use_fast_tree_observation:
-        print("\nUsing standard TreeObs")
-
-        def check_is_observation_valid(observation):
-            return observation
-
-        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
-            return normalize_observation(observation, tree_depth, observation_radius)
-
-        tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
-        tree_observation.check_is_observation_valid = check_is_observation_valid
-        tree_observation.get_normalized_observation = get_normalized_observation
-    else:
-        print("\nUsing FastTreeObs")
-
-        def check_is_observation_valid(observation):
-            return True
-
-        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
-            return observation
-
-        tree_observation = FastTreeObs(max_depth=observation_tree_depth)
-        tree_observation.check_is_observation_valid = check_is_observation_valid
-        tree_observation.get_normalized_observation = get_normalized_observation
-
-    # Setup the environments
-    train_env = create_rail_env(train_env_params, tree_observation)
-    train_env.reset(regenerate_schedule=True, regenerate_rail=True)
-    eval_env = create_rail_env(eval_env_params, tree_observation)
-    eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
-
-    if not train_params.use_fast_tree_observation:
-        # Calculate the state size given the depth of the tree observation and the number of features
-        n_features_per_node = train_env.obs_builder.observation_dim
-        n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
-        state_size = n_features_per_node * n_nodes
-    else:
-        # Calculate the state size given the depth of the tree observation and the number of features
-        state_size = tree_observation.observation_dim
-
-    # Setup renderer
-    if train_params.render:
-        env_renderer = RenderTool(train_env, gl="PGL")
-
-    # The action space of flatland is 5 discrete actions
-    action_size = 5
-
-    action_count = [0] * action_size
-    action_dict = dict()
-    agent_obs = [None] * n_agents
-    agent_prev_obs = [None] * n_agents
-    agent_prev_action = [2] * n_agents
-    update_values = [False] * n_agents
-
-    # Smoothed values used as target for hyperparameter tuning
-    smoothed_eval_normalized_score = -1.0
-    smoothed_eval_completion = 0.0
-
-    scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
-    completion_window = deque(maxlen=checkpoint_interval)
-
-    # Double Dueling DQN policy
-    policy = DDDQNPolicy(state_size, action_size, train_params)
-    # policy = PPOAgent(state_size, action_size, n_agents)
-    # Load existing policy
-    if train_params.load_policy is not "":
-        policy.load(train_params.load_policy)
-
-    # Loads existing replay buffer
-    if restore_replay_buffer:
-        try:
-            policy.load_replay_buffer(restore_replay_buffer)
-            policy.test()
-        except RuntimeError as e:
-            print("\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?")
-            print(e)
-            exit(1)
-
-    print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))
-
-    hdd = psutil.disk_usage('/')
-    if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
-        print(
-            "⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(
-                hdd.free / (2 ** 30)))
-
-    # TensorBoard writer
-    writer = SummaryWriter()
-
-    training_timer = Timer()
-    training_timer.start()
-
-    print(
-        "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format(
-            train_env.get_num_agents(),
-            x_dim, y_dim,
-            n_episodes,
-            n_eval_episodes,
-            checkpoint_interval,
-            training_id
-        ))
-
-    for episode_idx in range(n_episodes + 1):
-        step_timer = Timer()
-        reset_timer = Timer()
-        learn_timer = Timer()
-        preproc_timer = Timer()
-        inference_timer = Timer()
-
-        # Reset environment
-        reset_timer.start()
-        train_env_params.n_agents = episode_idx % n_agents + 1
-        train_env = create_rail_env(train_env_params, tree_observation)
-        obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
-        policy.reset()
-
-        policy2 = DeadLockAvoidanceAgent(train_env)
-        policy2.reset()
-
-        reset_timer.end()
-
-        if train_params.render:
-            env_renderer.set_new_rail()
-
-        score = 0
-        nb_steps = 0
-        actions_taken = []
-
-        # Build initial agent-specific observations
-        for agent in train_env.get_agent_handles():
-            if tree_observation.check_is_observation_valid(obs[agent]):
-                agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth,
-                                                                               observation_radius=observation_radius)
-                agent_prev_obs[agent] = agent_obs[agent].copy()
-
-        # Max number of steps per episode
-        # This is the official formula used during evaluations
-        # See details in flatland.envs.schedule_generators.sparse_schedule_generator
-        # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
-        max_steps = train_env._max_episode_steps
-
-        # Run episode
-        agent_to_learn = 0
-        if train_env.get_num_agents() > 1:
-            agent_to_learn = np.random.choice(train_env.get_num_agents())
-        for step in range(max_steps - 1):
-            inference_timer.start()
-            policy.start_step()
-            policy2.start_step()
-            for agent in train_env.get_agent_handles():
-                if info['action_required'][agent]:
-                    update_values[agent] = True
-
-                    if agent == agent_to_learn:
-                        action = policy.act(agent_obs[agent], eps=eps_start)
-                    else:
-                        action = policy2.act([agent], eps=eps_start)
-                    action_count[action] += 1
-                    actions_taken.append(action)
-                else:
-                    # An action is not required if the train hasn't joined the railway network,
-                    # if it already reached its target, or if is currently malfunctioning.
-                    update_values[agent] = False
-                    action = 0
-                action_dict.update({agent: action})
-            policy.end_step()
-            policy2.end_step()
-            inference_timer.end()
-
-            # Environment step
-            step_timer.start()
-            next_obs, all_rewards, done, info = train_env.step(action_dict)
-
-            for agent in train_env.get_agent_handles():
-                act = action_dict.get(agent, RailEnvActions.DO_NOTHING)
-                if agent_obs[agent][26] == 1:
-                    if act == RailEnvActions.STOP_MOVING:
-                        all_rewards[agent] *= 0.01
-                else:
-                    if act == RailEnvActions.MOVE_LEFT:
-                        all_rewards[agent] *= 0.9
-                    else:
-                        if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0:
-                            if act == RailEnvActions.MOVE_FORWARD:
-                                all_rewards[agent] *= 0.01
-                if done[agent]:
-                    all_rewards[agent] += 100.0
-
-            step_timer.end()
-
-            # Render an episode at some interval
-            if train_params.render and episode_idx % checkpoint_interval == 0:
-                env_renderer.render_env(
-                    show=True,
-                    frames=False,
-                    show_observations=False,
-                    show_predictions=False
-                )
-
-            # Update replay buffer and train agent
-            for agent in train_env.get_agent_handles():
-                if update_values[agent] or done['__all__']:
-                    # Only learn from timesteps where somethings happened
-                    learn_timer.start()
-                    if agent == agent_to_learn:
-                        policy.step(agent,
-                                    agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
-                                    agent_obs[agent],
-                                    done[agent])
-                    learn_timer.end()
-
-                    agent_prev_obs[agent] = agent_obs[agent].copy()
-                    agent_prev_action[agent] = action_dict[agent]
-
-                # Preprocess the new observations
-                if tree_observation.check_is_observation_valid(next_obs[agent]):
-                    preproc_timer.start()
-                    agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent],
-                                                                                   observation_tree_depth,
-                                                                                   observation_radius=observation_radius)
-                    preproc_timer.end()
-
-                score += all_rewards[agent]
-
-            nb_steps = step
-
-            if done['__all__']:
-                break
-
-        # Epsilon decay
-        eps_start = max(eps_end, eps_decay * eps_start)
-
-        # Collect information about training
-        tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles())
-        completion = tasks_finished / max(1, train_env.get_num_agents())
-        normalized_score = score / (max_steps * train_env.get_num_agents())
-        action_probs = action_count / max(1, np.sum(action_count))
-
-        scores_window.append(normalized_score)
-        completion_window.append(completion)
-        smoothed_normalized_score = np.mean(scores_window)
-        smoothed_completion = np.mean(completion_window)
-
-        # Print logs
-        if episode_idx % checkpoint_interval == 0:
-            policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
-
-            if save_replay_buffer:
-                policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl')
-
-            if train_params.render:
-                env_renderer.close_window()
-
-            # reset action count
-            action_count = [0] * action_size
-
-        print(
-            '\r🚂 Episode {}'
-            '\t 🏆 Score: {:7.3f}'
-            ' Avg: {:7.3f}'
-            '\t 💯 Done: {:6.2f}%'
-            ' Avg: {:6.2f}%'
-            '\t 🎲 Epsilon: {:.3f} '
-            '\t 🔀 Action Probs: {}'.format(
-                episode_idx,
-                normalized_score,
-                smoothed_normalized_score,
-                100 * completion,
-                100 * smoothed_completion,
-                eps_start,
-                format_action_prob(action_probs)
-            ), end=" ")
-
-        # Evaluate policy and log results at some interval
-        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0:
-            scores, completions, nb_steps_eval = eval_policy(eval_env,
-                                                             tree_observation,
-                                                             policy,
-                                                             train_params,
-                                                             obs_params)
-
-            writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx)
-            writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx)
-            writer.add_scalar("evaluation/scores_mean", np.mean(scores), episode_idx)
-            writer.add_scalar("evaluation/scores_std", np.std(scores), episode_idx)
-            writer.add_histogram("evaluation/scores", np.array(scores), episode_idx)
-            writer.add_scalar("evaluation/completions_min", np.min(completions), episode_idx)
-            writer.add_scalar("evaluation/completions_max", np.max(completions), episode_idx)
-            writer.add_scalar("evaluation/completions_mean", np.mean(completions), episode_idx)
-            writer.add_scalar("evaluation/completions_std", np.std(completions), episode_idx)
-            writer.add_histogram("evaluation/completions", np.array(completions), episode_idx)
-            writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval), episode_idx)
-            writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval), episode_idx)
-            writer.add_scalar("evaluation/nb_steps_mean", np.mean(nb_steps_eval), episode_idx)
-            writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval), episode_idx)
-            writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx)
-
-            smoothing = 0.9
-            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (
-                    1.0 - smoothing)
-            smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing)
-            writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx)
-            writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx)
-
-        # Save logs to tensorboard
-        writer.add_scalar("training/score", normalized_score, episode_idx)
-        writer.add_scalar("training/smoothed_score", smoothed_normalized_score, episode_idx)
-        writer.add_scalar("training/completion", np.mean(completion), episode_idx)
-        writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx)
-        writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
-        writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx)
-        writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx)
-        writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
-        writer.add_scalar("actions/forward", action_probs[RailEnvActions.MOVE_FORWARD], episode_idx)
-        writer.add_scalar("actions/right", action_probs[RailEnvActions.MOVE_RIGHT], episode_idx)
-        writer.add_scalar("actions/stop", action_probs[RailEnvActions.STOP_MOVING], episode_idx)
-        writer.add_scalar("training/epsilon", eps_start, episode_idx)
-        writer.add_scalar("training/buffer_size", len(policy.memory), episode_idx)
-        writer.add_scalar("training/loss", policy.loss, episode_idx)
-        writer.add_scalar("timer/reset", reset_timer.get(), episode_idx)
-        writer.add_scalar("timer/step", step_timer.get(), episode_idx)
-        writer.add_scalar("timer/learn", learn_timer.get(), episode_idx)
-        writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx)
-        writer.add_scalar("timer/total", training_timer.get_current(), episode_idx)
-
-
-def format_action_prob(action_probs):
-    action_probs = np.round(action_probs, 3)
-    actions = ["↻", "←", "↑", "→", "◼"]
-
-    buffer = ""
-    for action, action_prob in zip(actions, action_probs):
-        buffer += action + " " + "{:.3f}".format(action_prob) + " "
-
-    return buffer
-
-
-def eval_policy(env, tree_observation, policy, train_params, obs_params):
-    n_eval_episodes = train_params.n_evaluation_episodes
-    max_steps = env._max_episode_steps
-    tree_depth = obs_params.observation_tree_depth
-    observation_radius = obs_params.observation_radius
-
-    action_dict = dict()
-    scores = []
-    completions = []
-    nb_steps = []
-
-    for episode_idx in range(n_eval_episodes):
-        agent_obs = [None] * env.get_num_agents()
-        score = 0.0
-
-        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
-        final_step = 0
-
-        for step in range(max_steps - 1):
-            policy.start_step()
-            for agent in env.get_agent_handles():
-                if tree_observation.check_is_observation_valid(agent_obs[agent]):
-                    agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
-                                                                                   observation_radius=observation_radius)
-
-                action = 0
-                if info['action_required'][agent]:
-                    if tree_observation.check_is_observation_valid(agent_obs[agent]):
-                        action = policy.act(agent_obs[agent], eps=0.0)
-                action_dict.update({agent: action})
-            policy.end_step()
-            obs, all_rewards, done, info = env.step(action_dict)
-
-            for agent in env.get_agent_handles():
-                score += all_rewards[agent]
-
-            final_step = step
-
-            if done['__all__']:
-                break
-
-        normalized_score = score / (max_steps * env.get_num_agents())
-        scores.append(normalized_score)
-
-        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
-        completion = tasks_finished / max(1, env.get_num_agents())
-        completions.append(completion)
-
-        nb_steps.append(final_step)
-
-    print("\t✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0))
-
-    return scores, completions, nb_steps
-
-
-if __name__ == "__main__":
-    parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
-    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
-                        type=int)
-    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int)
-    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
-    parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
-    parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
-    parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float)
-    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int)
-    parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
-    parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
-    parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
-                        type=bool)
-    parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
-    parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
-    parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
-    parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
-    parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
-    parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
-    parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
-    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
-    parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
-    parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
-    parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
-                        action='store_true')
-    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
-
-    training_params = parser.parse_args()
-    env_params = [
-        {
-            # Test_0
-            "n_agents": 5,
-            "x_dim": 25,
-            "y_dim": 25,
-            "n_cities": 2,
-            "max_rails_between_cities": 2,
-            "max_rails_in_city": 3,
-            "malfunction_rate": 1 / 50,
-            "seed": 0
-        },
-        {
-            # Test_1
-            "n_agents": 10,
-            "x_dim": 30,
-            "y_dim": 30,
-            "n_cities": 2,
-            "max_rails_between_cities": 2,
-            "max_rails_in_city": 3,
-            "malfunction_rate": 1 / 100,
-            "seed": 0
-        },
-        {
-            # Test_2
-            "n_agents": 20,
-            "x_dim": 30,
-            "y_dim": 30,
-            "n_cities": 3,
-            "max_rails_between_cities": 2,
-            "max_rails_in_city": 3,
-            "malfunction_rate": 1 / 200,
-            "seed": 0
-        },
-    ]
-
-    obs_params = {
-        "observation_tree_depth": training_params.max_depth,
-        "observation_radius": 10,
-        "observation_max_path_depth": 30
-    }
-
-
-    def check_env_config(id):
-        if id >= len(env_params) or id < 0:
-            print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(
-                len(env_params) - 1))
-            exit(1)
-
-
-    check_env_config(training_params.training_env_config)
-    check_env_config(training_params.evaluation_env_config)
-
-    training_env_params = env_params[training_params.training_env_config]
-    evaluation_env_params = env_params[training_params.evaluation_env_config]
-
-    # FIXME hard-coded for sweep search
-    # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly
-    # training_params.use_fast_tree_observation = True
-
-    print("\nTraining parameters:")
-    pprint(vars(training_params))
-    print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config))
-    pprint(training_env_params)
-    print("\nEvaluation environment parameters (Test_{}):".format(training_params.evaluation_env_config))
-    pprint(evaluation_env_params)
-    print("\nObservation parameters:")
-    pprint(obs_params)
-
-    os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads)
-    train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params),
-                Namespace(**obs_params))
+import os
+import random
+import sys
+from argparse import ArgumentParser, Namespace
+from collections import deque
+from datetime import datetime
+from pathlib import Path
+from pprint import pprint
+
+import numpy as np
+import psutil
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from torch.utils.tensorboard import SummaryWriter
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.timer import Timer
+from utils.observation_utils import normalize_observation
+from utils.fast_tree_obs import FastTreeObs
+
+try:
+    import wandb
+
+    wandb.init(sync_tensorboard=True)
+except ImportError:
+    print("Install wandb to log to Weights & Biases")
+
+"""
+This file shows how to train multiple agents using a reinforcement learning approach.
+After training an agent, you can submit it straight away to the NeurIPS 2020 Flatland challenge!
+
+Agent documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html
+Submission documentation: https://flatland.aicrowd.com/getting-started/first-submission.html
+"""
+
+
+def create_rail_env(env_params, tree_observation):
+    n_agents = env_params.n_agents
+    x_dim = env_params.x_dim
+    y_dim = env_params.y_dim
+    n_cities = env_params.n_cities
+    max_rails_between_cities = env_params.max_rails_between_cities
+    max_rails_in_city = env_params.max_rails_in_city
+    seed = env_params.seed
+
+    # Break agents from time to time
+    malfunction_parameters = MalfunctionParameters(
+        malfunction_rate=env_params.malfunction_rate,
+        min_duration=20,
+        max_duration=50
+    )
+
+    return RailEnv(
+        width=x_dim, height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city
+        ),
+        schedule_generator=sparse_schedule_generator(),
+        number_of_agents=n_agents,
+        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
+        obs_builder_object=tree_observation,
+        random_seed=seed
+    )
+
+
+def train_agent(train_params, train_env_params, eval_env_params, obs_params):
+    # Environment parameters
+    n_agents = train_env_params.n_agents
+    x_dim = train_env_params.x_dim
+    y_dim = train_env_params.y_dim
+    n_cities = train_env_params.n_cities
+    max_rails_between_cities = train_env_params.max_rails_between_cities
+    max_rails_in_city = train_env_params.max_rails_in_city
+    seed = train_env_params.seed
+
+    # Unique ID for this training
+    now = datetime.now()
+    training_id = now.strftime('%y%m%d%H%M%S')
+
+    # Observation parameters
+    observation_tree_depth = obs_params.observation_tree_depth
+    observation_radius = obs_params.observation_radius
+    observation_max_path_depth = obs_params.observation_max_path_depth
+
+    # Training parameters
+    eps_start = train_params.eps_start
+    eps_end = train_params.eps_end
+    eps_decay = train_params.eps_decay
+    n_episodes = train_params.n_episodes
+    checkpoint_interval = train_params.checkpoint_interval
+    n_eval_episodes = train_params.n_evaluation_episodes
+    restore_replay_buffer = train_params.restore_replay_buffer
+    save_replay_buffer = train_params.save_replay_buffer
+
+    # Set the seeds
+    random.seed(seed)
+    np.random.seed(seed)
+
+    # Observation builder
+    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+    if not train_params.use_fast_tree_observation:
+        print("\nUsing standard TreeObs")
+
+        def check_is_observation_valid(observation):
+            return observation
+
+        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+            return normalize_observation(observation, tree_depth, observation_radius)
+
+        tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+        tree_observation.check_is_observation_valid = check_is_observation_valid
+        tree_observation.get_normalized_observation = get_normalized_observation
+    else:
+        print("\nUsing FastTreeObs")
+
+        def check_is_observation_valid(observation):
+            return True
+
+        def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+            return observation
+
+        tree_observation = FastTreeObs(max_depth=observation_tree_depth)
+        tree_observation.check_is_observation_valid = check_is_observation_valid
+        tree_observation.get_normalized_observation = get_normalized_observation
+
+    # Setup the environments
+    train_env = create_rail_env(train_env_params, tree_observation)
+    train_env.reset(regenerate_schedule=True, regenerate_rail=True)
+    eval_env = create_rail_env(eval_env_params, tree_observation)
+    eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
+
+    if not train_params.use_fast_tree_observation:
+        # Calculate the state size given the depth of the tree observation and the number of features
+        n_features_per_node = train_env.obs_builder.observation_dim
+        n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
+        state_size = n_features_per_node * n_nodes
+    else:
+        # Calculate the state size given the depth of the tree observation and the number of features
+        state_size = tree_observation.observation_dim
+
+    # Setup renderer
+    if train_params.render:
+        env_renderer = RenderTool(train_env, gl="PGL")
+
+    # The action space of flatland is 5 discrete actions
+    action_size = 5
+
+    action_count = [0] * action_size
+    action_dict = dict()
+    agent_obs = [None] * n_agents
+    agent_prev_obs = [None] * n_agents
+    agent_prev_action = [2] * n_agents
+    update_values = [False] * n_agents
+
+    # Smoothed values used as target for hyperparameter tuning
+    smoothed_eval_normalized_score = -1.0
+    smoothed_eval_completion = 0.0
+
+    scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
+    completion_window = deque(maxlen=checkpoint_interval)
+
+    # Double Dueling DQN policy
+    policy = DDDQNPolicy(state_size, action_size, train_params)
+    # policy = PPOAgent(state_size, action_size, n_agents)
+    # Load existing policy
+    if train_params.load_policy is not "":
+        policy.load(train_params.load_policy)
+
+    # Loads existing replay buffer
+    if restore_replay_buffer:
+        try:
+            policy.load_replay_buffer(restore_replay_buffer)
+            policy.test()
+        except RuntimeError as e:
+            print("\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?")
+            print(e)
+            exit(1)
+
+    print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))
+
+    hdd = psutil.disk_usage('/')
+    if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
+        print(
+            "⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(
+                hdd.free / (2 ** 30)))
+
+    # TensorBoard writer
+    writer = SummaryWriter()
+
+    training_timer = Timer()
+    training_timer.start()
+
+    print(
+        "\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format(
+            train_env.get_num_agents(),
+            x_dim, y_dim,
+            n_episodes,
+            n_eval_episodes,
+            checkpoint_interval,
+            training_id
+        ))
+
+    for episode_idx in range(n_episodes + 1):
+        step_timer = Timer()
+        reset_timer = Timer()
+        learn_timer = Timer()
+        preproc_timer = Timer()
+        inference_timer = Timer()
+
+        # Reset environment
+        reset_timer.start()
+        train_env_params.n_agents = episode_idx % n_agents + 1
+        train_env = create_rail_env(train_env_params, tree_observation)
+        obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
+        policy.reset()
+
+        policy2 = DeadLockAvoidanceAgent(train_env)
+        policy2.reset()
+
+        reset_timer.end()
+
+        if train_params.render:
+            env_renderer.set_new_rail()
+
+        score = 0
+        nb_steps = 0
+        actions_taken = []
+
+        # Build initial agent-specific observations
+        for agent in train_env.get_agent_handles():
+            if tree_observation.check_is_observation_valid(obs[agent]):
+                agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth,
+                                                                               observation_radius=observation_radius)
+                agent_prev_obs[agent] = agent_obs[agent].copy()
+
+        # Max number of steps per episode
+        # This is the official formula used during evaluations
+        # See details in flatland.envs.schedule_generators.sparse_schedule_generator
+        # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+        max_steps = train_env._max_episode_steps
+
+        # Run episode
+        agent_to_learn = 0
+        if train_env.get_num_agents() > 1:
+            agent_to_learn = np.random.choice(train_env.get_num_agents())
+        for step in range(max_steps - 1):
+            inference_timer.start()
+            policy.start_step()
+            policy2.start_step()
+            for agent in train_env.get_agent_handles():
+                if info['action_required'][agent]:
+                    update_values[agent] = True
+
+                    if agent == agent_to_learn:
+                        action = policy.act(agent_obs[agent], eps=eps_start)
+                    else:
+                        action = policy2.act([agent], eps=eps_start)
+                    action_count[action] += 1
+                    actions_taken.append(action)
+                else:
+                    # An action is not required if the train hasn't joined the railway network,
+                    # if it already reached its target, or if is currently malfunctioning.
+                    update_values[agent] = False
+                    action = 0
+                action_dict.update({agent: action})
+            policy.end_step()
+            policy2.end_step()
+            inference_timer.end()
+
+            # Environment step
+            step_timer.start()
+            next_obs, all_rewards, done, info = train_env.step(action_dict)
+
+            for agent in train_env.get_agent_handles():
+                act = action_dict.get(agent, RailEnvActions.DO_NOTHING)
+                if agent_obs[agent][26] == 1:
+                    if act == RailEnvActions.STOP_MOVING:
+                        all_rewards[agent] *= 0.01
+                else:
+                    if act == RailEnvActions.MOVE_LEFT:
+                        all_rewards[agent] *= 0.9
+                    else:
+                        if agent_obs[agent][7] == 0 and agent_obs[agent][8] == 0:
+                            if act == RailEnvActions.MOVE_FORWARD:
+                                all_rewards[agent] *= 0.01
+                if done[agent]:
+                    all_rewards[agent] += 100.0
+
+            step_timer.end()
+
+            # Render an episode at some interval
+            if train_params.render and episode_idx % checkpoint_interval == 0:
+                env_renderer.render_env(
+                    show=True,
+                    frames=False,
+                    show_observations=False,
+                    show_predictions=False
+                )
+
+            # Update replay buffer and train agent
+            for agent in train_env.get_agent_handles():
+                if update_values[agent] or done['__all__']:
+                    # Only learn from timesteps where somethings happened
+                    learn_timer.start()
+                    if agent == agent_to_learn:
+                        policy.step(agent,
+                                    agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
+                                    agent_obs[agent],
+                                    done[agent])
+                    learn_timer.end()
+
+                    agent_prev_obs[agent] = agent_obs[agent].copy()
+                    agent_prev_action[agent] = action_dict[agent]
+
+                # Preprocess the new observations
+                if tree_observation.check_is_observation_valid(next_obs[agent]):
+                    preproc_timer.start()
+                    agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent],
+                                                                                   observation_tree_depth,
+                                                                                   observation_radius=observation_radius)
+                    preproc_timer.end()
+
+                score += all_rewards[agent]
+
+            nb_steps = step
+
+            if done['__all__']:
+                break
+
+        # Epsilon decay
+        eps_start = max(eps_end, eps_decay * eps_start)
+
+        # Collect information about training
+        tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles())
+        completion = tasks_finished / max(1, train_env.get_num_agents())
+        normalized_score = score / (max_steps * train_env.get_num_agents())
+        action_probs = action_count / max(1, np.sum(action_count))
+
+        scores_window.append(normalized_score)
+        completion_window.append(completion)
+        smoothed_normalized_score = np.mean(scores_window)
+        smoothed_completion = np.mean(completion_window)
+
+        # Print logs
+        if episode_idx % checkpoint_interval == 0:
+            policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
+
+            if save_replay_buffer:
+                policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl')
+
+            if train_params.render:
+                env_renderer.close_window()
+
+            # reset action count
+            action_count = [0] * action_size
+
+        print(
+            '\r🚂 Episode {}'
+            '\t 🏆 Score: {:7.3f}'
+            ' Avg: {:7.3f}'
+            '\t 💯 Done: {:6.2f}%'
+            ' Avg: {:6.2f}%'
+            '\t 🎲 Epsilon: {:.3f} '
+            '\t 🔀 Action Probs: {}'.format(
+                episode_idx,
+                normalized_score,
+                smoothed_normalized_score,
+                100 * completion,
+                100 * smoothed_completion,
+                eps_start,
+                format_action_prob(action_probs)
+            ), end=" ")
+
+        # Evaluate policy and log results at some interval
+        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0:
+            scores, completions, nb_steps_eval = eval_policy(eval_env,
+                                                             tree_observation,
+                                                             policy,
+                                                             train_params,
+                                                             obs_params)
+
+            writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_mean", np.mean(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_std", np.std(scores), episode_idx)
+            writer.add_histogram("evaluation/scores", np.array(scores), episode_idx)
+            writer.add_scalar("evaluation/completions_min", np.min(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_max", np.max(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_mean", np.mean(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_std", np.std(completions), episode_idx)
+            writer.add_histogram("evaluation/completions", np.array(completions), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_mean", np.mean(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval), episode_idx)
+            writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx)
+
+            smoothing = 0.9
+            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (
+                    1.0 - smoothing)
+            smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing)
+            writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx)
+            writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx)
+
+        # Save logs to tensorboard
+        writer.add_scalar("training/score", normalized_score, episode_idx)
+        writer.add_scalar("training/smoothed_score", smoothed_normalized_score, episode_idx)
+        writer.add_scalar("training/completion", np.mean(completion), episode_idx)
+        writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx)
+        writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
+        writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx)
+        writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx)
+        writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
+        writer.add_scalar("actions/forward", action_probs[RailEnvActions.MOVE_FORWARD], episode_idx)
+        writer.add_scalar("actions/right", action_probs[RailEnvActions.MOVE_RIGHT], episode_idx)
+        writer.add_scalar("actions/stop", action_probs[RailEnvActions.STOP_MOVING], episode_idx)
+        writer.add_scalar("training/epsilon", eps_start, episode_idx)
+        writer.add_scalar("training/buffer_size", len(policy.memory), episode_idx)
+        writer.add_scalar("training/loss", policy.loss, episode_idx)
+        writer.add_scalar("timer/reset", reset_timer.get(), episode_idx)
+        writer.add_scalar("timer/step", step_timer.get(), episode_idx)
+        writer.add_scalar("timer/learn", learn_timer.get(), episode_idx)
+        writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx)
+        writer.add_scalar("timer/total", training_timer.get_current(), episode_idx)
+
+
+def format_action_prob(action_probs):
+    action_probs = np.round(action_probs, 3)
+    actions = ["↻", "←", "↑", "→", "◼"]
+
+    buffer = ""
+    for action, action_prob in zip(actions, action_probs):
+        buffer += action + " " + "{:.3f}".format(action_prob) + " "
+
+    return buffer
+
+
+def eval_policy(env, tree_observation, policy, train_params, obs_params):
+    n_eval_episodes = train_params.n_evaluation_episodes
+    max_steps = env._max_episode_steps
+    tree_depth = obs_params.observation_tree_depth
+    observation_radius = obs_params.observation_radius
+
+    action_dict = dict()
+    scores = []
+    completions = []
+    nb_steps = []
+
+    for episode_idx in range(n_eval_episodes):
+        agent_obs = [None] * env.get_num_agents()
+        score = 0.0
+
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+        final_step = 0
+
+        for step in range(max_steps - 1):
+            policy.start_step()
+            for agent in env.get_agent_handles():
+                if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                    agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
+                                                                                   observation_radius=observation_radius)
+
+                action = 0
+                if info['action_required'][agent]:
+                    if tree_observation.check_is_observation_valid(agent_obs[agent]):
+                        action = policy.act(agent_obs[agent], eps=0.0)
+                action_dict.update({agent: action})
+            policy.end_step()
+            obs, all_rewards, done, info = env.step(action_dict)
+
+            for agent in env.get_agent_handles():
+                score += all_rewards[agent]
+
+            final_step = step
+
+            if done['__all__']:
+                break
+
+        normalized_score = score / (max_steps * env.get_num_agents())
+        scores.append(normalized_score)
+
+        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
+        completion = tasks_finished / max(1, env.get_num_agents())
+        completions.append(completion)
+
+        nb_steps.append(final_step)
+
+    print("\t✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0))
+
+    return scores, completions, nb_steps
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5400, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1, type=int)
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
+                        type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=1, type=int)
+    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
+    parser.add_argument("--eps_start", help="max exploration", default=0.1, type=float)
+    parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.9998, type=float)
+    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e7), type=int)
+    parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
+    parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
+    parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
+                        type=bool)
+    parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
+    parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
+    parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
+    parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
+    parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
+    parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
+    parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
+    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
+    parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
+    parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
+    parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
+                        action='store_true')
+    parser.add_argument("--max_depth", help="max depth", default=1, type=int)
+
+    training_params = parser.parse_args()
+    env_params = [
+        {
+            # Test_0
+            "n_agents": 5,
+            "x_dim": 25,
+            "y_dim": 25,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 50,
+            "seed": 0
+        },
+        {
+            # Test_1
+            "n_agents": 10,
+            "x_dim": 30,
+            "y_dim": 30,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 100,
+            "seed": 0
+        },
+        {
+            # Test_2
+            "n_agents": 20,
+            "x_dim": 30,
+            "y_dim": 30,
+            "n_cities": 3,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 200,
+            "seed": 0
+        },
+    ]
+
+    obs_params = {
+        "observation_tree_depth": training_params.max_depth,
+        "observation_radius": 10,
+        "observation_max_path_depth": 30
+    }
+
+
+    def check_env_config(id):
+        if id >= len(env_params) or id < 0:
+            print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(
+                len(env_params) - 1))
+            exit(1)
+
+
+    check_env_config(training_params.training_env_config)
+    check_env_config(training_params.evaluation_env_config)
+
+    training_env_params = env_params[training_params.training_env_config]
+    evaluation_env_params = env_params[training_params.evaluation_env_config]
+
+    # FIXME hard-coded for sweep search
+    # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly
+    # training_params.use_fast_tree_observation = True
+
+    print("\nTraining parameters:")
+    pprint(vars(training_params))
+    print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config))
+    pprint(training_env_params)
+    print("\nEvaluation environment parameters (Test_{}):".format(training_params.evaluation_env_config))
+    pprint(evaluation_env_params)
+    print("\nObservation parameters:")
+    pprint(obs_params)
+
+    os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads)
+    train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params),
+                Namespace(**obs_params))
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index b605aa3ddaf43ad1a496e44a3fac367be4bd8234..c7621a62e84d17081018975e100f3b1f64f7ab66 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -1,27 +1,27 @@
-class Policy:
-    def step(self, handle, state, action, reward, next_state, done):
-        raise NotImplementedError
-
-    def act(self, state, eps=0.):
-        raise NotImplementedError
-
-    def save(self, filename):
-        raise NotImplementedError
-
-    def load(self, filename):
-        raise NotImplementedError
-
-    def start_step(self):
-        pass
-
-    def end_step(self):
-        pass
-
-    def load_replay_buffer(self, filename):
-        pass
-
-    def test(self):
-        pass
-
-    def reset(self):
+class Policy:
+    def step(self, handle, state, action, reward, next_state, done):
+        raise NotImplementedError
+
+    def act(self, state, eps=0.):
+        raise NotImplementedError
+
+    def save(self, filename):
+        raise NotImplementedError
+
+    def load(self, filename):
+        raise NotImplementedError
+
+    def start_step(self):
+        pass
+
+    def end_step(self):
+        pass
+
+    def load_replay_buffer(self, filename):
+        pass
+
+    def test(self):
+        pass
+
+    def reset(self):
         pass
\ No newline at end of file
diff --git a/reinforcement_learning/ppo/model.py b/reinforcement_learning/ppo/model.py
index 03d72c9ca4059d45a139305b69ee95f709977e07..51b86ff16691c03f6a754405352bb4cf48e4b914 100644
--- a/reinforcement_learning/ppo/model.py
+++ b/reinforcement_learning/ppo/model.py
@@ -1,20 +1,20 @@
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class PolicyNetwork(nn.Module):
-    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
-        super().__init__()
-        self.fc1 = nn.Linear(state_size, hidsize1)
-        self.fc2 = nn.Linear(hidsize1, hidsize2)
-        # self.fc3 = nn.Linear(hidsize2, hidsize3)
-        self.output = nn.Linear(hidsize2, action_size)
-        self.softmax = nn.Softmax(dim=1)
-        self.bn0 = nn.BatchNorm1d(state_size, affine=False)
-
-    def forward(self, inputs):
-        x = self.bn0(inputs.float())
-        x = F.relu(self.fc1(x))
-        x = F.relu(self.fc2(x))
-        # x = F.relu(self.fc3(x))
-        return self.softmax(self.output(x))
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PolicyNetwork(nn.Module):
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
+        super().__init__()
+        self.fc1 = nn.Linear(state_size, hidsize1)
+        self.fc2 = nn.Linear(hidsize1, hidsize2)
+        # self.fc3 = nn.Linear(hidsize2, hidsize3)
+        self.output = nn.Linear(hidsize2, action_size)
+        self.softmax = nn.Softmax(dim=1)
+        self.bn0 = nn.BatchNorm1d(state_size, affine=False)
+
+    def forward(self, inputs):
+        x = self.bn0(inputs.float())
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        # x = F.relu(self.fc3(x))
+        return self.softmax(self.output(x))
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
index a7431f85201def6f189ccdc6101a89428b598e47..49fe7e6f6c02a20e4ea3d7c6e7a9d3e33cff9742 100644
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ b/reinforcement_learning/ppo/ppo_agent.py
@@ -1,131 +1,131 @@
-import os
-
-import numpy as np
-import torch
-from torch.distributions.categorical import Categorical
-
-from reinforcement_learning.policy import Policy
-from reinforcement_learning.ppo.model import PolicyNetwork
-from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer
-
-BUFFER_SIZE = 128_000
-BATCH_SIZE = 8192
-GAMMA = 0.95
-LR = 0.5e-4
-CLIP_FACTOR = .005
-UPDATE_EVERY = 30
-
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-print("device:", device)
-
-
-class PPOAgent(Policy):
-    def __init__(self, state_size, action_size, num_agents):
-        self.action_size = action_size
-        self.state_size = state_size
-        self.num_agents = num_agents
-        self.policy = PolicyNetwork(state_size, action_size).to(device)
-        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
-        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
-        self.episodes = [Episode() for _ in range(num_agents)]
-        self.memory = ReplayBuffer(BUFFER_SIZE)
-        self.t_step = 0
-        self.loss = 0
-
-    def reset(self):
-        self.finished = [False] * len(self.episodes)
-        self.tot_reward = [0] * self.num_agents
-
-    # Decide on an action to take in the environment
-
-    def act(self, state, eps=None):
-        self.policy.eval()
-        with torch.no_grad():
-            output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
-            ret = Categorical(output).sample().item()
-            return ret
-
-    # Record the results of the agent's action and update the model
-    def step(self, handle, state, action, reward, next_state, done):
-        if not self.finished[handle]:
-            # Push experience into Episode memory
-            self.tot_reward[handle] += reward
-            if done == 1:
-                reward = 1  # self.tot_reward[handle]
-            else:
-                reward = 0
-
-            self.episodes[handle].push(state, action, reward, next_state, done)
-
-            # When we finish the episode, discount rewards and push the experience into replay memory
-            if done:
-                self.episodes[handle].discount_rewards(GAMMA)
-                self.memory.push_episode(self.episodes[handle])
-                self.episodes[handle].reset()
-                self.finished[handle] = True
-
-        # Perform a gradient update every UPDATE_EVERY time steps
-        self.t_step = (self.t_step + 1) % UPDATE_EVERY
-        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
-            self._learn(*self.memory.sample(BATCH_SIZE, device))
-
-    def _clip_gradient(self, model, clip):
-
-        for p in model.parameters():
-            p.grad.data.clamp_(-clip, clip)
-        return
-
-        """Computes a gradient clipping coefficient based on gradient norm."""
-        totalnorm = 0
-        for p in model.parameters():
-            if p.grad is not None:
-                modulenorm = p.grad.data.norm()
-                totalnorm += modulenorm ** 2
-        totalnorm = np.sqrt(totalnorm)
-        coeff = min(1, clip / (totalnorm + 1e-6))
-
-        for p in model.parameters():
-            if p.grad is not None:
-                p.grad.mul_(coeff)
-
-    def _learn(self, states, actions, rewards, next_state, done):
-        self.policy.train()
-
-        responsible_outputs = torch.gather(self.policy(states), 1, actions)
-        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
-
-        # rewards = rewards - rewards.mean()
-        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
-        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
-        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
-        self.loss = loss
-
-        # Compute loss and perform a gradient step
-        self.old_policy.load_state_dict(self.policy.state_dict())
-        self.optimizer.zero_grad()
-        loss.backward()
-        # self._clip_gradient(self.policy, 1.0)
-        self.optimizer.step()
-
-    # Checkpointing methods
-    def save(self, filename):
-        # print("Saving model from checkpoint:", filename)
-        torch.save(self.policy.state_dict(), filename + ".policy")
-        torch.save(self.optimizer.state_dict(), filename + ".optimizer")
-
-    def load(self, filename):
-        print("load policy from file", filename)
-        if os.path.exists(filename + ".policy"):
-            print(' >> ', filename + ".policy")
-            try:
-                self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device))
-            except:
-                print(" >> failed!")
-                pass
-        if os.path.exists(filename + ".optimizer"):
-            print(' >> ', filename + ".optimizer")
-            try:
-                self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device))
-            except:
-                print(" >> failed!")
-                pass
+import os
+
+import numpy as np
+import torch
+from torch.distributions.categorical import Categorical
+
+from reinforcement_learning.policy import Policy
+from reinforcement_learning.ppo.model import PolicyNetwork
+from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer
+
+BUFFER_SIZE = 128_000
+BATCH_SIZE = 8192
+GAMMA = 0.95
+LR = 0.5e-4
+CLIP_FACTOR = .005
+UPDATE_EVERY = 30
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print("device:", device)
+
+
+class PPOAgent(Policy):
+    def __init__(self, state_size, action_size, num_agents):
+        self.action_size = action_size
+        self.state_size = state_size
+        self.num_agents = num_agents
+        self.policy = PolicyNetwork(state_size, action_size).to(device)
+        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
+        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
+        self.episodes = [Episode() for _ in range(num_agents)]
+        self.memory = ReplayBuffer(BUFFER_SIZE)
+        self.t_step = 0
+        self.loss = 0
+
+    def reset(self):
+        self.finished = [False] * len(self.episodes)
+        self.tot_reward = [0] * self.num_agents
+
+    # Decide on an action to take in the environment
+
+    def act(self, state, eps=None):
+        self.policy.eval()
+        with torch.no_grad():
+            output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
+            ret = Categorical(output).sample().item()
+            return ret
+
+    # Record the results of the agent's action and update the model
+    def step(self, handle, state, action, reward, next_state, done):
+        if not self.finished[handle]:
+            # Push experience into Episode memory
+            self.tot_reward[handle] += reward
+            if done == 1:
+                reward = 1  # self.tot_reward[handle]
+            else:
+                reward = 0
+
+            self.episodes[handle].push(state, action, reward, next_state, done)
+
+            # When we finish the episode, discount rewards and push the experience into replay memory
+            if done:
+                self.episodes[handle].discount_rewards(GAMMA)
+                self.memory.push_episode(self.episodes[handle])
+                self.episodes[handle].reset()
+                self.finished[handle] = True
+
+        # Perform a gradient update every UPDATE_EVERY time steps
+        self.t_step = (self.t_step + 1) % UPDATE_EVERY
+        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
+            self._learn(*self.memory.sample(BATCH_SIZE, device))
+
+    def _clip_gradient(self, model, clip):
+
+        for p in model.parameters():
+            p.grad.data.clamp_(-clip, clip)
+        return
+
+        """Computes a gradient clipping coefficient based on gradient norm."""
+        totalnorm = 0
+        for p in model.parameters():
+            if p.grad is not None:
+                modulenorm = p.grad.data.norm()
+                totalnorm += modulenorm ** 2
+        totalnorm = np.sqrt(totalnorm)
+        coeff = min(1, clip / (totalnorm + 1e-6))
+
+        for p in model.parameters():
+            if p.grad is not None:
+                p.grad.mul_(coeff)
+
+    def _learn(self, states, actions, rewards, next_state, done):
+        self.policy.train()
+
+        responsible_outputs = torch.gather(self.policy(states), 1, actions)
+        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
+
+        # rewards = rewards - rewards.mean()
+        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
+        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
+        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
+        self.loss = loss
+
+        # Compute loss and perform a gradient step
+        self.old_policy.load_state_dict(self.policy.state_dict())
+        self.optimizer.zero_grad()
+        loss.backward()
+        # self._clip_gradient(self.policy, 1.0)
+        self.optimizer.step()
+
+    # Checkpointing methods
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        torch.save(self.policy.state_dict(), filename + ".policy")
+        torch.save(self.optimizer.state_dict(), filename + ".optimizer")
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        if os.path.exists(filename + ".policy"):
+            print(' >> ', filename + ".policy")
+            try:
+                self.policy.load_state_dict(torch.load(filename + ".policy", map_location=device))
+            except:
+                print(" >> failed!")
+                pass
+        if os.path.exists(filename + ".optimizer"):
+            print(' >> ', filename + ".optimizer")
+            try:
+                self.optimizer.load_state_dict(torch.load(filename + ".optimizer", map_location=device))
+            except:
+                print(" >> failed!")
+                pass
diff --git a/reinforcement_learning/ppo/replay_memory.py b/reinforcement_learning/ppo/replay_memory.py
index 61a1b81bf3a338e031f3d0441058268206426ce7..3e6619b40169597d7a4b379f4ce2c9ddccd4cd9b 100644
--- a/reinforcement_learning/ppo/replay_memory.py
+++ b/reinforcement_learning/ppo/replay_memory.py
@@ -1,53 +1,53 @@
-import torch
-import random
-import numpy as np
-from collections import namedtuple, deque, Iterable
-
-
-Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done"))
-
-
-class Episode:
-    memory = []
-
-    def reset(self):
-        self.memory = []
-
-    def push(self, *args):
-        self.memory.append(tuple(args))
-
-    def discount_rewards(self, gamma):
-        running_add = 0.
-        for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]:
-            running_add = running_add * gamma + reward
-            self.memory[i] = (state, action, running_add, *rest)
-
-
-class ReplayBuffer:
-    def __init__(self, buffer_size):
-        self.memory = deque(maxlen=buffer_size)
-
-    def push(self, state, action, reward, next_state, done):
-        self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done))
-
-    def push_episode(self, episode):
-        for step in episode.memory:
-            self.push(*step)
-
-    def sample(self, batch_size, device):
-        experiences = random.sample(self.memory, k=batch_size)
-
-        states      = torch.from_numpy(self.stack([e.state      for e in experiences])).float().to(device)
-        actions     = torch.from_numpy(self.stack([e.action     for e in experiences])).long().to(device)
-        rewards     = torch.from_numpy(self.stack([e.reward     for e in experiences])).float().to(device)
-        next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device)
-        dones       = torch.from_numpy(self.stack([e.done       for e in experiences]).astype(np.uint8)).float().to(device)
-
-        return states, actions, rewards, next_states, dones
-
-    def stack(self, states):
-        sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1]
-        return np.reshape(np.array(states), (len(states), *sub_dims))
-
-    def __len__(self):
-        return len(self.memory)
+import torch
+import random
+import numpy as np
+from collections import namedtuple, deque, Iterable
+
+
+Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done"))
+
+
+class Episode:
+    memory = []
+
+    def reset(self):
+        self.memory = []
+
+    def push(self, *args):
+        self.memory.append(tuple(args))
+
+    def discount_rewards(self, gamma):
+        running_add = 0.
+        for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]:
+            running_add = running_add * gamma + reward
+            self.memory[i] = (state, action, running_add, *rest)
+
+
+class ReplayBuffer:
+    def __init__(self, buffer_size):
+        self.memory = deque(maxlen=buffer_size)
+
+    def push(self, state, action, reward, next_state, done):
+        self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done))
+
+    def push_episode(self, episode):
+        for step in episode.memory:
+            self.push(*step)
+
+    def sample(self, batch_size, device):
+        experiences = random.sample(self.memory, k=batch_size)
+
+        states      = torch.from_numpy(self.stack([e.state      for e in experiences])).float().to(device)
+        actions     = torch.from_numpy(self.stack([e.action     for e in experiences])).long().to(device)
+        rewards     = torch.from_numpy(self.stack([e.reward     for e in experiences])).float().to(device)
+        next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device)
+        dones       = torch.from_numpy(self.stack([e.done       for e in experiences]).astype(np.uint8)).float().to(device)
+
+        return states, actions, rewards, next_states, dones
+
+    def stack(self, states):
+        sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1]
+        return np.reshape(np.array(states), (len(states), *sub_dims))
+
+    def __len__(self):
+        return len(self.memory)
diff --git a/run.py b/run.py
index c5e879e45f830b6079bc5e7d1e882146a6f372bc..1094d1bfbe0a5290be7f11f602b37740f68ae928 100644
--- a/run.py
+++ b/run.py
@@ -1,215 +1,215 @@
-import sys
-import time
-from argparse import Namespace
-from pathlib import Path
-
-import numpy as np
-from flatland.core.env_observation_builder import DummyObservationBuilder
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
-from flatland.envs.rail_env import RailEnvActions
-from flatland.evaluators.client import FlatlandRemoteClient
-from flatland.evaluators.client import TimeoutException
-
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-from utils.deadlock_check import check_if_all_blocked
-from utils.fast_tree_obs import FastTreeObs
-
-base_dir = Path(__file__).resolve().parent.parent
-sys.path.append(str(base_dir))
-
-from reinforcement_learning.dddqn_policy import DDDQNPolicy
-
-####################################################
-# EVALUATION PARAMETERS
-
-# Print per-step logs
-VERBOSE = True
-
-# Checkpoint to use (remember to push it!)
-checkpoint = "./checkpoints/201106234244-400.pth"  # 15.64082361736683 Depth 1
-checkpoint = "./checkpoints/201106234900-200.pth"  # 15.64082361736683 Depth 1
-
-# Use last action cache
-USE_ACTION_CACHE = False
-USE_DEAD_LOCK_AVOIDANCE_AGENT = False
-
-# Observation parameters (must match training parameters!)
-observation_tree_depth = 1
-observation_radius = 10
-observation_max_path_depth = 30
-
-####################################################
-
-remote_client = FlatlandRemoteClient()
-
-# Observation builder
-predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-tree_observation = FastTreeObs(max_depth=observation_tree_depth)
-
-# Calculates state and action sizes
-state_size = tree_observation.observation_dim
-action_size = 5
-
-# Creates the policy. No GPU on evaluation server.
-policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
-# policy = PPOAgent(state_size, action_size, 10)
-policy.load(checkpoint)
-
-#####################################################################
-# Main evaluation loop
-#####################################################################
-evaluation_number = 0
-
-while True:
-    evaluation_number += 1
-
-    # We use a dummy observation and call TreeObsForRailEnv ourselves when needed.
-    # This way we decide if we want to calculate the observations or not instead
-    # of having them calculated every time we perform an env step.
-    time_start = time.time()
-    observation, info = remote_client.env_create(
-        obs_builder_object=DummyObservationBuilder()
-    )
-    env_creation_time = time.time() - time_start
-
-    if not observation:
-        # If the remote_client returns False on a `env_create` call,
-        # then it basically means that your agent has already been
-        # evaluated on all the required evaluation environments,
-        # and hence it's safe to break out of the main evaluation loop.
-        break
-
-    print("Env Path : ", remote_client.current_env_path)
-    print("Env Creation Time : ", env_creation_time)
-
-    local_env = remote_client.env
-    nb_agents = len(local_env.agents)
-    max_nb_steps = local_env._max_episode_steps
-
-    tree_observation.set_env(local_env)
-    tree_observation.reset()
-    observation = tree_observation.get_many(list(range(nb_agents)))
-
-    print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
-
-    # Now we enter into another infinite loop where we
-    # compute the actions for all the individual steps in this episode
-    # until the episode is `done`
-    steps = 0
-
-    # Bookkeeping
-    time_taken_by_controller = []
-    time_taken_per_step = []
-
-    # Action cache: keep track of last observation to avoid running the same inferrence multiple times.
-    # This only makes sense for deterministic policies.
-    agent_last_obs = {}
-    agent_last_action = {}
-    nb_hit = 0
-
-    if USE_DEAD_LOCK_AVOIDANCE_AGENT:
-        policy = DeadLockAvoidanceAgent(local_env)
-
-    while True:
-        try:
-            #####################################################################
-            # Evaluation of a single episode
-            #####################################################################
-            steps += 1
-            obs_time, agent_time, step_time = 0.0, 0.0, 0.0
-            no_ops_mode = False
-
-            if not check_if_all_blocked(env=local_env):
-                time_start = time.time()
-                action_dict = {}
-                policy.start_step()
-                if USE_DEAD_LOCK_AVOIDANCE_AGENT:
-                    observation = np.zeros((local_env.get_num_agents(), 2))
-                for agent in range(nb_agents):
-
-                    if USE_DEAD_LOCK_AVOIDANCE_AGENT:
-                        observation[agent][0] = agent
-                        observation[agent][1] = steps
-
-                    if info['action_required'][agent]:
-                        if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
-                            # cache hit
-                            action = agent_last_action[agent]
-                            nb_hit += 1
-                        else:
-                            action = policy.act(observation[agent], eps=0.01)
-                            if observation[agent][26] == 1:
-                                action = RailEnvActions.STOP_MOVING
-
-                        action_dict[agent] = action
-
-                        if USE_ACTION_CACHE:
-                            agent_last_obs[agent] = observation[agent]
-                            agent_last_action[agent] = action
-                policy.end_step()
-                agent_time = time.time() - time_start
-                time_taken_by_controller.append(agent_time)
-
-                time_start = time.time()
-                _, all_rewards, done, info = remote_client.env_step(action_dict)
-                step_time = time.time() - time_start
-                time_taken_per_step.append(step_time)
-
-                time_start = time.time()
-                observation = tree_observation.get_many(list(range(nb_agents)))
-                obs_time = time.time() - time_start
-
-            else:
-                # Fully deadlocked: perform no-ops
-                no_ops_mode = True
-
-                time_start = time.time()
-                _, all_rewards, done, info = remote_client.env_step({})
-                step_time = time.time() - time_start
-                time_taken_per_step.append(step_time)
-
-            nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
-
-            if VERBOSE or done['__all__']:
-                print(
-                    "Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
-                        str(steps).zfill(4),
-                        max_nb_steps,
-                        nb_agents_done,
-                        obs_time,
-                        agent_time,
-                        step_time,
-                        nb_hit,
-                        no_ops_mode
-                    ), end="\r")
-
-            if done['__all__']:
-                # When done['__all__'] == True, then the evaluation of this
-                # particular Env instantiation is complete, and we can break out
-                # of this loop, and move onto the next Env evaluation
-                print()
-                break
-
-        except TimeoutException as err:
-            # A timeout occurs, won't get any reward for this episode :-(
-            # Skip to next episode as further actions in this one will be ignored.
-            # The whole evaluation will be stopped if there are 10 consecutive timeouts.
-            print("Timeout! Will skip this episode and go to the next.", err)
-            break
-
-    np_time_taken_by_controller = np.array(time_taken_by_controller)
-    np_time_taken_per_step = np.array(time_taken_per_step)
-    print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
-          np_time_taken_by_controller.std())
-    print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
-    print("=" * 100)
-
-print("Evaluation of all environments complete!")
-########################################################################
-# Submit your Results
-#
-# Please do not forget to include this call, as this triggers the
-# final computation of the score statistics, video generation, etc
-# and is necessary to have your submission marked as successfully evaluated
-########################################################################
-print(remote_client.submit())
+import sys
+import time
+from argparse import Namespace
+from pathlib import Path
+
+import numpy as np
+from flatland.core.env_observation_builder import DummyObservationBuilder
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnvActions
+from flatland.evaluators.client import FlatlandRemoteClient
+from flatland.evaluators.client import TimeoutException
+
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+from utils.deadlock_check import check_if_all_blocked
+from utils.fast_tree_obs import FastTreeObs
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+
+####################################################
+# EVALUATION PARAMETERS
+
+# Print per-step logs
+VERBOSE = True
+
+# Checkpoint to use (remember to push it!)
+checkpoint = "./checkpoints/201106234244-400.pth"  # 15.64082361736683 Depth 1
+checkpoint = "./checkpoints/201106234900-200.pth"  # 15.64082361736683 Depth 1
+
+# Use last action cache
+USE_ACTION_CACHE = False
+USE_DEAD_LOCK_AVOIDANCE_AGENT = False
+
+# Observation parameters (must match training parameters!)
+observation_tree_depth = 1
+observation_radius = 10
+observation_max_path_depth = 30
+
+####################################################
+
+remote_client = FlatlandRemoteClient()
+
+# Observation builder
+predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+tree_observation = FastTreeObs(max_depth=observation_tree_depth)
+
+# Calculates state and action sizes
+state_size = tree_observation.observation_dim
+action_size = 5
+
+# Creates the policy. No GPU on evaluation server.
+policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
+# policy = PPOAgent(state_size, action_size, 10)
+policy.load(checkpoint)
+
+#####################################################################
+# Main evaluation loop
+#####################################################################
+evaluation_number = 0
+
+while True:
+    evaluation_number += 1
+
+    # We use a dummy observation and call TreeObsForRailEnv ourselves when needed.
+    # This way we decide if we want to calculate the observations or not instead
+    # of having them calculated every time we perform an env step.
+    time_start = time.time()
+    observation, info = remote_client.env_create(
+        obs_builder_object=DummyObservationBuilder()
+    )
+    env_creation_time = time.time() - time_start
+
+    if not observation:
+        # If the remote_client returns False on a `env_create` call,
+        # then it basically means that your agent has already been
+        # evaluated on all the required evaluation environments,
+        # and hence it's safe to break out of the main evaluation loop.
+        break
+
+    print("Env Path : ", remote_client.current_env_path)
+    print("Env Creation Time : ", env_creation_time)
+
+    local_env = remote_client.env
+    nb_agents = len(local_env.agents)
+    max_nb_steps = local_env._max_episode_steps
+
+    tree_observation.set_env(local_env)
+    tree_observation.reset()
+    observation = tree_observation.get_many(list(range(nb_agents)))
+
+    print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
+
+    # Now we enter into another infinite loop where we
+    # compute the actions for all the individual steps in this episode
+    # until the episode is `done`
+    steps = 0
+
+    # Bookkeeping
+    time_taken_by_controller = []
+    time_taken_per_step = []
+
+    # Action cache: keep track of last observation to avoid running the same inferrence multiple times.
+    # This only makes sense for deterministic policies.
+    agent_last_obs = {}
+    agent_last_action = {}
+    nb_hit = 0
+
+    if USE_DEAD_LOCK_AVOIDANCE_AGENT:
+        policy = DeadLockAvoidanceAgent(local_env)
+
+    while True:
+        try:
+            #####################################################################
+            # Evaluation of a single episode
+            #####################################################################
+            steps += 1
+            obs_time, agent_time, step_time = 0.0, 0.0, 0.0
+            no_ops_mode = False
+
+            if not check_if_all_blocked(env=local_env):
+                time_start = time.time()
+                action_dict = {}
+                policy.start_step()
+                if USE_DEAD_LOCK_AVOIDANCE_AGENT:
+                    observation = np.zeros((local_env.get_num_agents(), 2))
+                for agent in range(nb_agents):
+
+                    if USE_DEAD_LOCK_AVOIDANCE_AGENT:
+                        observation[agent][0] = agent
+                        observation[agent][1] = steps
+
+                    if info['action_required'][agent]:
+                        if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
+                            # cache hit
+                            action = agent_last_action[agent]
+                            nb_hit += 1
+                        else:
+                            action = policy.act(observation[agent], eps=0.0)
+                            #if observation[agent][26] == 1:
+                            #    action = RailEnvActions.STOP_MOVING
+
+                        action_dict[agent] = action
+
+                        if USE_ACTION_CACHE:
+                            agent_last_obs[agent] = observation[agent]
+                            agent_last_action[agent] = action
+                policy.end_step()
+                agent_time = time.time() - time_start
+                time_taken_by_controller.append(agent_time)
+
+                time_start = time.time()
+                _, all_rewards, done, info = remote_client.env_step(action_dict)
+                step_time = time.time() - time_start
+                time_taken_per_step.append(step_time)
+
+                time_start = time.time()
+                observation = tree_observation.get_many(list(range(nb_agents)))
+                obs_time = time.time() - time_start
+
+            else:
+                # Fully deadlocked: perform no-ops
+                no_ops_mode = True
+
+                time_start = time.time()
+                _, all_rewards, done, info = remote_client.env_step({})
+                step_time = time.time() - time_start
+                time_taken_per_step.append(step_time)
+
+            nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
+
+            if VERBOSE or done['__all__']:
+                print(
+                    "Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
+                        str(steps).zfill(4),
+                        max_nb_steps,
+                        nb_agents_done,
+                        obs_time,
+                        agent_time,
+                        step_time,
+                        nb_hit,
+                        no_ops_mode
+                    ), end="\r")
+
+            if done['__all__']:
+                # When done['__all__'] == True, then the evaluation of this
+                # particular Env instantiation is complete, and we can break out
+                # of this loop, and move onto the next Env evaluation
+                print()
+                break
+
+        except TimeoutException as err:
+            # A timeout occurs, won't get any reward for this episode :-(
+            # Skip to next episode as further actions in this one will be ignored.
+            # The whole evaluation will be stopped if there are 10 consecutive timeouts.
+            print("Timeout! Will skip this episode and go to the next.", err)
+            break
+
+    np_time_taken_by_controller = np.array(time_taken_by_controller)
+    np_time_taken_per_step = np.array(time_taken_per_step)
+    print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
+          np_time_taken_by_controller.std())
+    print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
+    print("=" * 100)
+
+print("Evaluation of all environments complete!")
+########################################################################
+# Submit your Results
+#
+# Please do not forget to include this call, as this triggers the
+# final computation of the score statistics, video generation, etc
+# and is necessary to have your submission marked as successfully evaluated
+########################################################################
+print(remote_client.submit())
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 700600c337882271eebe519233b473449be3b1ab..1d0b52ccd915f730fa98fe79db9336f66cb70116 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -1,175 +1,175 @@
-from typing import Optional, List
-
-import matplotlib.pyplot as plt
-import numpy as np
-from flatland.core.env_observation_builder import DummyObservationBuilder
-from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
-
-from reinforcement_learning.policy import Policy
-from utils.shortest_distance_walker import ShortestDistanceWalker
-
-
-class DeadlockAvoidanceObservation(DummyObservationBuilder):
-    def __init__(self):
-        self.counter = 0
-
-    def get_many(self, handles: Optional[List[int]] = None) -> bool:
-        self.counter += 1
-        obs = np.ones(len(handles), 2)
-        for handle in handles:
-            obs[handle][0] = handle
-            obs[handle][1] = self.counter
-        return obs
-
-
-class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
-    def __init__(self, env: RailEnv, agent_positions, switches):
-        super().__init__(env)
-        self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
-                                                     self.env.height,
-                                                     self.env.width),
-                                                    dtype=int) - 1
-
-        self.full_shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
-                                                          self.env.height,
-                                                          self.env.width),
-                                                         dtype=int) - 1
-
-        self.agent_positions = agent_positions
-
-        self.opp_agent_map = {}
-        self.same_agent_map = {}
-        self.switches = switches
-
-    def getData(self):
-        return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map
-
-    def callback(self, handle, agent, position, direction, action, possible_transitions):
-        opp_a = self.agent_positions[position]
-        if opp_a != -1 and opp_a != handle:
-            if self.env.agents[opp_a].direction != direction:
-                d = self.opp_agent_map.get(handle, [])
-                if opp_a not in d:
-                    d.append(opp_a)
-                self.opp_agent_map.update({handle: d})
-            else:
-                if len(self.opp_agent_map.get(handle, [])) == 0:
-                    d = self.same_agent_map.get(handle, [])
-                    if opp_a not in d:
-                        d.append(opp_a)
-                    self.same_agent_map.update({handle: d})
-
-        if len(self.opp_agent_map.get(handle, [])) == 0:
-            if self.switches.get(position, None) is None:
-                self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
-        self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
-
-
-class DeadLockAvoidanceAgent(Policy):
-    def __init__(self, env: RailEnv, show_debug_plot=False):
-        self.env = env
-        self.memory = None
-        self.loss = 0
-        self.agent_can_move = {}
-        self.switches = {}
-        self.show_debug_plot = show_debug_plot
-
-    def step(self, state, action, reward, next_state, done):
-        pass
-
-    def act(self, state, eps=0.):
-        # agent = self.env.agents[state[0]]
-        check = self.agent_can_move.get(state[0], None)
-        if check is None:
-            return RailEnvActions.STOP_MOVING
-        return check[3]
-
-    def reset(self):
-        self.agent_positions = None
-        self.shortest_distance_walker = None
-        self.switches = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                pos = (h, w)
-                for dir in range(4):
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    num_transitions = fast_count_nonzero(possible_transitions)
-                    if num_transitions > 1:
-                        if pos not in self.switches.keys():
-                            self.switches.update({pos: [dir]})
-                        else:
-                            self.switches[pos].append(dir)
-
-    def start_step(self):
-        self.build_agent_position_map()
-        self.shortest_distance_mapper()
-        self.extract_agent_can_move()
-
-    def end_step(self):
-        pass
-
-    def get_actions(self):
-        pass
-
-    def build_agent_position_map(self):
-        # build map with agent positions (only active agents)
-        self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
-        for handle in range(self.env.get_num_agents()):
-            agent = self.env.agents[handle]
-            if agent.status == RailAgentStatus.ACTIVE:
-                if agent.position is not None:
-                    self.agent_positions[agent.position] = handle
-
-    def shortest_distance_mapper(self):
-        self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env,
-                                                                                self.agent_positions,
-                                                                                self.switches)
-        for handle in range(self.env.get_num_agents()):
-            agent = self.env.agents[handle]
-            if agent.status <= RailAgentStatus.ACTIVE:
-                self.shortest_distance_walker.walk_to_target(handle)
-
-    def extract_agent_can_move(self):
-        self.agent_can_move = {}
-        shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData()
-        for handle in range(self.env.get_num_agents()):
-            agent = self.env.agents[handle]
-            if agent.status < RailAgentStatus.DONE:
-                next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle],
-                                                         self.shortest_distance_walker.same_agent_map.get(handle, []),
-                                                         self.shortest_distance_walker.opp_agent_map.get(handle, []),
-                                                         full_shortest_distance_agent_map)
-                if next_step_ok:
-                    next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle)
-                    self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]})
-
-        if self.show_debug_plot:
-            a = np.floor(np.sqrt(self.env.get_num_agents()))
-            b = np.ceil(self.env.get_num_agents() / a)
-            for handle in range(self.env.get_num_agents()):
-                plt.subplot(a, b, handle + 1)
-                plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle])
-            plt.show(block=False)
-            plt.pause(0.01)
-
-    def check_agent_can_move(self,
-                             my_shortest_walking_path,
-                             same_agents,
-                             opp_agents,
-                             full_shortest_distance_agent_map):
-        agent_positions_map = (self.agent_positions > -1).astype(int)
-        delta = my_shortest_walking_path
-        next_step_ok = True
-        for opp_a in opp_agents:
-            opp = full_shortest_distance_agent_map[opp_a]
-            delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int)
-            if np.sum(delta) < (3 + len(opp_agents)):
-                next_step_ok = False
-        return next_step_ok
-
-    def save(self, filename):
-        pass
-
-    def load(self, filename):
-        pass
+from typing import Optional, List
+
+import matplotlib.pyplot as plt
+import numpy as np
+from flatland.core.env_observation_builder import DummyObservationBuilder
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
+
+from reinforcement_learning.policy import Policy
+from utils.shortest_distance_walker import ShortestDistanceWalker
+
+
+class DeadlockAvoidanceObservation(DummyObservationBuilder):
+    def __init__(self):
+        self.counter = 0
+
+    def get_many(self, handles: Optional[List[int]] = None) -> bool:
+        self.counter += 1
+        obs = np.ones(len(handles), 2)
+        for handle in handles:
+            obs[handle][0] = handle
+            obs[handle][1] = self.counter
+        return obs
+
+
+class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
+    def __init__(self, env: RailEnv, agent_positions, switches):
+        super().__init__(env)
+        self.shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
+                                                     self.env.height,
+                                                     self.env.width),
+                                                    dtype=int) - 1
+
+        self.full_shortest_distance_agent_map = np.zeros((self.env.get_num_agents(),
+                                                          self.env.height,
+                                                          self.env.width),
+                                                         dtype=int) - 1
+
+        self.agent_positions = agent_positions
+
+        self.opp_agent_map = {}
+        self.same_agent_map = {}
+        self.switches = switches
+
+    def getData(self):
+        return self.shortest_distance_agent_map, self.full_shortest_distance_agent_map
+
+    def callback(self, handle, agent, position, direction, action, possible_transitions):
+        opp_a = self.agent_positions[position]
+        if opp_a != -1 and opp_a != handle:
+            if self.env.agents[opp_a].direction != direction:
+                d = self.opp_agent_map.get(handle, [])
+                if opp_a not in d:
+                    d.append(opp_a)
+                self.opp_agent_map.update({handle: d})
+            else:
+                if len(self.opp_agent_map.get(handle, [])) == 0:
+                    d = self.same_agent_map.get(handle, [])
+                    if opp_a not in d:
+                        d.append(opp_a)
+                    self.same_agent_map.update({handle: d})
+
+        if len(self.opp_agent_map.get(handle, [])) == 0:
+            if self.switches.get(position, None) is None:
+                self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
+        self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
+
+
+class DeadLockAvoidanceAgent(Policy):
+    def __init__(self, env: RailEnv, show_debug_plot=False):
+        self.env = env
+        self.memory = None
+        self.loss = 0
+        self.agent_can_move = {}
+        self.switches = {}
+        self.show_debug_plot = show_debug_plot
+
+    def step(self, state, action, reward, next_state, done):
+        pass
+
+    def act(self, state, eps=0.):
+        # agent = self.env.agents[state[0]]
+        check = self.agent_can_move.get(state[0], None)
+        if check is None:
+            return RailEnvActions.STOP_MOVING
+        return check[3]
+
+    def reset(self):
+        self.agent_positions = None
+        self.shortest_distance_walker = None
+        self.switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in self.switches.keys():
+                            self.switches.update({pos: [dir]})
+                        else:
+                            self.switches[pos].append(dir)
+
+    def start_step(self):
+        self.build_agent_position_map()
+        self.shortest_distance_mapper()
+        self.extract_agent_can_move()
+
+    def end_step(self):
+        pass
+
+    def get_actions(self):
+        pass
+
+    def build_agent_position_map(self):
+        # build map with agent positions (only active agents)
+        self.agent_positions = np.zeros((self.env.height, self.env.width), dtype=int) - 1
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status == RailAgentStatus.ACTIVE:
+                if agent.position is not None:
+                    self.agent_positions[agent.position] = handle
+
+    def shortest_distance_mapper(self):
+        self.shortest_distance_walker = DeadlockAvoidanceShortestDistanceWalker(self.env,
+                                                                                self.agent_positions,
+                                                                                self.switches)
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status <= RailAgentStatus.ACTIVE:
+                self.shortest_distance_walker.walk_to_target(handle)
+
+    def extract_agent_can_move(self):
+        self.agent_can_move = {}
+        shortest_distance_agent_map, full_shortest_distance_agent_map = self.shortest_distance_walker.getData()
+        for handle in range(self.env.get_num_agents()):
+            agent = self.env.agents[handle]
+            if agent.status < RailAgentStatus.DONE:
+                next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle],
+                                                         self.shortest_distance_walker.same_agent_map.get(handle, []),
+                                                         self.shortest_distance_walker.opp_agent_map.get(handle, []),
+                                                         full_shortest_distance_agent_map)
+                if next_step_ok:
+                    next_position, next_direction, action, _ = self.shortest_distance_walker.walk_one_step(handle)
+                    self.agent_can_move.update({handle: [next_position[0], next_position[1], next_direction, action]})
+
+        if self.show_debug_plot:
+            a = np.floor(np.sqrt(self.env.get_num_agents()))
+            b = np.ceil(self.env.get_num_agents() / a)
+            for handle in range(self.env.get_num_agents()):
+                plt.subplot(a, b, handle + 1)
+                plt.imshow(full_shortest_distance_agent_map[handle] + shortest_distance_agent_map[handle])
+            plt.show(block=False)
+            plt.pause(0.01)
+
+    def check_agent_can_move(self,
+                             my_shortest_walking_path,
+                             same_agents,
+                             opp_agents,
+                             full_shortest_distance_agent_map):
+        agent_positions_map = (self.agent_positions > -1).astype(int)
+        delta = my_shortest_walking_path
+        next_step_ok = True
+        for opp_a in opp_agents:
+            opp = full_shortest_distance_agent_map[opp_a]
+            delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int)
+            if np.sum(delta) < (3 + len(opp_agents)):
+                next_step_ok = False
+        return next_step_ok
+
+    def save(self, filename):
+        pass
+
+    def load(self, filename):
+        pass
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
index b6d673d997b6e43ebdba22578f8a0a7e21831000..3b14b0f7be2379bc472a23122e982336a4421106 100755
--- a/utils/fast_tree_obs.py
+++ b/utils/fast_tree_obs.py
@@ -1,308 +1,308 @@
-import numpy as np
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
-
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-
-"""
-LICENCE for the FastTreeObs Observation Builder  
-
-The observation can be used freely and reused for further submissions. Only the author needs to be referred to
-/mentioned in any submissions - if the entire observation or parts, or the main idea is used.
-
-Author: Adrian Egli (adrian.egli@gmail.com)
-
-[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2)
-[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/)
-"""
-
-
-class FastTreeObs(ObservationBuilder):
-
-    def __init__(self, max_depth):
-        self.max_depth = max_depth
-        self.observation_dim = 27
-
-    def build_data(self):
-        if self.env is not None:
-            self.env.dev_obs_dict = {}
-        self.switches = {}
-        self.switches_neighbours = {}
-        self.debug_render_list = []
-        self.debug_render_path_list = []
-        if self.env is not None:
-            self.find_all_cell_where_agent_can_choose()
-            self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env)
-        else:
-            self.dead_lock_avoidance_agent = None
-
-    def find_all_cell_where_agent_can_choose(self):
-        switches = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                pos = (h, w)
-                for dir in range(4):
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    num_transitions = fast_count_nonzero(possible_transitions)
-                    if num_transitions > 1:
-                        if pos not in switches.keys():
-                            switches.update({pos: [dir]})
-                        else:
-                            switches[pos].append(dir)
-
-        switches_neighbours = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                # look one step forward
-                for dir in range(4):
-                    pos = (h, w)
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    for d in range(4):
-                        if possible_transitions[d] == 1:
-                            new_cell = get_new_position(pos, d)
-                            if new_cell in switches.keys() and pos not in switches.keys():
-                                if pos not in switches_neighbours.keys():
-                                    switches_neighbours.update({pos: [dir]})
-                                else:
-                                    switches_neighbours[pos].append(dir)
-
-        self.switches = switches
-        self.switches_neighbours = switches_neighbours
-
-    def check_agent_decision(self, position, direction):
-        switches = self.switches
-        switches_neighbours = self.switches_neighbours
-        agents_on_switch = False
-        agents_on_switch_all = False
-        agents_near_to_switch = False
-        agents_near_to_switch_all = False
-        if position in switches.keys():
-            agents_on_switch = direction in switches[position]
-            agents_on_switch_all = True
-
-        if position in switches_neighbours.keys():
-            new_cell = get_new_position(position, direction)
-            if new_cell in switches.keys():
-                if not direction in switches[new_cell]:
-                    agents_near_to_switch = direction in switches_neighbours[position]
-            else:
-                agents_near_to_switch = direction in switches_neighbours[position]
-
-            agents_near_to_switch_all = direction in switches_neighbours[position]
-
-        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
-
-    def required_agent_decision(self):
-        agents_can_choose = {}
-        agents_on_switch = {}
-        agents_on_switch_all = {}
-        agents_near_to_switch = {}
-        agents_near_to_switch_all = {}
-        for a in range(self.env.get_num_agents()):
-            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
-                self.check_agent_decision(
-                    self.env.agents[a].position,
-                    self.env.agents[a].direction)
-            agents_on_switch.update({a: ret_agents_on_switch})
-            agents_on_switch_all.update({a: ret_agents_on_switch_all})
-            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
-            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
-
-            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
-
-            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
-
-        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
-
-    def debug_render(self, env_renderer):
-        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
-            self.required_agent_decision()
-        self.env.dev_obs_dict = {}
-        for a in range(max(3, self.env.get_num_agents())):
-            self.env.dev_obs_dict.update({a: []})
-
-        selected_agent = None
-        if agents_can_choose[0]:
-            if self.env.agents[0].position is not None:
-                self.debug_render_list.append(self.env.agents[0].position)
-            else:
-                self.debug_render_list.append(self.env.agents[0].initial_position)
-
-        if self.env.agents[0].position is not None:
-            self.debug_render_path_list.append(self.env.agents[0].position)
-        else:
-            self.debug_render_path_list.append(self.env.agents[0].initial_position)
-
-        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
-        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
-        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
-        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
-
-        self.env.dev_obs_dict[0] = self.debug_render_list
-        self.env.dev_obs_dict[1] = self.switches.keys()
-        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
-        self.env.dev_obs_dict[3] = self.debug_render_path_list
-
-    def reset(self):
-        self.build_data()
-        return
-
-    def fast_argmax(self, array):
-        if array[0] == 1:
-            return 0
-        if array[1] == 1:
-            return 1
-        if array[2] == 1:
-            return 2
-        return 3
-
-    def _explore(self, handle, new_position, new_direction, depth=0):
-        has_opp_agent = 0
-        has_same_agent = 0
-        has_switch = 0
-        visited = []
-
-        # stop exploring (max_depth reached)
-        if depth >= self.max_depth:
-            return has_opp_agent, has_same_agent, has_switch, visited
-
-        # max_explore_steps = 100
-        cnt = 0
-        while cnt < 100:
-            cnt += 1
-
-            visited.append(new_position)
-            opp_a = self.env.agent_positions[new_position]
-            if opp_a != -1 and opp_a != handle:
-                if self.env.agents[opp_a].direction != new_direction:
-                    # opp agent found
-                    has_opp_agent = 1
-                    return has_opp_agent, has_same_agent, has_switch, visited
-                else:
-                    has_same_agent = 1
-                    return has_opp_agent, has_same_agent, has_switch, visited
-
-            # convert one-hot encoding to 0,1,2,3
-            agents_on_switch, \
-            agents_near_to_switch, \
-            agents_near_to_switch_all, \
-            agents_on_switch_all = \
-                self.check_agent_decision(new_position, new_direction)
-            if agents_near_to_switch:
-                return has_opp_agent, has_same_agent, has_switch, visited
-
-            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
-            if agents_on_switch:
-                f = 0
-                for dir_loop in range(4):
-                    if possible_transitions[dir_loop] == 1:
-                        f += 1
-                        hoa, hsa, hs, v = self._explore(handle,
-                                                        get_new_position(new_position, dir_loop),
-                                                        dir_loop,
-                                                        depth + 1)
-                        visited.append(v)
-                        has_opp_agent += hoa
-                        has_same_agent += hsa
-                        has_switch += hs
-                f = max(f, 1.0)
-                return has_opp_agent / f, has_same_agent / f, has_switch / f, visited
-            else:
-                new_direction = fast_argmax(possible_transitions)
-                new_position = get_new_position(new_position, new_direction)
-
-        return has_opp_agent, has_same_agent, has_switch, visited
-
-    def get(self, handle):
-        # all values are [0,1]
-        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
-        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
-        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
-        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
-        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
-        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
-        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
-        # observation[7]  : current agent is located at a switch, where it can take a routing decision
-        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
-        # observation[9]  : current agent is located one step before/after a switch
-        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
-        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
-        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
-        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
-        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
-        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
-        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
-        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
-        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
-        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
-        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
-        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
-        # observation[22] : If there is a switch on the path which agent can not use -> 1
-        # observation[23] : If there is a switch on the path which agent can not use -> 1
-        # observation[24] : If there is a switch on the path which agent can not use -> 1
-        # observation[25] : If there is a switch on the path which agent can not use -> 1
-        # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
-
-        if handle == 0:
-            self.dead_lock_avoidance_agent.start_step()
-
-        observation = np.zeros(self.observation_dim)
-        visited = []
-        agent = self.env.agents[handle]
-
-        agent_done = False
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            agent_virtual_position = agent.initial_position
-            observation[4] = 1
-        elif agent.status == RailAgentStatus.ACTIVE:
-            agent_virtual_position = agent.position
-            observation[5] = 1
-        else:
-            observation[6] = 1
-            agent_virtual_position = (-1, -1)
-            agent_done = True
-
-        if not agent_done:
-            visited.append(agent_virtual_position)
-            distance_map = self.env.distance_map.get()
-            current_cell_dist = distance_map[handle,
-                                             agent_virtual_position[0], agent_virtual_position[1],
-                                             agent.direction]
-            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
-            orientation = agent.direction
-            if fast_count_nonzero(possible_transitions) == 1:
-                orientation = fast_argmax(possible_transitions)
-
-            for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
-                if possible_transitions[branch_direction]:
-                    new_position = get_new_position(agent_virtual_position, branch_direction)
-                    new_cell_dist = distance_map[handle,
-                                                 new_position[0], new_position[1],
-                                                 branch_direction]
-                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
-                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
-
-                    has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
-                    visited.append(v)
-
-                    observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist))
-                    observation[14 + dir_loop] = has_opp_agent
-                    observation[18 + dir_loop] = has_same_agent
-                    observation[22 + dir_loop] = has_switch
-
-        agents_on_switch, \
-        agents_near_to_switch, \
-        agents_near_to_switch_all, \
-        agents_on_switch_all = \
-            self.check_agent_decision(agent_virtual_position, agent.direction)
-        observation[7] = int(agents_on_switch)
-        observation[8] = int(agents_near_to_switch)
-        observation[9] = int(agents_near_to_switch_all)
-
-        action = self.dead_lock_avoidance_agent.act([handle], 0.0)
-        observation[26] = int(action == RailEnvActions.STOP_MOVING)
-        self.env.dev_obs_dict.update({handle: visited})
-
-        return observation
+import numpy as np
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
+
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+
+"""
+LICENCE for the FastTreeObs Observation Builder  
+
+The observation can be used freely and reused for further submissions. Only the author needs to be referred to
+/mentioned in any submissions - if the entire observation or parts, or the main idea is used.
+
+Author: Adrian Egli (adrian.egli@gmail.com)
+
+[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2)
+[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/)
+"""
+
+
+class FastTreeObs(ObservationBuilder):
+
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
+        self.observation_dim = 27
+
+    def build_data(self):
+        if self.env is not None:
+            self.env.dev_obs_dict = {}
+        self.switches = {}
+        self.switches_neighbours = {}
+        self.debug_render_list = []
+        self.debug_render_path_list = []
+        if self.env is not None:
+            self.find_all_cell_where_agent_can_choose()
+            self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env)
+        else:
+            self.dead_lock_avoidance_agent = None
+
+    def find_all_cell_where_agent_can_choose(self):
+        switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in switches.keys():
+                            switches.update({pos: [dir]})
+                        else:
+                            switches[pos].append(dir)
+
+        switches_neighbours = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                # look one step forward
+                for dir in range(4):
+                    pos = (h, w)
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    for d in range(4):
+                        if possible_transitions[d] == 1:
+                            new_cell = get_new_position(pos, d)
+                            if new_cell in switches.keys() and pos not in switches.keys():
+                                if pos not in switches_neighbours.keys():
+                                    switches_neighbours.update({pos: [dir]})
+                                else:
+                                    switches_neighbours[pos].append(dir)
+
+        self.switches = switches
+        self.switches_neighbours = switches_neighbours
+
+    def check_agent_decision(self, position, direction):
+        switches = self.switches
+        switches_neighbours = self.switches_neighbours
+        agents_on_switch = False
+        agents_on_switch_all = False
+        agents_near_to_switch = False
+        agents_near_to_switch_all = False
+        if position in switches.keys():
+            agents_on_switch = direction in switches[position]
+            agents_on_switch_all = True
+
+        if position in switches_neighbours.keys():
+            new_cell = get_new_position(position, direction)
+            if new_cell in switches.keys():
+                if not direction in switches[new_cell]:
+                    agents_near_to_switch = direction in switches_neighbours[position]
+            else:
+                agents_near_to_switch = direction in switches_neighbours[position]
+
+            agents_near_to_switch_all = direction in switches_neighbours[position]
+
+        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def required_agent_decision(self):
+        agents_can_choose = {}
+        agents_on_switch = {}
+        agents_on_switch_all = {}
+        agents_near_to_switch = {}
+        agents_near_to_switch_all = {}
+        for a in range(self.env.get_num_agents()):
+            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
+                self.check_agent_decision(
+                    self.env.agents[a].position,
+                    self.env.agents[a].direction)
+            agents_on_switch.update({a: ret_agents_on_switch})
+            agents_on_switch_all.update({a: ret_agents_on_switch_all})
+            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
+            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
+
+            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
+
+            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
+
+        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def debug_render(self, env_renderer):
+        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
+            self.required_agent_decision()
+        self.env.dev_obs_dict = {}
+        for a in range(max(3, self.env.get_num_agents())):
+            self.env.dev_obs_dict.update({a: []})
+
+        selected_agent = None
+        if agents_can_choose[0]:
+            if self.env.agents[0].position is not None:
+                self.debug_render_list.append(self.env.agents[0].position)
+            else:
+                self.debug_render_list.append(self.env.agents[0].initial_position)
+
+        if self.env.agents[0].position is not None:
+            self.debug_render_path_list.append(self.env.agents[0].position)
+        else:
+            self.debug_render_path_list.append(self.env.agents[0].initial_position)
+
+        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
+        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
+        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
+        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
+
+        self.env.dev_obs_dict[0] = self.debug_render_list
+        self.env.dev_obs_dict[1] = self.switches.keys()
+        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
+        self.env.dev_obs_dict[3] = self.debug_render_path_list
+
+    def reset(self):
+        self.build_data()
+        return
+
+    def fast_argmax(self, array):
+        if array[0] == 1:
+            return 0
+        if array[1] == 1:
+            return 1
+        if array[2] == 1:
+            return 2
+        return 3
+
+    def _explore(self, handle, new_position, new_direction, depth=0):
+        has_opp_agent = 0
+        has_same_agent = 0
+        has_switch = 0
+        visited = []
+
+        # stop exploring (max_depth reached)
+        if depth >= self.max_depth:
+            return has_opp_agent, has_same_agent, has_switch, visited
+
+        # max_explore_steps = 100
+        cnt = 0
+        while cnt < 100:
+            cnt += 1
+
+            visited.append(new_position)
+            opp_a = self.env.agent_positions[new_position]
+            if opp_a != -1 and opp_a != handle:
+                if self.env.agents[opp_a].direction != new_direction:
+                    # opp agent found
+                    has_opp_agent = 1
+                    return has_opp_agent, has_same_agent, has_switch, visited
+                else:
+                    has_same_agent = 1
+                    return has_opp_agent, has_same_agent, has_switch, visited
+
+            # convert one-hot encoding to 0,1,2,3
+            agents_on_switch, \
+            agents_near_to_switch, \
+            agents_near_to_switch_all, \
+            agents_on_switch_all = \
+                self.check_agent_decision(new_position, new_direction)
+            if agents_near_to_switch:
+                return has_opp_agent, has_same_agent, has_switch, visited
+
+            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
+            if agents_on_switch:
+                f = 0
+                for dir_loop in range(4):
+                    if possible_transitions[dir_loop] == 1:
+                        f += 1
+                        hoa, hsa, hs, v = self._explore(handle,
+                                                        get_new_position(new_position, dir_loop),
+                                                        dir_loop,
+                                                        depth + 1)
+                        visited.append(v)
+                        has_opp_agent += hoa
+                        has_same_agent += hsa
+                        has_switch += hs
+                f = max(f, 1.0)
+                return has_opp_agent / f, has_same_agent / f, has_switch / f, visited
+            else:
+                new_direction = fast_argmax(possible_transitions)
+                new_position = get_new_position(new_position, new_direction)
+
+        return has_opp_agent, has_same_agent, has_switch, visited
+
+    def get(self, handle):
+        # all values are [0,1]
+        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
+        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
+        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
+        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
+        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
+        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
+        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
+        # observation[7]  : current agent is located at a switch, where it can take a routing decision
+        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
+        # observation[9]  : current agent is located one step before/after a switch
+        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
+        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
+        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
+        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
+        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
+        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
+        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
+        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
+        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
+        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
+        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
+        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
+        # observation[22] : If there is a switch on the path which agent can not use -> 1
+        # observation[23] : If there is a switch on the path which agent can not use -> 1
+        # observation[24] : If there is a switch on the path which agent can not use -> 1
+        # observation[25] : If there is a switch on the path which agent can not use -> 1
+        # observation[26] : If there the dead-lock avoidance agent predicts a deadlock -> 1
+
+        if handle == 0:
+            self.dead_lock_avoidance_agent.start_step()
+
+        observation = np.zeros(self.observation_dim)
+        visited = []
+        agent = self.env.agents[handle]
+
+        agent_done = False
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+            observation[4] = 1
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+            observation[5] = 1
+        else:
+            observation[6] = 1
+            agent_virtual_position = (-1, -1)
+            agent_done = True
+
+        if not agent_done:
+            visited.append(agent_virtual_position)
+            distance_map = self.env.distance_map.get()
+            current_cell_dist = distance_map[handle,
+                                             agent_virtual_position[0], agent_virtual_position[1],
+                                             agent.direction]
+            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
+            orientation = agent.direction
+            if fast_count_nonzero(possible_transitions) == 1:
+                orientation = fast_argmax(possible_transitions)
+
+            for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
+                if possible_transitions[branch_direction]:
+                    new_position = get_new_position(agent_virtual_position, branch_direction)
+                    new_cell_dist = distance_map[handle,
+                                                 new_position[0], new_position[1],
+                                                 branch_direction]
+                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
+                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
+
+                    has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
+                    visited.append(v)
+
+                    observation[10 + dir_loop] = int(not np.math.isinf(new_cell_dist))
+                    observation[14 + dir_loop] = has_opp_agent
+                    observation[18 + dir_loop] = has_same_agent
+                    observation[22 + dir_loop] = has_switch
+
+        agents_on_switch, \
+        agents_near_to_switch, \
+        agents_near_to_switch_all, \
+        agents_on_switch_all = \
+            self.check_agent_decision(agent_virtual_position, agent.direction)
+        observation[7] = int(agents_on_switch)
+        observation[8] = int(agents_near_to_switch)
+        observation[9] = int(agents_near_to_switch_all)
+
+        action = self.dead_lock_avoidance_agent.act([handle], 0.0)
+        observation[26] = int(action == RailEnvActions.STOP_MOVING)
+        self.env.dev_obs_dict.update({handle: visited})
+
+        return observation
diff --git a/utils/shortest_distance_walker.py b/utils/shortest_distance_walker.py
index 0b8121f46e681ebc37ea3d1afb6b4023d33f2e14..62b686fff0f61a13c48220565553d7e63067739a 100644
--- a/utils/shortest_distance_walker.py
+++ b/utils/shortest_distance_walker.py
@@ -1,87 +1,87 @@
-import numpy as np
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.rail_env import RailEnv, RailEnvActions
-from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
-
-
-class ShortestDistanceWalker:
-    def __init__(self, env: RailEnv):
-        self.env = env
-
-    def walk(self, handle, position, direction):
-        possible_transitions = self.env.rail.get_transitions(*position, direction)
-        num_transitions = fast_count_nonzero(possible_transitions)
-        if num_transitions == 1:
-            new_direction = fast_argmax(possible_transitions)
-            new_position = get_new_position(position, new_direction)
-
-            dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
-            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions
-        else:
-            min_distances = []
-            positions = []
-            directions = []
-            for new_direction in [(direction + i) % 4 for i in range(-1, 2)]:
-                if possible_transitions[new_direction]:
-                    new_position = get_new_position(position, new_direction)
-                    min_distances.append(
-                        self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction])
-                    positions.append(new_position)
-                    directions.append(new_direction)
-                else:
-                    min_distances.append(np.inf)
-                    positions.append(None)
-                    directions.append(None)
-
-        a = self.get_action(handle, min_distances)
-        return positions[a], directions[a], min_distances[a], a + 1, possible_transitions
-
-    def get_action(self, handle, min_distances):
-        return np.argmin(min_distances)
-
-    def callback(self, handle, agent, position, direction, action, possible_transitions):
-        pass
-
-    def get_agent_position_and_direction(self, handle):
-        agent = self.env.agents[handle]
-        if agent.position is not None:
-            position = agent.position
-        else:
-            position = agent.initial_position
-        direction = agent.direction
-        return position, direction
-
-    def walk_to_target(self, handle, position=None, direction=None, max_step=500):
-        if position is None and direction is None:
-            position, direction = self.get_agent_position_and_direction(handle)
-        elif position is None:
-            position, _ = self.get_agent_position_and_direction(handle)
-        elif direction is None:
-            _, direction = self.get_agent_position_and_direction(handle)
-
-        agent = self.env.agents[handle]
-        step = 0
-        while (position != agent.target) and (step < max_step):
-            position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)
-            if position is None:
-                break
-            self.callback(handle, agent, position, direction, action, possible_transitions)
-            step += 1
-
-    def callback_one_step(self, handle, agent, position, direction, action, possible_transitions):
-        pass
-
-    def walk_one_step(self, handle):
-        agent = self.env.agents[handle]
-        if agent.position is not None:
-            position = agent.position
-        else:
-            position = agent.initial_position
-        direction = agent.direction
-        possible_transitions = (0, 1, 0, 0)
-        if (position != agent.target):
-            new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction)
-            if new_position is None:
-                return position, direction, RailEnvActions.STOP_MOVING, possible_transitions
-            self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions)
-        return new_position, new_direction, action, possible_transitions
+import numpy as np
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax
+
+
+class ShortestDistanceWalker:
+    def __init__(self, env: RailEnv):
+        self.env = env
+
+    def walk(self, handle, position, direction):
+        possible_transitions = self.env.rail.get_transitions(*position, direction)
+        num_transitions = fast_count_nonzero(possible_transitions)
+        if num_transitions == 1:
+            new_direction = fast_argmax(possible_transitions)
+            new_position = get_new_position(position, new_direction)
+
+            dist = self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction]
+            return new_position, new_direction, dist, RailEnvActions.MOVE_FORWARD, possible_transitions
+        else:
+            min_distances = []
+            positions = []
+            directions = []
+            for new_direction in [(direction + i) % 4 for i in range(-1, 2)]:
+                if possible_transitions[new_direction]:
+                    new_position = get_new_position(position, new_direction)
+                    min_distances.append(
+                        self.env.distance_map.get()[handle, new_position[0], new_position[1], new_direction])
+                    positions.append(new_position)
+                    directions.append(new_direction)
+                else:
+                    min_distances.append(np.inf)
+                    positions.append(None)
+                    directions.append(None)
+
+        a = self.get_action(handle, min_distances)
+        return positions[a], directions[a], min_distances[a], a + 1, possible_transitions
+
+    def get_action(self, handle, min_distances):
+        return np.argmin(min_distances)
+
+    def callback(self, handle, agent, position, direction, action, possible_transitions):
+        pass
+
+    def get_agent_position_and_direction(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        return position, direction
+
+    def walk_to_target(self, handle, position=None, direction=None, max_step=500):
+        if position is None and direction is None:
+            position, direction = self.get_agent_position_and_direction(handle)
+        elif position is None:
+            position, _ = self.get_agent_position_and_direction(handle)
+        elif direction is None:
+            _, direction = self.get_agent_position_and_direction(handle)
+
+        agent = self.env.agents[handle]
+        step = 0
+        while (position != agent.target) and (step < max_step):
+            position, direction, dist, action, possible_transitions = self.walk(handle, position, direction)
+            if position is None:
+                break
+            self.callback(handle, agent, position, direction, action, possible_transitions)
+            step += 1
+
+    def callback_one_step(self, handle, agent, position, direction, action, possible_transitions):
+        pass
+
+    def walk_one_step(self, handle):
+        agent = self.env.agents[handle]
+        if agent.position is not None:
+            position = agent.position
+        else:
+            position = agent.initial_position
+        direction = agent.direction
+        possible_transitions = (0, 1, 0, 0)
+        if (position != agent.target):
+            new_position, new_direction, dist, action, possible_transitions = self.walk(handle, position, direction)
+            if new_position is None:
+                return position, direction, RailEnvActions.STOP_MOVING, possible_transitions
+            self.callback_one_step(handle, agent, new_position, new_direction, action, possible_transitions)
+        return new_position, new_direction, action, possible_transitions