diff --git a/.gitignore b/.gitignore
index da214e57a3cb40bddeb8e0e2d0b518b3de06f20e..981e6a55487ba3d561923a03601cdec88b214c3b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -125,3 +125,8 @@ dmypy.json
 
 scratch/test-envs/
 scratch/
+
+# Checkpoints and replay buffers
+!checkpoints/.gitkeep
+replay_buffers/*
+!replay_buffers/.gitkeep
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..58d76d29beaa073412d53bbeab08ebcddc7151b9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Flatland
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index c4e93dc4ae84e6debde8a2483c7b3f94edbc0d5e..84c892d72ae9ece24eaf6e1975a24096a94b606a 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,93 @@
-![AIcrowd-Logo](https://raw.githubusercontent.com/AIcrowd/AIcrowd/master/app/assets/images/misc/aicrowd-horizontal.png)
+🚂 Starter Kit - NeurIPS 2020 Flatland Challenge
+===
 
-# Flatland Challenge Starter Kit
+This starter kit contains 2 example policies to get started with this challenge: 
+- a simple single-agent DQN method
+- a more robust multi-agent DQN method that you can submit out of the box to the challenge 🚀
 
-**[Follow these instructions to submit your solutions!](http://flatland.aicrowd.com/getting-started/first-submission.html)**
+**🔗 [Train the single-agent DQN policy](https://flatland.aicrowd.com/getting-started/rl/single-agent.html)**
 
+**🔗 [Train the multi-agent DQN policy](https://flatland.aicrowd.com/getting-started/rl/multi-agent.html)**
 
-![flatland](https://i.imgur.com/0rnbSLY.gif)
+**🔗 [Submit a trained policy](https://flatland.aicrowd.com/getting-started/first-submission.html)**
+
+The single-agent example is meant as a minimal example of how to use DQN. The multi-agent is a better starting point to create your own solution.
+
+You can fully train the multi-agent policy in Colab for free! [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GbPwZNQU7KJIJtilcGBTtpOAD3EabAzJ?usp=sharing)
+
+Sample training usage
+---
+
+Train the multi-agent policy for 150 episodes:
+
+```bash
+python reinforcement_learning/multi_agent_training.py -n 150
+```
+
+The multi-agent policy training can be tuned using command-line arguments:
+
+```console 
+usage: multi_agent_training.py [-h] [-n N_EPISODES] [-t TRAINING_ENV_CONFIG]
+                               [-e EVALUATION_ENV_CONFIG]
+                               [--n_evaluation_episodes N_EVALUATION_EPISODES]
+                               [--checkpoint_interval CHECKPOINT_INTERVAL]
+                               [--eps_start EPS_START] [--eps_end EPS_END]
+                               [--eps_decay EPS_DECAY]
+                               [--buffer_size BUFFER_SIZE]
+                               [--buffer_min_size BUFFER_MIN_SIZE]
+                               [--restore_replay_buffer RESTORE_REPLAY_BUFFER]
+                               [--save_replay_buffer SAVE_REPLAY_BUFFER]
+                               [--batch_size BATCH_SIZE] [--gamma GAMMA]
+                               [--tau TAU] [--learning_rate LEARNING_RATE]
+                               [--hidden_size HIDDEN_SIZE]
+                               [--update_every UPDATE_EVERY]
+                               [--use_gpu USE_GPU] [--num_threads NUM_THREADS]
+                               [--render RENDER]
+
+optional arguments:
+  -h, --help            show this help message and exit
+  -n N_EPISODES, --n_episodes N_EPISODES
+                        number of episodes to run
+  -t TRAINING_ENV_CONFIG, --training_env_config TRAINING_ENV_CONFIG
+                        training config id (eg 0 for Test_0)
+  -e EVALUATION_ENV_CONFIG, --evaluation_env_config EVALUATION_ENV_CONFIG
+                        evaluation config id (eg 0 for Test_0)
+  --n_evaluation_episodes N_EVALUATION_EPISODES
+                        number of evaluation episodes
+  --checkpoint_interval CHECKPOINT_INTERVAL
+                        checkpoint interval
+  --eps_start EPS_START
+                        max exploration
+  --eps_end EPS_END     min exploration
+  --eps_decay EPS_DECAY
+                        exploration decay
+  --buffer_size BUFFER_SIZE
+                        replay buffer size
+  --buffer_min_size BUFFER_MIN_SIZE
+                        min buffer size to start training
+  --restore_replay_buffer RESTORE_REPLAY_BUFFER
+                        replay buffer to restore
+  --save_replay_buffer SAVE_REPLAY_BUFFER
+                        save replay buffer at each evaluation interval
+  --batch_size BATCH_SIZE
+                        minibatch size
+  --gamma GAMMA         discount factor
+  --tau TAU             soft update of target parameters
+  --learning_rate LEARNING_RATE
+                        learning rate
+  --hidden_size HIDDEN_SIZE
+                        hidden size (2 fc layers)
+  --update_every UPDATE_EVERY
+                        how often to update the network
+  --use_gpu USE_GPU     use GPU if available
+  --num_threads NUM_THREADS
+                        number of threads PyTorch can use
+  --render RENDER       render 1 episode in 100
+```
+
+[**📈 Performance with various hyper-parameters**](https://app.wandb.ai/masterscrat/flatland-examples-reinforcement_learning/reports/Flatland-Examples--VmlldzoxNDI2MTA)
+
+[![](https://i.imgur.com/Lqrq5GE.png)](https://app.wandb.ai/masterscrat/flatland-examples-reinforcement_learning/reports/Flatland-Examples--VmlldzoxNDI2MTA) 
 
 Main links
 ---
@@ -18,9 +100,4 @@ Communication
 
 * [Discord Channel](https://discord.com/invite/hCR3CZG)
 * [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge)
-* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
-
-Author
----
-
-- **[Sharada Mohanty](https://twitter.com/MeMohanty)**
+* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
\ No newline at end of file
diff --git a/apt.txt b/apt.txt
index 5bf24df89ee3441d844eee250d29256307809d8b..834e45c4c9ed5f5432186e3e3751f6dd7e4dc6e4 100644
--- a/apt.txt
+++ b/apt.txt
@@ -2,4 +2,4 @@ curl
 git
 vim
 ssh
-gcc
+gcc
\ No newline at end of file
diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/sample-checkpoint.pth b/checkpoints/sample-checkpoint.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3fd7a50d88963dc4aa657825757fbfbfa51d508a
Binary files /dev/null and b/checkpoints/sample-checkpoint.pth differ
diff --git a/environment.yml b/environment.yml
index 626815eb2e8b8bf83a989a60f189f902a340bed3..471be17f8222df708f74d8e1d1a4bf78cc7b5937 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,111 +1,14 @@
 name: flatland-rl
 channels:
-  - anaconda
+  - pytorch
   - conda-forge
   - defaults
 dependencies:
-  - tk=8.6.8
-  - cairo=1.16.0
-  - cairocffi=1.1.0
-  - cairosvg=2.4.2
-  - cffi=1.12.3
-  - cssselect2=0.2.1
-  - defusedxml=0.6.0
-  - fontconfig=2.13.1
-  - freetype=2.10.0
-  - gettext=0.19.8.1
-  - glib=2.58.3
-  - icu=64.2
-  - jpeg=9c
-  - libiconv=1.15
-  - libpng=1.6.37
-  - libtiff=4.0.10
-  - libuuid=2.32.1
-  - libxcb=1.13
-  - libxml2=2.9.9
-  - lz4-c=1.8.3
-  - olefile=0.46
-  - pcre=8.41
-  - pillow=5.3.0
-  - pixman=0.38.0
-  - pthread-stubs=0.4
-  - pycairo=1.18.1
-  - pycparser=2.19
-  - tinycss2=1.0.2
-  - webencodings=0.5.1
-  - xorg-kbproto=1.0.7
-  - xorg-libice=1.0.10
-  - xorg-libsm=1.2.3
-  - xorg-libx11=1.6.8
-  - xorg-libxau=1.0.9
-  - xorg-libxdmcp=1.1.3
-  - xorg-libxext=1.3.4
-  - xorg-libxrender=0.9.10
-  - xorg-renderproto=0.11.1
-  - xorg-xextproto=7.3.0
-  - xorg-xproto=7.0.31
-  - zstd=1.4.0
-  - _libgcc_mutex=0.1
-  - ca-certificates=2019.5.15
-  - certifi=2019.6.16
-  - libedit=3.1.20181209
-  - libffi=3.2.1
-  - ncurses=6.1
-  - openssl=1.1.1c
-  - pip=19.1.1
-  - python=3.6.8
-  - readline=7.0
-  - setuptools=41.0.1
-  - sqlite=3.29.0
-  - wheel=0.33.4
-  - xz=5.2.4
-  - zlib=1.2.11
+  - psutil==5.7.2
+  - pytorch==1.6.0
+  - pip==20.2.3
+  - python==3.6.8
   - pip:
-    - atomicwrites==1.3.0
-    - importlib-metadata==0.19
-    - importlib-resources==1.0.2
-    - attrs==19.1.0
-    - chardet==3.0.4
-    - click==7.0
-    - cloudpickle==1.2.2
-    - crowdai-api==0.1.21
-    - cycler==0.10.0
-    - filelock==3.0.12
-    - flatland-rl==2.2.1
-    - future==0.17.1
-    - gym==0.14.0
-    - idna==2.8
-    - kiwisolver==1.1.0
-    - lxml==4.4.0
-    - matplotlib==3.1.1
-    - more-itertools==7.2.0
-    - msgpack==0.6.1
-    - msgpack-numpy==0.4.4.3
-    - numpy==1.17.0
-    - packaging==19.0
-    - pandas==0.25.0
-    - pluggy==0.12.0
-    - py==1.8.0
-    - pyarrow==0.14.1
-    - pyglet==1.3.2
-    - pyparsing==2.4.1.1
-    - pytest==5.0.1
-    - pytest-runner==5.1
-    - python-dateutil==2.8.0
-    - python-gitlab==1.10.0
-    - pytz==2019.1
-    - recordtype==1.3
-    - redis==3.3.2
-    - requests==2.22.0
-    - scipy==1.3.1
-    - six==1.12.0
-    - svgutils==0.3.1
-    - timeout-decorator==0.4.1
-    - toml==0.10.0
-    - tox==3.13.2
-    - urllib3==1.25.3
-    - ushlex==0.99.1
-    - virtualenv==16.7.2
-    - wcwidth==0.1.7
-    - xarray==0.12.3
-    - zipp==0.5.2
+      - flatland-rl==2.2.2
+      - tensorboard==2.3.0
+      - tensorboardx==2.1
\ No newline at end of file
diff --git a/my_observation_builder.py b/my_observation_builder.py
deleted file mode 100644
index 915ff839cb6cad922fe6ca7513465a4a9edde705..0000000000000000000000000000000000000000
--- a/my_observation_builder.py
+++ /dev/null
@@ -1,101 +0,0 @@
-#!/usr/bin/env python 
-
-import collections
-from typing import Optional, List, Dict, Tuple
-
-import numpy as np
-
-from flatland.core.env import Environment
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
-
-
-class CustomObservationBuilder(ObservationBuilder):
-    """
-    Template for building a custom observation builder for the RailEnv class
-
-    The observation in this case composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width),\
-          where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
-        
-        - the individual agent object (with position, direction, target information available)
-
-    """
-    def __init__(self):
-        super(CustomObservationBuilder, self).__init__()
-
-    def set_env(self, env: Environment):
-        super().set_env(env)
-        # Note :
-        # The instantiations which depend on parameters of the Env object should be 
-        # done here, as it is only here that the updated self.env instance is available
-        self.rail_obs = np.zeros((self.env.height, self.env.width))
-
-    def reset(self):
-        """
-        Called internally on every env.reset() call, 
-        to reset any observation specific variables that are being used
-        """
-        self.rail_obs[:] = 0        
-        for _x in range(self.env.width):
-            for _y in range(self.env.height):
-                # Get the transition map value at location _x, _y
-                transition_value = self.env.rail.get_full_transitions(_y, _x)
-                self.rail_obs[_y, _x] = transition_value
-
-    def get(self, handle: int = 0):
-        """
-        Returns the built observation for a single agent with handle : handle
-
-        In this particular case, we return 
-        - the global transition_map of the RailEnv,
-        - a tuple containing, the current agent's:
-            - state
-            - position
-            - direction
-            - initial_position
-            - target
-        """
-
-        agent = self.env.agents[handle]
-        """
-        Available information for each agent object : 
-
-        - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
-        - agent.position : Current position of the agent
-        - agent.direction : Current direction of the agent
-        - agent.initial_position : Initial Position of the agent
-        - agent.target : Target position of the agent
-        """
-
-        status = agent.status
-        position = agent.position
-        direction = agent.direction
-        initial_position = agent.initial_position
-        target = agent.target
-
-        
-        """
-        You can also optionally access the states of the rest of the agents by 
-        using something similar to 
-
-        for i in range(len(self.env.agents)):
-            other_agent: EnvAgent = self.env.agents[i]
-
-            # ignore other agents not in the grid any more
-            if other_agent.status == RailAgentStatus.DONE_REMOVED:
-                continue
-
-            ## Gather other agent specific params 
-            other_agent_status = other_agent.status
-            other_agent_position = other_agent.position
-            other_agent_direction = other_agent.direction
-            other_agent_initial_position = other_agent.initial_position
-            other_agent_target = other_agent.target
-
-            ## Do something nice here if you wish
-        """
-        return self.rail_obs, (status, position, direction, initial_position, target)
-
diff --git a/reinforcement_learning/__init__.py b/reinforcement_learning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff47b920f770a61369b46d6bbf0690af1cbc81d
--- /dev/null
+++ b/reinforcement_learning/dddqn_policy.py
@@ -0,0 +1,190 @@
+import copy
+import os
+import pickle
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+
+from reinforcement_learning.model import DuelingQNetwork
+from reinforcement_learning.policy import Policy
+
+
+class DDDQNPolicy(Policy):
+    """Dueling Double DQN policy"""
+
+    def __init__(self, state_size, action_size, parameters, evaluation_mode=False):
+        self.evaluation_mode = evaluation_mode
+
+        self.state_size = state_size
+        self.action_size = action_size
+        self.double_dqn = True
+        self.hidsize = 1
+
+        if not evaluation_mode:
+            self.hidsize = parameters.hidden_size
+            self.buffer_size = parameters.buffer_size
+            self.batch_size = parameters.batch_size
+            self.update_every = parameters.update_every
+            self.learning_rate = parameters.learning_rate
+            self.tau = parameters.tau
+            self.gamma = parameters.gamma
+            self.buffer_min_size = parameters.buffer_min_size
+
+        # Device
+        if parameters.use_gpu and torch.cuda.is_available():
+            self.device = torch.device("cuda:0")
+            # print("🐇 Using GPU")
+        else:
+            self.device = torch.device("cpu")
+            # print("🐢 Using CPU")
+
+        # Q-Network
+        self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(self.device)
+
+        if not evaluation_mode:
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+            self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate)
+            self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
+
+            self.t_step = 0
+            self.loss = 0.0
+
+    def act(self, state, eps=0.):
+        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy())
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    def step(self, state, action, reward, next_state, done):
+        assert not self.evaluation_mode, "Policy has been initialized for evaluation only."
+
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+
+        # Learn every UPDATE_EVERY time steps.
+        self.t_step = (self.t_step + 1) % self.update_every
+        if self.t_step == 0:
+            # If enough samples are available in memory, get random subset and learn
+            if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
+                self._learn()
+
+    def _learn(self):
+        experiences = self.memory.sample()
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+        # Compute Q targets for current states
+        q_targets = rewards + (self.gamma * q_targets_next * (1 - dones))
+
+        # Compute loss
+        self.loss = F.mse_loss(q_expected, q_targets)
+
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        self.loss.backward()
+        self.optimizer.step()
+
+        # Update target network
+        self._soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
+
+    def _soft_update(self, local_model, target_model, tau):
+        # Soft update model parameters.
+        # θ_target = τ*θ_local + (1 - τ)*θ_target
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        if os.path.exists(filename + ".local"):
+            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+        if os.path.exists(filename + ".target"):
+            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+
+    def save_replay_buffer(self, filename):
+        memory = self.memory.memory
+        with open(filename, 'wb') as f:
+            pickle.dump(list(memory)[-500000:], f)
+
+    def load_replay_buffer(self, filename):
+        with open(filename, 'rb') as f:
+            self.memory.memory = pickle.load(f)
+
+    def test(self):
+        self.act(np.array([[0] * self.state_size]))
+        self._learn()
+
+
+Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, device):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.device = device
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(self.device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(self.device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(self.device)
+
+        return states, actions, rewards, next_states, dones
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
diff --git a/reinforcement_learning/evaluate_agent.py b/reinforcement_learning/evaluate_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..80d42987c40bc1c590152b213ab8abaf6f9a91a6
--- /dev/null
+++ b/reinforcement_learning/evaluate_agent.py
@@ -0,0 +1,376 @@
+import math
+import multiprocessing
+import os
+import sys
+from argparse import ArgumentParser, Namespace
+from multiprocessing import Pool
+from pathlib import Path
+from pprint import pprint
+
+import numpy as np
+import torch
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.utils.rendertools import RenderTool
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.deadlock_check import check_if_all_blocked
+from utils.timer import Timer
+from utils.observation_utils import normalize_observation
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+
+
+def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching):
+    # Evaluation is faster on CPU (except if you use a really huge policy)
+    parameters = {
+        'use_gpu': False
+    }
+
+    policy = DDDQNPolicy(state_size, action_size, Namespace(**parameters), evaluation_mode=True)
+    policy.qnetwork_local = torch.load(checkpoint)
+
+    env_params = Namespace(**env_params)
+
+    # Environment parameters
+    n_agents = env_params.n_agents
+    x_dim = env_params.x_dim
+    y_dim = env_params.y_dim
+    n_cities = env_params.n_cities
+    max_rails_between_cities = env_params.max_rails_between_cities
+    max_rails_in_city = env_params.max_rails_in_city
+
+    # Malfunction and speed profiles
+    # TODO pass these parameters properly from main!
+    malfunction_parameters = MalfunctionParameters(
+        malfunction_rate=1. / 2000,  # Rate of malfunctions
+        min_duration=20,  # Minimal duration
+        max_duration=50  # Max duration
+    )
+
+    # Only fast trains in Round 1
+    speed_profiles = {
+        1.: 1.0,  # Fast passenger train
+        1. / 2.: 0.0,  # Fast freight train
+        1. / 3.: 0.0,  # Slow commuter train
+        1. / 4.: 0.0  # Slow freight train
+    }
+
+    # Observation parameters
+    observation_tree_depth = env_params.observation_tree_depth
+    observation_radius = env_params.observation_radius
+    observation_max_path_depth = env_params.observation_max_path_depth
+
+    # Observation builder
+    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+
+    # Setup the environment
+    env = RailEnv(
+        width=x_dim, height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city,
+        ),
+        schedule_generator=sparse_schedule_generator(speed_profiles),
+        number_of_agents=n_agents,
+        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
+        obs_builder_object=tree_observation
+    )
+
+    if render:
+        env_renderer = RenderTool(env, gl="PGL")
+
+    action_dict = dict()
+    scores = []
+    completions = []
+    nb_steps = []
+    inference_times = []
+    preproc_times = []
+    agent_times = []
+    step_times = []
+
+    for episode_idx in range(n_eval_episodes):
+        seed += 1
+
+        inference_timer = Timer()
+        preproc_timer = Timer()
+        agent_timer = Timer()
+        step_timer = Timer()
+
+        step_timer.start()
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True, random_seed=seed)
+        step_timer.end()
+
+        agent_obs = [None] * env.get_num_agents()
+        score = 0.0
+
+        if render:
+            env_renderer.set_new_rail()
+
+        final_step = 0
+        skipped = 0
+
+        nb_hit = 0
+        agent_last_obs = {}
+        agent_last_action = {}
+
+        for step in range(max_steps - 1):
+            if allow_skipping and check_if_all_blocked(env):
+                # FIXME why -1? bug where all agents are "done" after max_steps!
+                skipped = max_steps - step - 1
+                final_step = max_steps - 2
+                n_unfinished_agents = sum(not done[idx] for idx in env.get_agent_handles())
+                score -= skipped * n_unfinished_agents
+                break
+
+            agent_timer.start()
+            for agent in env.get_agent_handles():
+                if obs[agent] and info['action_required'][agent]:
+                    if agent in agent_last_obs and np.all(agent_last_obs[agent] == obs[agent]):
+                        nb_hit += 1
+                        action = agent_last_action[agent]
+
+                    else:
+                        preproc_timer.start()
+                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
+                        preproc_timer.end()
+
+                        inference_timer.start()
+                        action = policy.act(norm_obs, eps=0.0)
+                        inference_timer.end()
+
+                    action_dict.update({agent: action})
+
+                    if allow_caching:
+                        agent_last_obs[agent] = obs[agent]
+                        agent_last_action[agent] = action
+            agent_timer.end()
+
+            step_timer.start()
+            obs, all_rewards, done, info = env.step(action_dict)
+            step_timer.end()
+
+            if render:
+                env_renderer.render_env(
+                    show=True,
+                    frames=False,
+                    show_observations=False,
+                    show_predictions=False
+                )
+
+                if step % 100 == 0:
+                    print("{}/{}".format(step, max_steps - 1))
+
+            for agent in env.get_agent_handles():
+                score += all_rewards[agent]
+
+            final_step = step
+
+            if done['__all__']:
+                break
+
+        normalized_score = score / (max_steps * env.get_num_agents())
+        scores.append(normalized_score)
+
+        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
+        completion = tasks_finished / max(1, env.get_num_agents())
+        completions.append(completion)
+
+        nb_steps.append(final_step)
+
+        inference_times.append(inference_timer.get())
+        preproc_times.append(preproc_timer.get())
+        agent_times.append(agent_timer.get())
+        step_times.append(step_timer.get())
+
+        skipped_text = ""
+        if skipped > 0:
+            skipped_text = "\tâš¡ Skipped {}".format(skipped)
+
+        hit_text = ""
+        if nb_hit > 0:
+            hit_text = "\tâš¡ Hit {} ({:.1f}%)".format(nb_hit, (100 * nb_hit) / (n_agents * final_step))
+
+        print(
+            "☑️  Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} "
+            "\t🍭 Seed: {}"
+            "\t🚉 Env: {:.3f}s  "
+            "\t🤖 Agent: {:.3f}s (per step: {:.3f}s) \t[preproc: {:.3f}s \tinfer: {:.3f}s]"
+            "{}{}".format(
+                normalized_score,
+                completion * 100.0,
+                final_step,
+                seed,
+                step_timer.get(),
+                agent_timer.get(),
+                agent_timer.get() / final_step,
+                preproc_timer.get(),
+                inference_timer.get(),
+                skipped_text,
+                hit_text
+            )
+        )
+
+    return scores, completions, nb_steps, agent_times, step_times
+
+
+def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping, allow_caching):
+    nb_threads = 1
+    eval_per_thread = n_evaluation_episodes
+
+    if not render:
+        nb_threads = multiprocessing.cpu_count()
+        eval_per_thread = max(1, math.ceil(n_evaluation_episodes / nb_threads))
+
+    total_nb_eval = eval_per_thread * nb_threads
+    print("Will evaluate policy {} over {} episodes on {} threads.".format(file, total_nb_eval, nb_threads))
+
+    if total_nb_eval != n_evaluation_episodes:
+        print("(Rounding up from {} to fill all cores)".format(n_evaluation_episodes))
+
+    # Observation parameters need to match the ones used during training!
+
+    # small_v0
+    small_v0_params = {
+        # sample configuration
+        "n_agents": 5,
+        "x_dim": 25,
+        "y_dim": 25,
+        "n_cities": 4,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    # Test_0
+    test0_params = {
+        # sample configuration
+        "n_agents": 5,
+        "x_dim": 25,
+        "y_dim": 25,
+        "n_cities": 2,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    # Test_1
+    test1_params = {
+        # environment
+        "n_agents": 10,
+        "x_dim": 30,
+        "y_dim": 30,
+        "n_cities": 2,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 3,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 10
+    }
+
+    # Test_5
+    test5_params = {
+        # environment
+        "n_agents": 80,
+        "x_dim": 35,
+        "y_dim": 35,
+        "n_cities": 5,
+        "max_rails_between_cities": 2,
+        "max_rails_in_city": 4,
+
+        # observations
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 20
+    }
+
+    params = small_v0_params
+    env_params = Namespace(**params)
+
+    print("Environment parameters:")
+    pprint(params)
+
+    # Calculate space dimensions and max steps
+    max_steps = int(4 * 2 * (env_params.x_dim + env_params.y_dim + (env_params.n_agents / env_params.n_cities)))
+    action_size = 5
+    tree_observation = TreeObsForRailEnv(max_depth=env_params.observation_tree_depth)
+    tree_depth = env_params.observation_tree_depth
+    num_features_per_node = tree_observation.observation_dim
+    n_nodes = sum([np.power(4, i) for i in range(tree_depth + 1)])
+    state_size = num_features_per_node * n_nodes
+
+    results = []
+    if render:
+        results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching))
+
+    else:
+        with Pool() as p:
+            results = p.starmap(eval_policy,
+                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching)
+                                 for seed in
+                                 range(total_nb_eval)])
+
+    scores = []
+    completions = []
+    nb_steps = []
+    times = []
+    step_times = []
+    for s, c, n, t, st in results:
+        scores.append(s)
+        completions.append(c)
+        nb_steps.append(n)
+        times.append(t)
+        step_times.append(st)
+
+    print("-" * 200)
+
+    print("✅ Score: {:.3f} \tDone: {:.1f}% \tNb steps: {:.3f} \tAgent total: {:.3f}s (per step: {:.3f}s)".format(
+        np.mean(scores),
+        np.mean(completions) * 100.0,
+        np.mean(nb_steps),
+        np.mean(times),
+        np.mean(times) / np.mean(nb_steps)
+    ))
+
+    print("⏲️  Agent sum: {:.3f}s \tEnv sum: {:.3f}s \tTotal sum: {:.3f}s".format(
+        np.sum(times),
+        np.sum(step_times),
+        np.sum(times) + np.sum(step_times)
+    ))
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-f", "--file", help="checkpoint to load", required=True, type=str)
+    parser.add_argument("-n", "--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
+
+    # TODO
+    # parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int)
+
+    parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
+    parser.add_argument("--render", help="render a single episode", action='store_true')
+    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
+    parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true')
+    args = parser.parse_args()
+
+    os.environ["OMP_NUM_THREADS"] = str(1)
+    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render,
+                    allow_skipping=args.allow_skipping, allow_caching=args.allow_caching)
diff --git a/reinforcement_learning/model.py b/reinforcement_learning/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc6c8a98db876e7f3489b0db29377640ce41d176
--- /dev/null
+++ b/reinforcement_learning/model.py
@@ -0,0 +1,31 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DuelingQNetwork(nn.Module):
+    """Dueling Q-network (https://arxiv.org/abs/1511.06581)"""
+
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
+        super(DuelingQNetwork, self).__init__()
+
+        # value network
+        self.fc1_val = nn.Linear(state_size, hidsize1)
+        self.fc2_val = nn.Linear(hidsize1, hidsize2)
+        self.fc4_val = nn.Linear(hidsize2, 1)
+
+        # advantage network
+        self.fc1_adv = nn.Linear(state_size, hidsize1)
+        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
+        self.fc4_adv = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x):
+        val = F.relu(self.fc1_val(x))
+        val = F.relu(self.fc2_val(val))
+        val = self.fc4_val(val)
+
+        # advantage calculation
+        adv = F.relu(self.fc1_adv(x))
+        adv = F.relu(self.fc2_adv(adv))
+        adv = self.fc4_adv(adv)
+
+        return val + adv - adv.mean()
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..7118a3a47855a5f77814e294f9067fd6feb16ce1
--- /dev/null
+++ b/reinforcement_learning/multi_agent_training.py
@@ -0,0 +1,512 @@
+from datetime import datetime
+import os
+import random
+import sys
+from argparse import ArgumentParser, Namespace
+from pathlib import Path
+from pprint import pprint
+
+import psutil
+from flatland.utils.rendertools import RenderTool
+from torch.utils.tensorboard import SummaryWriter
+import numpy as np
+import torch
+
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.envs.observations import TreeObsForRailEnv
+
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.timer import Timer
+from utils.observation_utils import normalize_observation
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+
+try:
+    import wandb
+
+    wandb.init(sync_tensorboard=True)
+except ImportError:
+    print("Install wandb to log to Weights & Biases")
+
+"""
+This file shows how to train multiple agents using a reinforcement learning approach.
+After training an agent, you can submit it straight away to the NeurIPS 2020 Flatland challenge!
+
+Agent documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html
+Submission documentation: https://flatland.aicrowd.com/getting-started/first-submission.html
+"""
+
+
+def create_rail_env(env_params, tree_observation):
+    n_agents = env_params.n_agents
+    x_dim = env_params.x_dim
+    y_dim = env_params.y_dim
+    n_cities = env_params.n_cities
+    max_rails_between_cities = env_params.max_rails_between_cities
+    max_rails_in_city = env_params.max_rails_in_city
+    seed = env_params.seed
+
+    # Break agents from time to time
+    malfunction_parameters = MalfunctionParameters(
+        malfunction_rate=env_params.malfunction_rate,
+        min_duration=20,
+        max_duration=50
+    )
+
+    return RailEnv(
+        width=x_dim, height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city
+        ),
+        schedule_generator=sparse_schedule_generator(),
+        number_of_agents=n_agents,
+        malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
+        obs_builder_object=tree_observation,
+        random_seed=seed
+    )
+
+
+def train_agent(train_params, train_env_params, eval_env_params, obs_params):
+    # Environment parameters
+    n_agents = train_env_params.n_agents
+    x_dim = train_env_params.x_dim
+    y_dim = train_env_params.y_dim
+    n_cities = train_env_params.n_cities
+    max_rails_between_cities = train_env_params.max_rails_between_cities
+    max_rails_in_city = train_env_params.max_rails_in_city
+    seed = train_env_params.seed
+
+    # Unique ID for this training
+    now = datetime.now()
+    training_id = now.strftime('%y%m%d%H%M%S')
+
+    # Observation parameters
+    observation_tree_depth = obs_params.observation_tree_depth
+    observation_radius = obs_params.observation_radius
+    observation_max_path_depth = obs_params.observation_max_path_depth
+
+    # Training parameters
+    eps_start = train_params.eps_start
+    eps_end = train_params.eps_end
+    eps_decay = train_params.eps_decay
+    n_episodes = train_params.n_episodes
+    checkpoint_interval = train_params.checkpoint_interval
+    n_eval_episodes = train_params.n_evaluation_episodes
+    restore_replay_buffer = train_params.restore_replay_buffer
+    save_replay_buffer = train_params.save_replay_buffer
+
+    # Set the seeds
+    random.seed(seed)
+    np.random.seed(seed)
+
+    # Observation builder
+    predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+
+    # Setup the environments
+    train_env = create_rail_env(train_env_params, tree_observation)
+    train_env.reset(regenerate_schedule=True, regenerate_rail=True)
+    eval_env = create_rail_env(eval_env_params, tree_observation)
+    eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
+
+    # Setup renderer
+    if train_params.render:
+        env_renderer = RenderTool(train_env, gl="PGL")
+
+    # Calculate the state size given the depth of the tree observation and the number of features
+    n_features_per_node = train_env.obs_builder.observation_dim
+    n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
+    state_size = n_features_per_node * n_nodes
+
+    # The action space of flatland is 5 discrete actions
+    action_size = 5
+
+    # Max number of steps per episode
+    # This is the official formula used during evaluations
+    # See details in flatland.envs.schedule_generators.sparse_schedule_generator
+    # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+    max_steps = train_env._max_episode_steps
+
+    action_count = [0] * action_size
+    action_dict = dict()
+    agent_obs = [None] * n_agents
+    agent_prev_obs = [None] * n_agents
+    agent_prev_action = [2] * n_agents
+    update_values = [False] * n_agents
+
+    # Smoothed values used as target for hyperparameter tuning
+    smoothed_normalized_score = -1.0
+    smoothed_eval_normalized_score = -1.0
+    smoothed_completion = 0.0
+    smoothed_eval_completion = 0.0
+
+    # Double Dueling DQN policy
+    policy = DDDQNPolicy(state_size, action_size, train_params)
+
+    # Loads existing replay buffer
+    if restore_replay_buffer:
+        try:
+            policy.load_replay_buffer(restore_replay_buffer)
+            policy.test()
+        except RuntimeError as e:
+            print("\n🛑 Could't load replay buffer, were the experiences generated using the same tree depth?")
+            print(e)
+            exit(1)
+
+    print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))
+
+    hdd = psutil.disk_usage('/')
+    if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
+        print("⚠️  Careful! Saving replay buffers will quickly consume a lot of disk space. You have {:.2f}gb left.".format(hdd.free / (2 ** 30)))
+
+    # TensorBoard writer
+    writer = SummaryWriter()
+    writer.add_hparams(vars(train_params), {})
+    writer.add_hparams(vars(train_env_params), {})
+    writer.add_hparams(vars(obs_params), {})
+
+    training_timer = Timer()
+    training_timer.start()
+
+    print("\n🚉 Training {} trains on {}x{} grid for {} episodes, evaluating on {} episodes every {} episodes. Training id '{}'.\n".format(
+        train_env.get_num_agents(),
+        x_dim, y_dim,
+        n_episodes,
+        n_eval_episodes,
+        checkpoint_interval,
+        training_id
+    ))
+
+    for episode_idx in range(n_episodes + 1):
+        step_timer = Timer()
+        reset_timer = Timer()
+        learn_timer = Timer()
+        preproc_timer = Timer()
+        inference_timer = Timer()
+
+        # Reset environment
+        reset_timer.start()
+        obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
+        reset_timer.end()
+
+        if train_params.render:
+            env_renderer.set_new_rail()
+
+        score = 0
+        nb_steps = 0
+        actions_taken = []
+
+        # Build initial agent-specific observations
+        for agent in train_env.get_agent_handles():
+            if obs[agent]:
+                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius)
+                agent_prev_obs[agent] = agent_obs[agent].copy()
+
+        # Run episode
+        for step in range(max_steps - 1):
+            inference_timer.start()
+            for agent in train_env.get_agent_handles():
+                if info['action_required'][agent]:
+                    update_values[agent] = True
+                    action = policy.act(agent_obs[agent], eps=eps_start)
+
+                    action_count[action] += 1
+                    actions_taken.append(action)
+                else:
+                    # An action is not required if the train hasn't joined the railway network,
+                    # if it already reached its target, or if is currently malfunctioning.
+                    update_values[agent] = False
+                    action = 0
+                action_dict.update({agent: action})
+            inference_timer.end()
+
+            # Environment step
+            step_timer.start()
+            next_obs, all_rewards, done, info = train_env.step(action_dict)
+            step_timer.end()
+
+            # Render an episode at some interval
+            if train_params.render and episode_idx % checkpoint_interval == 0:
+                env_renderer.render_env(
+                    show=True,
+                    frames=False,
+                    show_observations=False,
+                    show_predictions=False
+                )
+
+            # Update replay buffer and train agent
+            for agent in train_env.get_agent_handles():
+                if update_values[agent] or done['__all__']:
+                    # Only learn from timesteps where somethings happened
+                    learn_timer.start()
+                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent])
+                    learn_timer.end()
+
+                    agent_prev_obs[agent] = agent_obs[agent].copy()
+                    agent_prev_action[agent] = action_dict[agent]
+
+                # Preprocess the new observations
+                if next_obs[agent]:
+                    preproc_timer.start()
+                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=observation_radius)
+                    preproc_timer.end()
+
+                score += all_rewards[agent]
+
+            nb_steps = step
+
+            if done['__all__']:
+                break
+
+        # Epsilon decay
+        eps_start = max(eps_end, eps_decay * eps_start)
+
+        # Collect information about training
+        tasks_finished = sum(done[idx] for idx in train_env.get_agent_handles())
+        completion = tasks_finished / max(1, train_env.get_num_agents())
+        normalized_score = score / (max_steps * train_env.get_num_agents())
+        action_probs = action_count / np.sum(action_count)
+        action_count = [1] * action_size
+
+        smoothing = 0.99
+        smoothed_normalized_score = smoothed_normalized_score * smoothing + normalized_score * (1.0 - smoothing)
+        smoothed_completion = smoothed_completion * smoothing + completion * (1.0 - smoothing)
+
+        # Print logs
+        if episode_idx % checkpoint_interval == 0:
+            torch.save(policy.qnetwork_local, './checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
+
+            if save_replay_buffer:
+                policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl')
+
+            if train_params.render:
+                env_renderer.close_window()
+
+        print(
+            '\r🚂 Episode {}'
+            '\t 🏆 Score: {:.3f}'
+            ' Avg: {:.3f}'
+            '\t 💯 Done: {:.2f}%'
+            ' Avg: {:.2f}%'
+            '\t 🎲 Epsilon: {:.3f} '
+            '\t 🔀 Action Probs: {}'.format(
+                episode_idx,
+                normalized_score,
+                smoothed_normalized_score,
+                100 * completion,
+                100 * smoothed_completion,
+                eps_start,
+                format_action_prob(action_probs)
+            ), end=" ")
+
+        # Evaluate policy and log results at some interval
+        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0:
+            scores, completions, nb_steps_eval = eval_policy(eval_env, policy, train_params, obs_params)
+
+            writer.add_scalar("evaluation/scores_min", np.min(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_max", np.max(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_mean", np.mean(scores), episode_idx)
+            writer.add_scalar("evaluation/scores_std", np.std(scores), episode_idx)
+            writer.add_histogram("evaluation/scores", np.array(scores), episode_idx)
+            writer.add_scalar("evaluation/completions_min", np.min(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_max", np.max(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_mean", np.mean(completions), episode_idx)
+            writer.add_scalar("evaluation/completions_std", np.std(completions), episode_idx)
+            writer.add_histogram("evaluation/completions", np.array(completions), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_min", np.min(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_max", np.max(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_mean", np.mean(nb_steps_eval), episode_idx)
+            writer.add_scalar("evaluation/nb_steps_std", np.std(nb_steps_eval), episode_idx)
+            writer.add_histogram("evaluation/nb_steps", np.array(nb_steps_eval), episode_idx)
+
+            smoothing = 0.9
+            smoothed_eval_normalized_score = smoothed_eval_normalized_score * smoothing + np.mean(scores) * (1.0 - smoothing)
+            smoothed_eval_completion = smoothed_eval_completion * smoothing + np.mean(completions) * (1.0 - smoothing)
+            writer.add_scalar("evaluation/smoothed_score", smoothed_eval_normalized_score, episode_idx)
+            writer.add_scalar("evaluation/smoothed_completion", smoothed_eval_completion, episode_idx)
+
+        # Save logs to tensorboard
+        writer.add_scalar("training/score", normalized_score, episode_idx)
+        writer.add_scalar("training/smoothed_score", smoothed_normalized_score, episode_idx)
+        writer.add_scalar("training/completion", np.mean(completion), episode_idx)
+        writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx)
+        writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
+        writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx)
+        writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx)
+        writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
+        writer.add_scalar("actions/forward", action_probs[RailEnvActions.MOVE_FORWARD], episode_idx)
+        writer.add_scalar("actions/right", action_probs[RailEnvActions.MOVE_RIGHT], episode_idx)
+        writer.add_scalar("actions/stop", action_probs[RailEnvActions.STOP_MOVING], episode_idx)
+        writer.add_scalar("training/epsilon", eps_start, episode_idx)
+        writer.add_scalar("training/buffer_size", len(policy.memory), episode_idx)
+        writer.add_scalar("training/loss", policy.loss, episode_idx)
+        writer.add_scalar("timer/reset", reset_timer.get(), episode_idx)
+        writer.add_scalar("timer/step", step_timer.get(), episode_idx)
+        writer.add_scalar("timer/learn", learn_timer.get(), episode_idx)
+        writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx)
+        writer.add_scalar("timer/total", training_timer.get_current(), episode_idx)
+
+
+def format_action_prob(action_probs):
+    action_probs = np.round(action_probs, 3)
+    actions = ["↻", "←", "↑", "→", "◼"]
+
+    buffer = ""
+    for action, action_prob in zip(actions, action_probs):
+        buffer += action + " " + "{:.3f}".format(action_prob) + " "
+
+    return buffer
+
+
+def eval_policy(env, policy, train_params, obs_params):
+    n_eval_episodes = train_params.n_evaluation_episodes
+    max_steps = env._max_episode_steps
+    tree_depth = obs_params.observation_tree_depth
+    observation_radius = obs_params.observation_radius
+
+    action_dict = dict()
+    scores = []
+    completions = []
+    nb_steps = []
+
+    for episode_idx in range(n_eval_episodes):
+        agent_obs = [None] * env.get_num_agents()
+        score = 0.0
+
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+
+        final_step = 0
+
+        for step in range(max_steps - 1):
+            for agent in env.get_agent_handles():
+                if obs[agent]:
+                    agent_obs[agent] = normalize_observation(obs[agent], tree_depth=tree_depth, observation_radius=observation_radius)
+
+                action = 0
+                if info['action_required'][agent]:
+                    action = policy.act(agent_obs[agent], eps=0.0)
+                action_dict.update({agent: action})
+
+            obs, all_rewards, done, info = env.step(action_dict)
+
+            for agent in env.get_agent_handles():
+                score += all_rewards[agent]
+
+            final_step = step
+
+            if done['__all__']:
+                break
+
+        normalized_score = score / (max_steps * env.get_num_agents())
+        scores.append(normalized_score)
+
+        tasks_finished = sum(done[idx] for idx in env.get_agent_handles())
+        completion = tasks_finished / max(1, env.get_num_agents())
+        completions.append(completion)
+
+        nb_steps.append(final_step)
+
+    print("\t✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0))
+
+    return scores, completions, nb_steps
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=2500, type=int)
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=0, type=int)
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0, type=int)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=25, type=int)
+    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
+    parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
+    parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.99, type=float)
+    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e5), type=int)
+    parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
+    parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
+    parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False, type=bool)
+    parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
+    parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
+    parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
+    parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
+    parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
+    parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
+    parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
+    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
+    parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
+    training_params = parser.parse_args()
+
+    env_params = [
+        {
+            # Test_0
+            "n_agents": 5,
+            "x_dim": 25,
+            "y_dim": 25,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 50,
+            "seed": 0
+        },
+        {
+            # Test_1
+            "n_agents": 10,
+            "x_dim": 30,
+            "y_dim": 30,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 100,
+            "seed": 0
+        },
+        {
+            # Test_2
+            "n_agents": 20,
+            "x_dim": 30,
+            "y_dim": 30,
+            "n_cities": 3,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 200,
+            "seed": 0
+        },
+    ]
+
+    obs_params = {
+        "observation_tree_depth": 2,
+        "observation_radius": 10,
+        "observation_max_path_depth": 30
+    }
+
+    def check_env_config(id):
+        if id >= len(env_params) or id < 0:
+            print("\n🛑 Invalid environment configuration, only Test_0 to Test_{} are supported.".format(len(env_params) - 1))
+            exit(1)
+
+
+    check_env_config(training_params.training_env_config)
+    check_env_config(training_params.evaluation_env_config)
+
+    training_env_params = env_params[training_params.training_env_config]
+    evaluation_env_params = env_params[training_params.evaluation_env_config]
+
+    print("\nTraining parameters:")
+    pprint(vars(training_params))
+    print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config))
+    pprint(training_env_params)
+    print("\nEvaluation environment parameters (Test_{}):".format(training_params.evaluation_env_config))
+    pprint(evaluation_env_params)
+    print("\nObservation parameters:")
+    pprint(obs_params)
+
+    os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads)
+    train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params), Namespace(**obs_params))
diff --git a/reinforcement_learning/ordered_policy.py b/reinforcement_learning/ordered_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..daf6639d33052eedc5b69481e84413edea552eee
--- /dev/null
+++ b/reinforcement_learning/ordered_policy.py
@@ -0,0 +1,34 @@
+import sys
+from pathlib import Path
+
+import numpy as np
+
+from reinforcement_learning.policy import Policy
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.observation_utils import split_tree_into_feature_groups, min_gt
+
+
+class OrderedPolicy(Policy):
+    def __init__(self):
+        self.action_size = 5
+
+    def act(self, state, eps=0.):
+        _, distance, _ = split_tree_into_feature_groups(state, 1)
+        distance = distance[1:]
+        min_dist = min_gt(distance, 0)
+        min_direction = np.where(distance == min_dist)
+        if len(min_direction[0]) > 1:
+            return min_direction[0][-1] + 1
+        return min_direction[0] + 1
+
+    def step(self, state, action, reward, next_state, done):
+        return
+
+    def save(self, filename):
+        return
+
+    def load(self, filename):
+        return
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5c6a3c23dacbb182dceda52617a0be12d1acf7b
--- /dev/null
+++ b/reinforcement_learning/policy.py
@@ -0,0 +1,12 @@
+class Policy:
+    def step(self, state, action, reward, next_state, done):
+        raise NotImplementedError
+
+    def act(self, state, eps=0.):
+        raise NotImplementedError
+
+    def save(self, filename):
+        raise NotImplementedError
+
+    def load(self, filename):
+        raise NotImplementedError
diff --git a/reinforcement_learning/sequential_agent.py b/reinforcement_learning/sequential_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb5a73cdc42f33a5e771eeaf530cf4af9742be8
--- /dev/null
+++ b/reinforcement_learning/sequential_agent.py
@@ -0,0 +1,85 @@
+import sys
+import numpy as np
+
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
+from flatland.utils.rendertools import RenderTool
+from pathlib import Path
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.ordered_policy import OrderedPolicy
+
+"""
+This file shows how to move agents in a sequential way: it moves the trains one by one, following a shortest path strategy.
+This is obviously very slow, but it's a good way to get familiar with the different Flatland components: RailEnv, TreeObsForRailEnv, etc...
+
+multi_agent_training.py is a better starting point to train your own solution!
+"""
+
+np.random.seed(2)
+
+x_dim = np.random.randint(8, 20)
+y_dim = np.random.randint(8, 20)
+n_agents = np.random.randint(3, 8)
+n_goals = n_agents + np.random.randint(0, 3)
+min_dist = int(0.75 * min(x_dim, y_dim))
+
+env = RailEnv(
+    width=x_dim,
+    height=y_dim,
+    rail_generator=complex_rail_generator(
+        nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+        max_dist=99999,
+        seed=0
+    ),
+    schedule_generator=complex_schedule_generator(),
+    obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()),
+    number_of_agents=n_agents)
+env.reset(True, True)
+
+tree_depth = 1
+observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
+env_renderer = RenderTool(env, gl="PGL", )
+handle = env.get_agent_handles()
+n_episodes = 10
+max_steps = 100 * (env.height + env.width)
+record_images = False
+policy = OrderedPolicy()
+action_dict = dict()
+
+for trials in range(1, n_episodes + 1):
+    # Reset environment
+    obs, info = env.reset(True, True)
+    done = env.dones
+    env_renderer.reset()
+    frame_step = 0
+
+    # Run episode
+    for step in range(max_steps):
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
+
+        if record_images:
+            env_renderer.gl.save_image("./Images/flatland_frame_{:04d}.bmp".format(frame_step))
+            frame_step += 1
+
+        # Action
+        acting_agent = 0
+        for a in range(env.get_num_agents()):
+            if done[a]:
+                acting_agent += 1
+            if a == acting_agent:
+                action = policy.act(obs[a])
+            else:
+                action = 4
+            action_dict.update({a: action})
+
+        # Environment step
+        obs, all_rewards, done, _ = env.step(action_dict)
+
+        if done['__all__']:
+            break
diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..236d1a76cbbff9be26612e81bf24265886acbab8
--- /dev/null
+++ b/reinforcement_learning/single_agent_training.py
@@ -0,0 +1,203 @@
+import random
+import sys
+from argparse import ArgumentParser, Namespace
+from collections import deque
+from pathlib import Path
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from utils.observation_utils import normalize_observation
+from flatland.envs.observations import TreeObsForRailEnv
+
+"""
+This file shows how to train a single agent using a reinforcement learning approach.
+Documentation: https://flatland.aicrowd.com/getting-started/rl/single-agent.html
+
+This is a simple method used for demonstration purposes.
+multi_agent_training.py is a better starting point to train your own solution!
+"""
+
+
+def train_agent(n_episodes):
+    # Environment parameters
+    n_agents = 1
+    x_dim = 25
+    y_dim = 25
+    n_cities = 4
+    max_rails_between_cities = 2
+    max_rails_in_city = 3
+    seed = 42
+
+    # Observation parameters
+    observation_tree_depth = 2
+    observation_radius = 10
+
+    # Exploration parameters
+    eps_start = 1.0
+    eps_end = 0.01
+    eps_decay = 0.997  # for 2500ts
+
+    # Set the seeds
+    random.seed(seed)
+    np.random.seed(seed)
+
+    # Observation builder
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth)
+
+    # Setup the environment
+    env = RailEnv(
+        width=x_dim,
+        height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            seed=seed,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city
+        ),
+        schedule_generator=sparse_schedule_generator(),
+        number_of_agents=n_agents,
+        obs_builder_object=tree_observation
+    )
+
+    env.reset(True, True)
+
+    # Calculate the state size given the depth of the tree observation and the number of features
+    n_features_per_node = env.obs_builder.observation_dim
+    n_nodes = 0
+    for i in range(observation_tree_depth + 1):
+        n_nodes += np.power(4, i)
+    state_size = n_features_per_node * n_nodes
+
+    # The action space of flatland is 5 discrete actions
+    action_size = 5
+
+    # Max number of steps per episode
+    # This is the official formula used during evaluations
+    max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+
+    action_dict = dict()
+
+    # And some variables to keep track of the progress
+    scores_window = deque(maxlen=100)  # todo smooth when rendering instead
+    completion_window = deque(maxlen=100)
+    scores = []
+    completion = []
+    action_count = [0] * action_size
+    agent_obs = [None] * env.get_num_agents()
+    agent_prev_obs = [None] * env.get_num_agents()
+    agent_prev_action = [2] * env.get_num_agents()
+    update_values = False
+
+    # Training parameters
+    training_parameters = {
+        'buffer_size': int(1e5),
+        'batch_size': 32,
+        'update_every': 8,
+        'learning_rate': 0.5e-4,
+        'tau': 1e-3,
+        'gamma': 0.99,
+        'buffer_min_size': 0,
+        'hidden_size': 256,
+        'use_gpu': False
+    }
+
+    # Double Dueling DQN policy
+    policy = DDDQNPolicy(state_size, action_size, Namespace(**training_parameters))
+
+    for episode_idx in range(n_episodes):
+        score = 0
+
+        # Reset environment
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+
+        # Build agent specific observations
+        for agent in env.get_agent_handles():
+            if obs[agent]:
+                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius)
+                agent_prev_obs[agent] = agent_obs[agent].copy()
+
+        # Run episode
+        for step in range(max_steps - 1):
+            for agent in env.get_agent_handles():
+                if info['action_required'][agent]:
+                    # If an action is required, we want to store the obs at that step as well as the action
+                    update_values = True
+                    action = policy.act(agent_obs[agent], eps=eps_start)
+                    action_count[action] += 1
+                else:
+                    update_values = False
+                    action = 0
+                action_dict.update({agent: action})
+
+            # Environment step
+            next_obs, all_rewards, done, info = env.step(action_dict)
+
+            # Update replay buffer and train agent
+            for agent in range(env.get_num_agents()):
+                # Only update the values when we are done or when an action was taken and thus relevant information is present
+                if update_values or done[agent]:
+                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent])
+
+                    agent_prev_obs[agent] = agent_obs[agent].copy()
+                    agent_prev_action[agent] = action_dict[agent]
+
+                if next_obs[agent]:
+                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=10)
+
+                score += all_rewards[agent]
+
+            if done['__all__']:
+                break
+
+        # Epsilon decay
+        eps_start = max(eps_end, eps_decay * eps_start)
+
+        # Collection information about training
+        tasks_finished = np.sum([int(done[idx]) for idx in env.get_agent_handles()])
+        completion_window.append(tasks_finished / max(1, env.get_num_agents()))
+        scores_window.append(score / (max_steps * env.get_num_agents()))
+        completion.append((np.mean(completion_window)))
+        scores.append(np.mean(scores_window))
+        action_probs = action_count / np.sum(action_count)
+
+        if episode_idx % 100 == 0:
+            end = "\n"
+            torch.save(policy.qnetwork_local, './checkpoints/single-' + str(episode_idx) + '.pth')
+            action_count = [1] * action_size
+        else:
+            end = " "
+
+        print('\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+            env.get_num_agents(),
+            x_dim, y_dim,
+            episode_idx,
+            np.mean(scores_window),
+            100 * np.mean(completion_window),
+            eps_start,
+            action_probs
+        ), end=end)
+
+    # Plot overall training progress at the end
+    plt.plot(scores)
+    plt.show()
+
+    plt.plot(completion)
+    plt.show()
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500, type=int)
+    args = parser.parse_args()
+
+    train_agent(args.n_episodes)
diff --git a/replay_buffers/.gitkeep b/replay_buffers/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/run.py b/run.py
index 5c5bb9a020297404cc25e31dcf0325d052723db0..b5968d318bb0fda55b0f87a3d320a692734dfdb5 100644
--- a/run.py
+++ b/run.py
@@ -1,160 +1,205 @@
-from flatland.evaluators.client import FlatlandRemoteClient
-from flatland.core.env_observation_builder import DummyObservationBuilder
-from my_observation_builder import CustomObservationBuilder
+import os
+import sys
+from argparse import Namespace
+from pathlib import Path
+
 import numpy as np
 import time
 
+import torch
+from flatland.core.env_observation_builder import DummyObservationBuilder
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.evaluators.client import FlatlandRemoteClient
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.evaluators.client import TimeoutException
 
+from utils.deadlock_check import check_if_all_blocked
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from utils.observation_utils import normalize_observation
+
+####################################################
+# EVALUATION PARAMETERS
+
+# Print per-step logs
+VERBOSE = True
+
+# Checkpoint to use (remember to push it!)
+checkpoint = ""
+
+# Use last action cache
+USE_ACTION_CACHE = True
+
+# Observation parameters (must match training parameters!)
+observation_tree_depth = 2
+observation_radius = 10
+observation_max_path_depth = 30
+
+####################################################
 
-#####################################################################
-# Instantiate a Remote Client
-#####################################################################
 remote_client = FlatlandRemoteClient()
 
-#####################################################################
-# Define your custom controller
-#
-# which can take an observation, and the number of agents and 
-# compute the necessary action for this step for all (or even some)
-# of the agents
-#####################################################################
-def my_controller(obs, number_of_agents):
-    _action = {}
-    for _idx in range(number_of_agents):
-        _action[_idx] = np.random.randint(0, 5)
-    return _action
+# Observation builder
+predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
 
-#####################################################################
-# Instantiate your custom Observation Builder
-# 
-# You can build your own Observation Builder by following 
-# the example here : 
-# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
-#####################################################################
-my_observation_builder = CustomObservationBuilder()
+# Calculates state and action sizes
+n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
+state_size = tree_observation.observation_dim * n_nodes
+action_size = 5
 
-# Or if you want to use your own approach to build the observation from the env_step, 
-# please feel free to pass a DummyObservationBuilder() object as mentioned below,
-# and that will just return a placeholder True for all observation, and you 
-# can build your own Observation for all the agents as your please.
-# my_observation_builder = DummyObservationBuilder()
+# Creates the policy. No GPU on evaluation server.
+policy = DDDQNPolicy(state_size, action_size, Namespace(**{'use_gpu': False}), evaluation_mode=True)
 
+if os.path.isfile(checkpoint):
+    policy.qnetwork_local = torch.load(checkpoint)
+else:
+    print("Checkpoint not found, using untrained policy! (path: {})".format(checkpoint))
 
 #####################################################################
 # Main evaluation loop
-#
-# This iterates over an arbitrary number of env evaluations
 #####################################################################
 evaluation_number = 0
-while True:
 
+while True:
     evaluation_number += 1
-    # Switch to a new evaluation environemnt
-    # 
-    # a remote_client.env_create is similar to instantiating a 
-    # RailEnv and then doing a env.reset()
-    # hence it returns the first observation from the 
-    # env.reset()
-    # 
-    # You can also pass your custom observation_builder object
-    # to allow you to have as much control as you wish 
-    # over the observation of your choice.
+
+    # We use a dummy observation and call TreeObsForRailEnv ourselves when needed.
+    # This way we decide if we want to calculate the observations or not instead
+    # of having them calculated every time we perform an env step.
     time_start = time.time()
     observation, info = remote_client.env_create(
-                    obs_builder_object=my_observation_builder
-                )
+        obs_builder_object=DummyObservationBuilder()
+    )
     env_creation_time = time.time() - time_start
+
     if not observation:
-        #
         # If the remote_client returns False on a `env_create` call,
-        # then it basically means that your agent has already been 
+        # then it basically means that your agent has already been
         # evaluated on all the required evaluation environments,
-        # and hence its safe to break out of the main evaluation loop
+        # and hence it's safe to break out of the main evaluation loop.
         break
-    
-    print("Evaluation Number : {}".format(evaluation_number))
-
-    #####################################################################
-    # Access to a local copy of the environment
-    # 
-    #####################################################################
-    # Note: You can access a local copy of the environment 
-    # by using : 
-    #       remote_client.env 
-    # 
-    # But please ensure to not make any changes (or perform any action) on 
-    # the local copy of the env, as then it will diverge from 
-    # the state of the remote copy of the env, and the observations and 
-    # rewards, etc will behave unexpectedly
-    # 
-    # You can however probe the local_env instance to get any information
-    # you need from the environment. It is a valid RailEnv instance.
+
+    print("Env Path : ", remote_client.current_env_path)
+    print("Env Creation Time : ", env_creation_time)
+
     local_env = remote_client.env
-    number_of_agents = len(local_env.agents)
+    nb_agents = len(local_env.agents)
+    max_nb_steps = local_env._max_episode_steps
 
-    # Now we enter into another infinite loop where we 
+    tree_observation.set_env(local_env)
+    tree_observation.reset()
+    observation = tree_observation.get_many(list(range(nb_agents)))
+
+    print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
+
+    # Now we enter into another infinite loop where we
     # compute the actions for all the individual steps in this episode
     # until the episode is `done`
-    # 
-    # An episode is considered done when either all the agents have 
-    # reached their target destination
-    # or when the number of time steps has exceed max_time_steps, which 
-    # is defined by : 
-    #
-    # max_time_steps = int(4 * 2 * (env.width + env.height + 20))
-    #
+    steps = 0
+
+    # Bookkeeping
     time_taken_by_controller = []
     time_taken_per_step = []
-    steps = 0
+
+    # Action cache: keep track of last observation to avoid running the same inferrence multiple times.
+    # This only makes sense for deterministic policies.
+    agent_last_obs = {}
+    agent_last_action = {}
+    nb_hit = 0
+
     while True:
-        #####################################################################
-        # Evaluation of a single episode
-        #
-        #####################################################################
-        # Compute the action for this step by using the previously 
-        # defined controller
-        time_start = time.time()
-        action = my_controller(observation, number_of_agents)
-        time_taken = time.time() - time_start
-        time_taken_by_controller.append(time_taken)
-
-        # Perform the chosen action on the environment.
-        # The action gets applied to both the local and the remote copy 
-        # of the environment instance, and the observation is what is 
-        # returned by the local copy of the env, and the rewards, and done and info
-        # are returned by the remote copy of the env
-        time_start = time.time()
-        observation, all_rewards, done, info = remote_client.env_step(action)
-        steps += 1
-        time_taken = time.time() - time_start
-        time_taken_per_step.append(time_taken)
-
-        if done['__all__']:
-            print("Reward : ", sum(list(all_rewards.values())))
-            #
-            # When done['__all__'] == True, then the evaluation of this 
-            # particular Env instantiation is complete, and we can break out 
-            # of this loop, and move onto the next Env evaluation
+        try:
+            #####################################################################
+            # Evaluation of a single episode
+            #####################################################################
+            steps += 1
+            obs_time, agent_time, step_time = 0.0, 0.0, 0.0
+            no_ops_mode = False
+
+            if not check_if_all_blocked(env=local_env):
+                time_start = time.time()
+                action_dict = {}
+                for agent in range(nb_agents):
+                    if observation[agent] and info['action_required'][agent]:
+                        if agent in agent_last_obs and np.all(agent_last_obs[agent] == observation[agent]):
+                            # cache hit
+                            action = agent_last_action[agent]
+                            nb_hit += 1
+                        else:
+                            # otherwise, run normalization and inference
+                            norm_obs = normalize_observation(observation[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
+                            action = policy.act(norm_obs, eps=0.0)
+
+                        action_dict[agent] = action
+
+                        if USE_ACTION_CACHE:
+                            agent_last_obs[agent] = observation[agent]
+                            agent_last_action[agent] = action
+                agent_time = time.time() - time_start
+                time_taken_by_controller.append(agent_time)
+
+                time_start = time.time()
+                _, all_rewards, done, info = remote_client.env_step(action_dict)
+                step_time = time.time() - time_start
+                time_taken_per_step.append(step_time)
+
+                time_start = time.time()
+                observation = tree_observation.get_many(list(range(nb_agents)))
+                obs_time = time.time() - time_start
+
+            else:
+                # Fully deadlocked: perform no-ops
+                no_ops_mode = True
+
+                time_start = time.time()
+                _, all_rewards, done, info = remote_client.env_step({})
+                step_time = time.time() - time_start
+                time_taken_per_step.append(step_time)
+
+            nb_agents_done = sum(done[idx] for idx in local_env.get_agent_handles())
+
+            if VERBOSE or done['__all__']:
+                print("Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
+                    str(steps).zfill(4),
+                    max_nb_steps,
+                    nb_agents_done,
+                    obs_time,
+                    agent_time,
+                    step_time,
+                    nb_hit,
+                    no_ops_mode
+                ), end="\r")
+
+            if done['__all__']:
+                # When done['__all__'] == True, then the evaluation of this
+                # particular Env instantiation is complete, and we can break out
+                # of this loop, and move onto the next Env evaluation
+                print()
+                break
+
+        except TimeoutException as err:
+            # A timeout occurs, won't get any reward for this episode :-(
+            # Skip to next episode as further actions in this one will be ignored.
+            # The whole evaluation will be stopped if there are 10 consecutive timeouts.
+            print("Timeout! Will skip this episode and go to the next.", err)
             break
-    
+
     np_time_taken_by_controller = np.array(time_taken_by_controller)
     np_time_taken_per_step = np.array(time_taken_per_step)
-    print("="*100)
-    print("="*100)
-    print("Evaluation Number : ", evaluation_number)
-    print("Current Env Path : ", remote_client.current_env_path)
-    print("Env Creation Time : ", env_creation_time)
-    print("Number of Steps : ", steps)
     print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), np_time_taken_by_controller.std())
     print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
-    print("="*100)
+    print("=" * 100)
 
-print("Evaluation of all environments complete...")
+print("Evaluation of all environments complete!")
 ########################################################################
 # Submit your Results
-# 
-# Please do not forget to include this call, as this triggers the 
+#
+# Please do not forget to include this call, as this triggers the
 # final computation of the score statistics, video generation, etc
-# and is necesaary to have your submission marked as successfully evaluated
+# and is necessary to have your submission marked as successfully evaluated
 ########################################################################
 print(remote_client.submit())
diff --git a/sweep.yaml b/sweep.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bd9299caaf8a6c0bf2ee07c2d5292731cc2f2a28
--- /dev/null
+++ b/sweep.yaml
@@ -0,0 +1,21 @@
+# This sweep file can be used to run hyper-parameter search using Weight & Biases tools
+# See: https://docs.wandb.com/sweeps
+program: reinforcement_learning/multi_agent_training.py
+method: bayes
+metric:
+    name: evaluation/smoothed_score
+    goal: maximize
+parameters:
+    n_episodes:
+        values: [2000]
+    hidden_size:
+        # default: 256
+        values: [128, 256, 512]
+    buffer_size:
+        # default: 50000
+        values: [50000, 100000, 500000, 1000000]
+    batch_size:
+        # default: 32
+        values: [16, 32, 64, 128]
+    training_env_config:
+        values: [0, 1, 2]
\ No newline at end of file
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/deadlock_check.py b/utils/deadlock_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d414fac7b0baf4efd9c43951ad92ace0ae79d5d
--- /dev/null
+++ b/utils/deadlock_check.py
@@ -0,0 +1,42 @@
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+
+
+def check_if_all_blocked(env):
+    """
+    Checks whether all the agents are blocked (full deadlock situation).
+    In that case it is pointless to keep running inference as no agent will be able to move.
+    :param env: current environment
+    :return:
+    """
+
+    # First build a map of agents in each position
+    location_has_agent = {}
+    for agent in env.agents:
+        if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position:
+            location_has_agent[tuple(agent.position)] = 1
+
+    # Looks for any agent that can still move
+    for handle in env.get_agent_handles():
+        agent = env.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            agent_virtual_position = agent.target
+        else:
+            continue
+
+        possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
+        orientation = agent.direction
+
+        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
+            if possible_transitions[branch_direction]:
+                new_position = get_new_position(agent_virtual_position, branch_direction)
+
+                if new_position not in location_has_agent:
+                    return False
+
+    # No agent can move at all: full deadlock!
+    return True
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa18aefe8ba2aeeb75ff0287b53a001db174db42
--- /dev/null
+++ b/utils/observation_utils.py
@@ -0,0 +1,124 @@
+import numpy as np
+from flatland.envs.observations import TreeObsForRailEnv
+
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    max = 0
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
+            max = seq[idx]
+        idx -= 1
+    return max
+
+
+def min_gt(seq, val):
+    """
+    Return smallest item in seq for which item > val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    min = np.inf
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] >= val and seq[idx] < min:
+            min = seq[idx]
+        idx -= 1
+    return min
+
+
+def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
+    """
+    This function returns the difference between min and max value of an observation
+    :param obs: Observation that should be normalized
+    :param clip_min: min value where observation will be clipped
+    :param clip_max: max value where observation will be clipped
+    :return: returnes normalized and clipped observatoin
+    """
+    if fixed_radius > 0:
+        max_obs = fixed_radius
+    else:
+        max_obs = max(1, max_lt(obs, 1000)) + 1
+
+    min_obs = 0  # min(max_obs, min_gt(obs, 0))
+    if normalize_to_range:
+        min_obs = min_gt(obs, 0)
+    if min_obs > max_obs:
+        min_obs = max_obs
+    if max_obs == min_obs:
+        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
+    norm = np.abs(max_obs - min_obs)
+    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
+
+
+def _split_node_into_feature_groups(node) -> (np.ndarray, np.ndarray, np.ndarray):
+    data = np.zeros(6)
+    distance = np.zeros(1)
+    agent_data = np.zeros(4)
+
+    data[0] = node.dist_own_target_encountered
+    data[1] = node.dist_other_target_encountered
+    data[2] = node.dist_other_agent_encountered
+    data[3] = node.dist_potential_conflict
+    data[4] = node.dist_unusable_switch
+    data[5] = node.dist_to_next_branch
+
+    distance[0] = node.dist_min_to_target
+
+    agent_data[0] = node.num_agents_same_direction
+    agent_data[1] = node.num_agents_opposite_direction
+    agent_data[2] = node.num_agents_malfunctioning
+    agent_data[3] = node.speed_min_fractional
+
+    return data, distance, agent_data
+
+
+def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
+    if node == -np.inf:
+        remaining_depth = max_tree_depth - current_tree_depth
+        # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
+        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
+        return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4
+
+    data, distance, agent_data = _split_node_into_feature_groups(node)
+
+    if not node.childs:
+        return data, distance, agent_data
+
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
+        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
+        data = np.concatenate((data, sub_data))
+        distance = np.concatenate((distance, sub_distance))
+        agent_data = np.concatenate((agent_data, sub_agent_data))
+
+    return data, distance, agent_data
+
+
+def split_tree_into_feature_groups(tree, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
+    """
+    This function splits the tree into three difference arrays of values
+    """
+    data, distance, agent_data = _split_node_into_feature_groups(tree)
+
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
+        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
+        data = np.concatenate((data, sub_data))
+        distance = np.concatenate((distance, sub_distance))
+        agent_data = np.concatenate((agent_data, sub_agent_data))
+
+    return data, distance, agent_data
+
+
+def normalize_observation(observation, tree_depth: int, observation_radius=0):
+    """
+    This function normalizes the observation used by the RL algorithm
+    """
+    data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
+
+    data = norm_obs_clip(data, fixed_radius=observation_radius)
+    distance = norm_obs_clip(distance, normalize_to_range=True)
+    agent_data = np.clip(agent_data, -1, 1)
+    normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
+    return normalized_obs
diff --git a/utils/timer.py b/utils/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa02e9f18bb01fc9730c484c079aa96556027f2b
--- /dev/null
+++ b/utils/timer.py
@@ -0,0 +1,33 @@
+from timeit import default_timer
+
+
+class Timer(object):
+    """
+    Utility to measure times.
+
+    TODO:
+    - add "lap" method to make it easier to measure average time (+std) when measuring the same thing multiple times.
+    """
+
+    def __init__(self):
+        self.total_time = 0.0
+        self.start_time = 0.0
+        self.end_time = 0.0
+
+    def start(self):
+        self.start_time = default_timer()
+
+    def end(self):
+        self.total_time += default_timer() - self.start_time
+
+    def get(self):
+        return self.total_time
+
+    def get_current(self):
+        return default_timer() - self.start_time
+
+    def reset(self):
+        self.__init__()
+
+    def __repr__(self):
+        return self.get()
\ No newline at end of file