diff --git a/.gitignore b/.gitignore
index 54ef0cba10a309cb13c6077b720208c5e142257b..981e6a55487ba3d561923a03601cdec88b214c3b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,127 +1,132 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-pip-wheel-metadata/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-#  Usually these files are written by a python script from a template
-#  before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-.hypothesis/
-.pytest_cache/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# pipenv
-#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-#   However, in case of collaboration, if having platform-specific dependencies or dependencies
-#   having no cross-platform support, pipenv may install dependencies that don't work, or not
-#   install all needed dependencies.
-#Pipfile.lock
-
-# celery beat schedule file
-celerybeat-schedule
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-scratch/test-envs/
-scratch/
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+scratch/test-envs/
+scratch/
+
+# Checkpoints and replay buffers
+!checkpoints/.gitkeep
+replay_buffers/*
+!replay_buffers/.gitkeep
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..58d76d29beaa073412d53bbeab08ebcddc7151b9
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 Flatland
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 8232bb1e569f62e96028faf3bc31adc12c9c201e..60170570eeee4b7c6fa00583293059f03abc11bd 100644
--- a/README.md
+++ b/README.md
@@ -1,45 +1,105 @@
-![AIcrowd-Logo](https://raw.githubusercontent.com/AIcrowd/AIcrowd/master/app/assets/images/misc/aicrowd-horizontal.png)
-
-# Flatland Challenge Starter Kit
-
-**[Follow these instructions to submit your solutions!](http://flatland.aicrowd.com/getting-started/first-submission.html)**
-
-
-![flatland](https://i.imgur.com/0rnbSLY.gif)
-
-
-# Round 1 - 3rd best RL solution 
-
-## Used agent 
-* [PPO Agent -> Mitchell Goff](https://github.com/mitchellgoffpc/flatland-training)
-
-## LICENCE for the Observation EXTRA.py  
-
-The observation can be used freely and reused for further submissions. Only the author needs to be referred to
-/mentioned in any submissions - if the entire observation or parts, or the main idea is used.
-
-Author: Adrian Egli (adrian.egli@gmail.com)
-
-[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2)
-[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/)
-
-
-
-
-Main links
----
-* [Submit in 10 minutes](https://flatland.aicrowd.com/getting-started/first-submission.html?_ga=2.175036450.1456714032.1596434204-43124944.1552486604)
-* [Flatland documentation](https://flatland.aicrowd.com/)
-* [NeurIPS 2020 Challenge](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/)
-
-Communication
----
-
-* [Discord Channel](https://discord.com/invite/hCR3CZG)
-* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge)
-* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
-
-Author
----
-
-- **[Sharada Mohanty](https://twitter.com/MeMohanty)**
+🚂 Starter Kit - NeurIPS 2020 Flatland Challenge
+===
+
+This starter kit contains 2 example policies to get started with this challenge: 
+- a simple single-agent DQN method
+- a more robust multi-agent DQN method that you can submit out of the box to the challenge 🚀
+
+**🔗 [Train the single-agent DQN policy](https://flatland.aicrowd.com/getting-started/rl/single-agent.html)**
+
+**🔗 [Train the multi-agent DQN policy](https://flatland.aicrowd.com/getting-started/rl/multi-agent.html)**
+
+**🔗 [Submit a trained policy](https://flatland.aicrowd.com/getting-started/first-submission.html)**
+
+The single-agent example is meant as a minimal example of how to use DQN. The multi-agent is a better starting point to create your own solution.
+
+You can fully train the multi-agent policy in Colab for free! [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GbPwZNQU7KJIJtilcGBTtpOAD3EabAzJ?usp=sharing)
+
+Sample training usage
+---
+
+Train the multi-agent policy for 150 episodes:
+
+```bash
+python reinforcement_learning/multi_agent_training.py -n 150
+```
+
+The multi-agent policy training can be tuned using command-line arguments:
+
+```console 
+usage: multi_agent_training.py [-h] [-n N_EPISODES] [-t TRAINING_ENV_CONFIG]
+                               [-e EVALUATION_ENV_CONFIG]
+                               [--n_evaluation_episodes N_EVALUATION_EPISODES]
+                               [--checkpoint_interval CHECKPOINT_INTERVAL]
+                               [--eps_start EPS_START] [--eps_end EPS_END]
+                               [--eps_decay EPS_DECAY]
+                               [--buffer_size BUFFER_SIZE]
+                               [--buffer_min_size BUFFER_MIN_SIZE]
+                               [--restore_replay_buffer RESTORE_REPLAY_BUFFER]
+                               [--save_replay_buffer SAVE_REPLAY_BUFFER]
+                               [--batch_size BATCH_SIZE] [--gamma GAMMA]
+                               [--tau TAU] [--learning_rate LEARNING_RATE]
+                               [--hidden_size HIDDEN_SIZE]
+                               [--update_every UPDATE_EVERY]
+                               [--use_gpu USE_GPU] [--num_threads NUM_THREADS]
+                               [--render RENDER]
+
+optional arguments:
+  -h, --help            show this help message and exit
+  -n N_EPISODES, --n_episodes N_EPISODES
+                        number of episodes to run
+  -t TRAINING_ENV_CONFIG, --training_env_config TRAINING_ENV_CONFIG
+                        training config id (eg 0 for Test_0)
+  -e EVALUATION_ENV_CONFIG, --evaluation_env_config EVALUATION_ENV_CONFIG
+                        evaluation config id (eg 0 for Test_0)
+  --n_evaluation_episodes N_EVALUATION_EPISODES
+                        number of evaluation episodes
+  --checkpoint_interval CHECKPOINT_INTERVAL
+                        checkpoint interval
+  --eps_start EPS_START
+                        max exploration
+  --eps_end EPS_END     min exploration
+  --eps_decay EPS_DECAY
+                        exploration decay
+  --buffer_size BUFFER_SIZE
+                        replay buffer size
+  --buffer_min_size BUFFER_MIN_SIZE
+                        min buffer size to start training
+  --restore_replay_buffer RESTORE_REPLAY_BUFFER
+                        replay buffer to restore
+  --save_replay_buffer SAVE_REPLAY_BUFFER
+                        save replay buffer at each evaluation interval
+  --batch_size BATCH_SIZE
+                        minibatch size
+  --gamma GAMMA         discount factor
+  --tau TAU             soft update of target parameters
+  --learning_rate LEARNING_RATE
+                        learning rate
+  --hidden_size HIDDEN_SIZE
+                        hidden size (2 fc layers)
+  --update_every UPDATE_EVERY
+                        how often to update the network
+  --use_gpu USE_GPU     use GPU if available
+  --num_threads NUM_THREADS
+                        number of threads PyTorch can use
+  --render RENDER       render 1 episode in 100
+```
+
+[**📈 Performance training in environments of various sizes**](https://wandb.ai/masterscrat/flatland-examples-reinforcement_learning/reports/Flatland-Starter-Kit-Training-in-environments-of-various-sizes--VmlldzoxNjgxMTk)
+
+[**📈 Performance with various hyper-parameters**](https://app.wandb.ai/masterscrat/flatland-examples-reinforcement_learning/reports/Flatland-Examples--VmlldzoxNDI2MTA)
+
+[![](https://i.imgur.com/Lqrq5GE.png)](https://app.wandb.ai/masterscrat/flatland-examples-reinforcement_learning/reports/Flatland-Examples--VmlldzoxNDI2MTA) 
+
+Main links
+---
+
+* [Flatland documentation](https://flatland.aicrowd.com/)
+* [NeurIPS 2020 Challenge](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/)
+
+Communication
+---
+
+* [Discord Channel](https://discord.com/invite/hCR3CZG)
+* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge)
+* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
\ No newline at end of file
diff --git a/aicrowd.json b/aicrowd.json
index 976e3fde7acb926a434bbf6580aba66c359802c4..68c76af4fd222127604c5e5e3252429f9795fa4c 100644
--- a/aicrowd.json
+++ b/aicrowd.json
@@ -1,7 +1,7 @@
-{
-  "challenge_id": "neurips-2020-flatland-challenge",
-  "grader_id": "neurips-2020-flatland-challenge",
-  "debug": false,
-  "tags": ["other"]
-}
-
+{
+  "challenge_id": "neurips-2020-flatland-challenge",
+  "grader_id": "neurips-2020-flatland-challenge",
+  "debug": false,
+  "tags": ["RL"]
+}
+
diff --git a/apt.txt b/apt.txt
index 881f321972f314b2ebdd57b611507a89ffd9536e..d593bcc792fa1aa73b3ca7ab57a93e91dd738d8e 100644
--- a/apt.txt
+++ b/apt.txt
@@ -3,3 +3,4 @@ git
 vim
 ssh
 gcc
+build-essential
\ No newline at end of file
diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/checkpoints/210122120236-3000.pth.local b/checkpoints/210122120236-3000.pth.local
new file mode 100644
index 0000000000000000000000000000000000000000..fc041a1fb19c59269b14012b7f4a99cdf059f19f
Binary files /dev/null and b/checkpoints/210122120236-3000.pth.local differ
diff --git a/checkpoints/210122120236-3000.pth.target b/checkpoints/210122120236-3000.pth.target
new file mode 100644
index 0000000000000000000000000000000000000000..c68bab3e29942319e053d9a4022fcf93042fa6e9
Binary files /dev/null and b/checkpoints/210122120236-3000.pth.target differ
diff --git a/checkpoints/210122165109-5000.pth.local b/checkpoints/210122165109-5000.pth.local
new file mode 100644
index 0000000000000000000000000000000000000000..b5b0ca5f919d54edb98387f1d50b186cadb3439d
Binary files /dev/null and b/checkpoints/210122165109-5000.pth.local differ
diff --git a/checkpoints/210122165109-5000.pth.target b/checkpoints/210122165109-5000.pth.target
new file mode 100644
index 0000000000000000000000000000000000000000..e12b79aff9be0c9e3fb4b917665ed1d883da2998
Binary files /dev/null and b/checkpoints/210122165109-5000.pth.target differ
diff --git a/checkpoints/210122235754-5000.pth.actor b/checkpoints/210122235754-5000.pth.actor
new file mode 100644
index 0000000000000000000000000000000000000000..07661e685ff2f34e9320d6a83f6cac3f9629807e
Binary files /dev/null and b/checkpoints/210122235754-5000.pth.actor differ
diff --git a/checkpoints/210122235754-5000.pth.optimizer b/checkpoints/210122235754-5000.pth.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..d581eb40d2a7e6dccc55697a4676883cc01d85c2
Binary files /dev/null and b/checkpoints/210122235754-5000.pth.optimizer differ
diff --git a/checkpoints/210122235754-5000.pth.value b/checkpoints/210122235754-5000.pth.value
new file mode 100644
index 0000000000000000000000000000000000000000..c323fa9a74a8e6c1ba62b0140fbebaa6e85b824b
Binary files /dev/null and b/checkpoints/210122235754-5000.pth.value differ
diff --git a/checkpoints/sample-checkpoint.pth b/checkpoints/sample-checkpoint.pth
new file mode 100644
index 0000000000000000000000000000000000000000..3fd7a50d88963dc4aa657825757fbfbfa51d508a
Binary files /dev/null and b/checkpoints/sample-checkpoint.pth differ
diff --git a/dump.rdb b/dump.rdb
deleted file mode 100644
index d719ed7cce7a692fb2775ab881f5020877164240..0000000000000000000000000000000000000000
Binary files a/dump.rdb and /dev/null differ
diff --git a/environment.yml b/environment.yml
index e79148acd6e4f97031ea83a2fc01f5f941a8f1e0..19f1c10223fb9ba9c15809cb7789af09aae5f5f7 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,112 +1,13 @@
-name: flatland-rl
-channels:
-  - anaconda
-  - conda-forge
-  - defaults
-dependencies:
-  - tk=8.6.8
-  - cairo=1.16.0
-  - cairocffi=1.1.0
-  - cairosvg=2.4.2
-  - cffi=1.12.3
-  - cssselect2=0.2.1
-  - defusedxml=0.6.0
-  - fontconfig=2.13.1
-  - freetype=2.10.0
-  - gettext=0.19.8.1
-  - glib=2.58.3
-  - icu=64.2
-  - jpeg=9c
-  - libiconv=1.15
-  - libpng=1.6.37
-  - libtiff=4.0.10
-  - libuuid=2.32.1
-  - libxcb=1.13
-  - libxml2=2.9.9
-  - lz4-c=1.8.3
-  - olefile=0.46
-  - pcre=8.41
-  - pillow=5.3.0
-  - pixman=0.38.0
-  - pthread-stubs=0.4
-  - pycairo=1.18.1
-  - pycparser=2.19
-  - tinycss2=1.0.2
-  - webencodings=0.5.1
-  - xorg-kbproto=1.0.7
-  - xorg-libice=1.0.10
-  - xorg-libsm=1.2.3
-  - xorg-libx11=1.6.8
-  - xorg-libxau=1.0.9
-  - xorg-libxdmcp=1.1.3
-  - xorg-libxext=1.3.4
-  - xorg-libxrender=0.9.10
-  - xorg-renderproto=0.11.1
-  - xorg-xextproto=7.3.0
-  - xorg-xproto=7.0.31
-  - zstd=1.4.0
-  - _libgcc_mutex=0.1
-  - ca-certificates=2019.5.15
-  - certifi=2019.6.16
-  - libedit=3.1.20181209
-  - libffi=3.2.1
-  - ncurses=6.1
-  - openssl=1.1.1c
-  - pip=19.1.1
-  - python=3.6.8
-  - readline=7.0
-  - setuptools=41.0.1
-  - sqlite=3.29.0
-  - wheel=0.33.4
-  - xz=5.2.4
-  - zlib=1.2.11
-  - pip:
-    - atomicwrites==1.3.0
-    - importlib-metadata==0.19
-    - importlib-resources==1.0.2
-    - attrs==19.1.0
-    - chardet==3.0.4
-    - click==7.0
-    - cloudpickle==1.2.2
-    - crowdai-api==0.1.21
-    - cycler==0.10.0
-    - filelock==3.0.12
-    - flatland-rl==2.2.1
-    - future==0.17.1
-    - gym==0.14.0
-    - idna==2.8
-    - kiwisolver==1.1.0
-    - lxml==4.4.0
-    - matplotlib==3.1.1
-    - more-itertools==7.2.0
-    - msgpack==0.6.1
-    - msgpack-numpy==0.4.4.3
-    - numpy==1.17.0
-    - packaging==19.0
-    - pandas==0.25.0
-    - pluggy==0.12.0
-    - py==1.8.0
-    - pyarrow==0.14.1
-    - pyglet==1.3.2
-    - pyparsing==2.4.1.1
-    - pytest==5.0.1
-    - pytest-runner==5.1
-    - python-dateutil==2.8.0
-    - python-gitlab==1.10.0
-    - pytz==2019.1
-    - torch==1.5.0
-    - recordtype==1.3
-    - redis==3.3.2
-    - requests==2.22.0
-    - scipy==1.3.1
-    - six==1.12.0
-    - svgutils==0.3.1
-    - timeout-decorator==0.4.1
-    - toml==0.10.0
-    - tox==3.13.2
-    - urllib3==1.25.3
-    - ushlex==0.99.1
-    - virtualenv==16.7.2
-    - wcwidth==0.1.7
-    - xarray==0.12.3
-    - zipp==0.5.2
+name: flatland-rl
+channels:
+  - pytorch
+  - conda-forge
+  - defaults
+dependencies:
+  - psutil==5.7.2
+  - pytorch==1.6.0
+  - pip==20.2.3
+  - python==3.6.8
+  - pip:
+      - tensorboard==2.3.0
+      - tensorboardx==2.1
\ No newline at end of file
diff --git a/my_observation_builder.py b/my_observation_builder.py
deleted file mode 100644
index 482eecfc36c53e9390b1600b7d0ab9a02f032d9a..0000000000000000000000000000000000000000
--- a/my_observation_builder.py
+++ /dev/null
@@ -1,101 +0,0 @@
-#!/usr/bin/env python 
-
-import collections
-from typing import Optional, List, Dict, Tuple
-
-import numpy as np
-
-from flatland.core.env import Environment
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
-
-
-class CustomObservationBuilder(ObservationBuilder):
-    """
-    Template for building a custom observation builder for the RailEnv class
-
-    The observation in this case composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width),\
-          where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
-        
-        - the individual agent object (with position, direction, target information available)
-
-    """
-    def __init__(self):
-        super(CustomObservationBuilder, self).__init__()
-
-    def set_env(self, env: Environment):
-        super().set_env(env)
-        # Note :
-        # The instantiations which depend on parameters of the Env object should be 
-        # done here, as it is only here that the updated self.env instance is available
-        self.rail_obs = np.zeros((self.env.height, self.env.width))
-
-    def reset(self):
-        """
-        Called internally on every env.reset() call, 
-        to reset any observation specific variables that are being used
-        """
-        self.rail_obs[:] = 0        
-        for _x in range(self.env.width):
-            for _y in range(self.env.height):
-                # Get the transition map value at location _x, _y
-                transition_value = self.env.rail.get_full_transitions(_y, _x)
-                self.rail_obs[_y, _x] = transition_value
-
-    def get(self, handle: int = 0):
-        """
-        Returns the built observation for a single agent with handle : handle
-
-        In this particular case, we return 
-        - the global transition_map of the RailEnv,
-        - a tuple containing, the current agent's:
-            - state
-            - position
-            - direction
-            - initial_position
-            - target
-        """
-
-        agent = self.env.agents[handle]
-        """
-        Available information for each agent object : 
-
-        - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
-        - agent.position : Current position of the agent
-        - agent.direction : Current direction of the agent
-        - agent.initial_position : Initial Position of the agent
-        - agent.target : Target position of the agent
-        """
-
-        status = agent.status
-        position = agent.position
-        direction = agent.direction
-        initial_position = agent.initial_position
-        target = agent.target
-
-        
-        """
-        You can also optionally access the states of the rest of the agents by 
-        using something similar to 
-
-        for i in range(len(self.env.agents)):
-            other_agent: EnvAgent = self.env.agents[i]
-
-            # ignore other agents not in the grid any more
-            if other_agent.status == RailAgentStatus.DONE_REMOVED:
-                continue
-
-            ## Gather other agent specific params 
-            other_agent_status = other_agent.status
-            other_agent_position = other_agent.position
-            other_agent_direction = other_agent.direction
-            other_agent_initial_position = other_agent.initial_position
-            other_agent_target = other_agent.target
-
-            ## Do something nice here if you wish
-        """
-        return self.rail_obs, (status, position, direction, initial_position, target)
-
diff --git a/reinforcement_learning/dddqn_policy.py b/reinforcement_learning/dddqn_policy.py
index 95c343ec6e87a200e5c2fa5565596b8852155e13..864c6a78dd293b16aedee6fb97b7de422f8134d4 100644
--- a/reinforcement_learning/dddqn_policy.py
+++ b/reinforcement_learning/dddqn_policy.py
@@ -2,7 +2,6 @@ import copy
 import os
 import pickle
 import random
-from collections import namedtuple, deque, Iterable
 
 import numpy as np
 import torch
@@ -10,32 +9,37 @@ import torch.nn.functional as F
 import torch.optim as optim
 
 from reinforcement_learning.model import DuelingQNetwork
-from reinforcement_learning.policy import Policy
+from reinforcement_learning.policy import Policy, LearningPolicy
+from reinforcement_learning.replay_buffer import ReplayBuffer
 
 
-class DDDQNPolicy(Policy):
+class DDDQNPolicy(LearningPolicy):
     """Dueling Double DQN policy"""
 
-    def __init__(self, state_size, action_size, parameters, evaluation_mode=False):
+    def __init__(self, state_size, action_size, in_parameters, evaluation_mode=False):
+        print(">> DDDQNPolicy")
+        super(Policy, self).__init__()
+
+        self.ddqn_parameters = in_parameters
         self.evaluation_mode = evaluation_mode
 
         self.state_size = state_size
         self.action_size = action_size
         self.double_dqn = True
-        self.hidsize = 1
+        self.hidsize = 128
 
         if not evaluation_mode:
-            self.hidsize = parameters.hidden_size
-            self.buffer_size = parameters.buffer_size
-            self.batch_size = parameters.batch_size
-            self.update_every = parameters.update_every
-            self.learning_rate = parameters.learning_rate
-            self.tau = parameters.tau
-            self.gamma = parameters.gamma
-            self.buffer_min_size = parameters.buffer_min_size
-
-        # Device
-        if parameters.use_gpu and torch.cuda.is_available():
+            self.hidsize = self.ddqn_parameters.hidden_size
+            self.buffer_size = self.ddqn_parameters.buffer_size
+            self.batch_size = self.ddqn_parameters.batch_size
+            self.update_every = self.ddqn_parameters.update_every
+            self.learning_rate = self.ddqn_parameters.learning_rate
+            self.tau = self.ddqn_parameters.tau
+            self.gamma = self.ddqn_parameters.gamma
+            self.buffer_min_size = self.ddqn_parameters.buffer_min_size
+
+            # Device
+        if self.ddqn_parameters.use_gpu and torch.cuda.is_available():
             self.device = torch.device("cuda:0")
             # print("🐇 Using GPU")
         else:
@@ -43,26 +47,31 @@ class DDDQNPolicy(Policy):
             # print("🐢 Using CPU")
 
         # Q-Network
-        self.qnetwork_local = DuelingQNetwork(state_size, action_size, hidsize1=self.hidsize, hidsize2=self.hidsize).to(
-            self.device)
+        self.qnetwork_local = DuelingQNetwork(state_size,
+                                              action_size,
+                                              hidsize1=self.hidsize,
+                                              hidsize2=self.hidsize).to(self.device)
 
         if not evaluation_mode:
             self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
             self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate)
             self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
-
             self.t_step = 0
             self.loss = 0.0
+        else:
+            self.memory = ReplayBuffer(action_size, 1, 1, self.device)
+            self.loss = 0.0
 
     def act(self, handle, state, eps=0.):
         state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
         self.qnetwork_local.eval()
         with torch.no_grad():
             action_values = self.qnetwork_local(state)
+
         self.qnetwork_local.train()
 
         # Epsilon-greedy action selection
-        if random.random() > eps:
+        if random.random() >= eps:
             return np.argmax(action_values.cpu().data.numpy())
         else:
             return random.choice(np.arange(self.action_size))
@@ -80,23 +89,9 @@ class DDDQNPolicy(Policy):
             if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
                 self._learn()
 
-    def _clip_gradient(self, model, clip):
-        """Computes a gradient clipping coefficient based on gradient norm."""
-        totalnorm = 0
-        for p in model.parameters():
-            if p.grad is not None:
-                modulenorm = p.grad.data.norm()
-                totalnorm += modulenorm ** 2
-        totalnorm = np.sqrt(totalnorm)
-        coeff = min(1, clip / (totalnorm + 1e-6))
-
-        for p in model.parameters():
-            if p.grad is not None:
-                p.grad.mul_(coeff)
-
     def _learn(self):
         experiences = self.memory.sample()
-        states, actions, rewards, next_states, dones = experiences
+        states, actions, rewards, next_states, dones, _ = experiences
 
         # Get expected Q values from local model
         q_expected = self.qnetwork_local(states).gather(1, actions)
@@ -118,10 +113,6 @@ class DDDQNPolicy(Policy):
         # Minimize the loss
         self.optimizer.zero_grad()
         self.loss.backward()
-        # for param in self.qnetwork_local.parameters():
-        #   param.grad.data.clamp_(-1.0, 1.0)
-        self._clip_gradient(self.qnetwork_local, 1.0)
-
         self.optimizer.step()
 
         # Update target network
@@ -138,13 +129,20 @@ class DDDQNPolicy(Policy):
         torch.save(self.qnetwork_target.state_dict(), filename + ".target")
 
     def load(self, filename):
-        print("load policy from file", filename)
-        if os.path.exists(filename + ".local"):
-            print(' >> ', filename + ".local")
-            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
-        if os.path.exists(filename + ".target"):
-            print(' >> ', filename + ".target")
-            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+        try:
+            if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
+                self.qnetwork_local.load_state_dict(torch.load(filename + ".local", map_location=self.device))
+                print("qnetwork_local loaded ('{}')".format(filename + ".local"))
+                if not self.evaluation_mode:
+                    self.qnetwork_target.load_state_dict(torch.load(filename + ".target", map_location=self.device))
+                    print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
+            else:
+                print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
+                                                                                             filename + ".target"))
+        except Exception as exc:
+            print(exc)
+            print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local",
+                                                                                           filename + ".target"))
 
     def save_replay_buffer(self, filename):
         memory = self.memory.memory
@@ -156,57 +154,11 @@ class DDDQNPolicy(Policy):
             self.memory.memory = pickle.load(f)
 
     def test(self):
-        self.act(np.array([[0] * self.state_size]))
+        self.act(0, np.array([[0] * self.state_size]))
         self._learn()
 
-
-Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
-
-
-class ReplayBuffer:
-    """Fixed-size buffer to store experience tuples."""
-
-    def __init__(self, action_size, buffer_size, batch_size, device):
-        """Initialize a ReplayBuffer object.
-
-        Params
-        ======
-            action_size (int): dimension of each action
-            buffer_size (int): maximum size of buffer
-            batch_size (int): size of each training batch
-        """
-        self.action_size = action_size
-        self.memory = deque(maxlen=buffer_size)
-        self.batch_size = batch_size
-        self.device = device
-
-    def add(self, state, action, reward, next_state, done):
-        """Add a new experience to memory."""
-        e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
-        self.memory.append(e)
-
-    def sample(self):
-        """Randomly sample a batch of experiences from memory."""
-        experiences = random.sample(self.memory, k=self.batch_size)
-
-        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
-            .float().to(self.device)
-        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
-            .long().to(self.device)
-        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
-            .float().to(self.device)
-        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
-            .float().to(self.device)
-        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
-            .float().to(self.device)
-
-        return states, actions, rewards, next_states, dones
-
-    def __len__(self):
-        """Return the current size of internal memory."""
-        return len(self.memory)
-
-    def __v_stack_impr(self, states):
-        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
-        np_states = np.reshape(np.array(states), (len(states), sub_dim))
-        return np_states
+    def clone(self):
+        me = DDDQNPolicy(self.state_size, self.action_size, self.ddqn_parameters, evaluation_mode=True)
+        me.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        me.qnetwork_target = copy.deepcopy(self.qnetwork_target)
+        return me
diff --git a/reinforcement_learning/deadlockavoidance_with_decision_agent.py b/reinforcement_learning/deadlockavoidance_with_decision_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..550e73e0c9793ed5b0e1775c47568989fdc83b3a
--- /dev/null
+++ b/reinforcement_learning/deadlockavoidance_with_decision_agent.py
@@ -0,0 +1,85 @@
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+
+from reinforcement_learning.policy import HybridPolicy
+from reinforcement_learning.ppo_agent import PPOPolicy
+from utils.agent_action_config import map_rail_env_action
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+
+
+class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
+
+    def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
+        print(">> DeadLockAvoidanceWithDecisionAgent")
+        super(DeadLockAvoidanceWithDecisionAgent, self).__init__()
+        self.env = env
+        self.state_size = state_size
+        self.action_size = action_size
+        self.learning_agent = learning_agent
+        self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False)
+        self.policy_selector = PPOPolicy(state_size, 2)
+
+        self.memory = self.learning_agent.memory
+        self.loss = self.learning_agent.loss
+
+    def step(self, handle, state, action, reward, next_state, done):
+        select = self.policy_selector.act(handle, state, 0.0)
+        self.policy_selector.step(handle, state, select, reward, next_state, done)
+        self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done)
+        self.learning_agent.step(handle, state, action, reward, next_state, done)
+        self.loss = self.learning_agent.loss
+
+    def act(self, handle, state, eps=0.):
+        select = self.policy_selector.act(handle, state, eps)
+        if select == 0:
+            return self.learning_agent.act(handle, state, eps)
+        return self.dead_lock_avoidance_agent.act(handle, state, -1.0)
+
+    def save(self, filename):
+        self.dead_lock_avoidance_agent.save(filename)
+        self.learning_agent.save(filename)
+        self.policy_selector.save(filename + '.selector')
+
+    def load(self, filename):
+        self.dead_lock_avoidance_agent.load(filename)
+        self.learning_agent.load(filename)
+        self.policy_selector.load(filename + '.selector')
+
+    def start_step(self, train):
+        self.dead_lock_avoidance_agent.start_step(train)
+        self.learning_agent.start_step(train)
+        self.policy_selector.start_step(train)
+
+    def end_step(self, train):
+        self.dead_lock_avoidance_agent.end_step(train)
+        self.learning_agent.end_step(train)
+        self.policy_selector.end_step(train)
+
+    def start_episode(self, train):
+        self.dead_lock_avoidance_agent.start_episode(train)
+        self.learning_agent.start_episode(train)
+        self.policy_selector.start_episode(train)
+
+    def end_episode(self, train):
+        self.dead_lock_avoidance_agent.end_episode(train)
+        self.learning_agent.end_episode(train)
+        self.policy_selector.end_episode(train)
+
+    def load_replay_buffer(self, filename):
+        self.dead_lock_avoidance_agent.load_replay_buffer(filename)
+        self.learning_agent.load_replay_buffer(filename)
+        self.policy_selector.load_replay_buffer(filename + ".selector")
+
+    def test(self):
+        self.dead_lock_avoidance_agent.test()
+        self.learning_agent.test()
+        self.policy_selector.test()
+
+    def reset(self, env: RailEnv):
+        self.env = env
+        self.dead_lock_avoidance_agent.reset(env)
+        self.learning_agent.reset(env)
+        self.policy_selector.reset(env)
+
+    def clone(self):
+        return self
diff --git a/reinforcement_learning/evaluate_agent.py b/reinforcement_learning/evaluate_agent.py
index 2adb14377803d681d44b17b9a1135203e0af59e7..5488f81eae52753a071ef18142a5514579dd4c5c 100644
--- a/reinforcement_learning/evaluate_agent.py
+++ b/reinforcement_learning/evaluate_agent.py
@@ -26,8 +26,9 @@ from utils.observation_utils import normalize_observation
 from reinforcement_learning.dddqn_policy import DDDQNPolicy
 
 
-def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render, allow_skipping, allow_caching):
-    # Evaluation is faster on CPU (except if you use a really huge)
+def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size, state_size, seed, render,
+                allow_skipping, allow_caching):
+    # Evaluation is faster on CPU (except if you use a really huge policy)
     parameters = {
         'use_gpu': False
     }
@@ -140,11 +141,12 @@ def eval_policy(env_params, checkpoint, n_eval_episodes, max_steps, action_size,
 
                     else:
                         preproc_timer.start()
-                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth, observation_radius=observation_radius)
+                        norm_obs = normalize_observation(obs[agent], tree_depth=observation_tree_depth,
+                                                         observation_radius=observation_radius)
                         preproc_timer.end()
 
                         inference_timer.start()
-                        action = policy.act(norm_obs, eps=0.0)
+                        action = policy.act(agent, norm_obs, eps=0.0)
                         inference_timer.end()
 
                     action_dict.update({agent: action})
@@ -319,12 +321,15 @@ def evaluate_agents(file, n_evaluation_episodes, use_gpu, render, allow_skipping
 
     results = []
     if render:
-        results.append(eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping, allow_caching))
+        results.append(
+            eval_policy(params, file, eval_per_thread, max_steps, action_size, state_size, 0, render, allow_skipping,
+                        allow_caching))
 
     else:
         with Pool() as p:
             results = p.starmap(eval_policy,
-                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render, allow_skipping, allow_caching)
+                                [(params, file, 1, max_steps, action_size, state_size, seed * nb_threads, render,
+                                  allow_skipping, allow_caching)
                                  for seed in
                                  range(total_nb_eval)])
 
@@ -367,10 +372,12 @@ if __name__ == "__main__":
 
     parser.add_argument("--use_gpu", dest="use_gpu", help="use GPU if available", action='store_true')
     parser.add_argument("--render", help="render a single episode", action='store_true')
-    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked", action='store_true')
+    parser.add_argument("--allow_skipping", help="skips to the end of the episode if all agents are deadlocked",
+                        action='store_true')
     parser.add_argument("--allow_caching", help="caches the last observation-action pair", action='store_true')
     args = parser.parse_args()
 
     os.environ["OMP_NUM_THREADS"] = str(1)
-    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu, render=args.render,
+    evaluate_agents(file=args.file, n_evaluation_episodes=args.n_evaluation_episodes, use_gpu=args.use_gpu,
+                    render=args.render,
                     allow_skipping=args.allow_skipping, allow_caching=args.allow_caching)
diff --git a/reinforcement_learning/model.py b/reinforcement_learning/model.py
index 223f2f7707d973a1a9821d1ab44d8e2ef63b438f..fc6c8a98db876e7f3489b0db29377640ce41d176 100644
--- a/reinforcement_learning/model.py
+++ b/reinforcement_learning/model.py
@@ -1,31 +1,31 @@
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class DuelingQNetwork(nn.Module):
-    """Dueling Q-network (https://arxiv.org/abs/1511.06581)"""
-
-    def __init__(self, state_size, action_size, hidsize1=64, hidsize2=64):
-        super(DuelingQNetwork, self).__init__()
-
-        # value network
-        self.fc1_val = nn.Linear(state_size, hidsize1)
-        self.fc2_val = nn.Linear(hidsize1, hidsize2)
-        self.fc4_val = nn.Linear(hidsize2, 1)
-
-        # advantage network
-        self.fc1_adv = nn.Linear(state_size, hidsize1)
-        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
-        self.fc4_adv = nn.Linear(hidsize2, action_size)
-
-    def forward(self, x):
-        val = F.relu(self.fc1_val(x))
-        val = F.relu(self.fc2_val(val))
-        val = self.fc4_val(val)
-
-        # advantage calculation
-        adv = F.relu(self.fc1_adv(x))
-        adv = F.relu(self.fc2_adv(adv))
-        adv = self.fc4_adv(adv)
-
-        return val + adv - adv.mean()
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class DuelingQNetwork(nn.Module):
+    """Dueling Q-network (https://arxiv.org/abs/1511.06581)"""
+
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128):
+        super(DuelingQNetwork, self).__init__()
+
+        # value network
+        self.fc1_val = nn.Linear(state_size, hidsize1)
+        self.fc2_val = nn.Linear(hidsize1, hidsize2)
+        self.fc4_val = nn.Linear(hidsize2, 1)
+
+        # advantage network
+        self.fc1_adv = nn.Linear(state_size, hidsize1)
+        self.fc2_adv = nn.Linear(hidsize1, hidsize2)
+        self.fc4_adv = nn.Linear(hidsize2, action_size)
+
+    def forward(self, x):
+        val = F.relu(self.fc1_val(x))
+        val = F.relu(self.fc2_val(val))
+        val = self.fc4_val(val)
+
+        # advantage calculation
+        adv = F.relu(self.fc1_adv(x))
+        adv = F.relu(self.fc2_adv(adv))
+        adv = self.fc4_adv(adv)
+
+        return val + adv - adv.mean()
diff --git a/reinforcement_learning/multi_agent_training.py b/reinforcement_learning/multi_agent_training.py
old mode 100644
new mode 100755
index 542f587b1c3bea557c5d9e5f90e6210415fcf4a7..68bf90bf6865726d05351d3e217f17e7ebe3a05e
--- a/reinforcement_learning/multi_agent_training.py
+++ b/reinforcement_learning/multi_agent_training.py
@@ -9,32 +9,29 @@ from pprint import pprint
 
 import numpy as np
 import psutil
-from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
-from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+from flatland.utils.rendertools import RenderTool
 from torch.utils.tensorboard import SummaryWriter
 
-from utils.deadlock_check import check_if_all_blocked
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
+from reinforcement_learning.multi_decision_agent import MultiDecisionAgent
+from reinforcement_learning.ppo_agent import PPOPolicy
+from utils.agent_action_config import get_flatland_full_action_size, get_action_size, map_actions, map_action, \
+    set_action_size_reduced, set_action_size_full, map_action_policy
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
 
 from utils.timer import Timer
 from utils.observation_utils import normalize_observation
-from reinforcement_learning.dddqn_policy import DDDQNPolicy
-from reinforcement_learning.ppo.ppo_agent import PPOAgent
-
-from utils.extra import Extra, ExtraPolicy
-from reinforcement_learning.multi_policy import MultiPolicy
-
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-
-# https://github.com/dongminlee94/deep_rl
+from utils.fast_tree_obs import FastTreeObs
 
 try:
     import wandb
@@ -45,13 +42,14 @@ except ImportError:
 
 """
 This file shows how to train multiple agents using a reinforcement learning approach.
+After training an agent, you can submit it straight away to the NeurIPS 2020 Flatland challenge!
 
-Documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html
-Results: https://app.wandb.ai/masterscrat/flatland-examples-reinforcement_learning/reports/Flatland-Examples--VmlldzoxNDI2MTA
+Agent documentation: https://flatland.aicrowd.com/getting-started/rl/multi-agent.html
+Submission documentation: https://flatland.aicrowd.com/getting-started/first-submission.html
 """
 
 
-def create_rail_env(env_params, tree_observation, close_following):
+def create_rail_env(env_params, tree_observation):
     n_agents = env_params.n_agents
     x_dim = env_params.x_dim
     y_dim = env_params.y_dim
@@ -79,12 +77,11 @@ def create_rail_env(env_params, tree_observation, close_following):
         number_of_agents=n_agents,
         malfunction_generator_and_process_data=malfunction_from_params(malfunction_parameters),
         obs_builder_object=tree_observation,
-        random_seed=seed,
-        close_following=close_following
+        random_seed=seed
     )
 
 
-def train_agent(train_params, train_env_params, eval_env_params, obs_params, close_following):
+def train_agent(train_params, train_env_params, eval_env_params, obs_params):
     # Environment parameters
     n_agents = train_env_params.n_agents
     x_dim = train_env_params.x_dim
@@ -109,7 +106,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
     eps_decay = train_params.eps_decay
     n_episodes = train_params.n_episodes
     checkpoint_interval = train_params.checkpoint_interval
-    render_interval = 1  # checkpoint_interval
     n_eval_episodes = train_params.n_evaluation_episodes
     restore_replay_buffer = train_params.restore_replay_buffer
     save_replay_buffer = train_params.save_replay_buffer
@@ -120,8 +116,8 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
 
     # Observation builder
     predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
-    if not train_params.use_extra_observation:
-        print("Create TreeObsForRailEnv")
+    if not train_params.use_fast_tree_observation:
+        print("\nUsing standard TreeObs")
 
         def check_is_observation_valid(observation):
             return observation
@@ -133,7 +129,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         tree_observation.check_is_observation_valid = check_is_observation_valid
         tree_observation.get_normalized_observation = get_normalized_observation
     else:
-        print("Create Extra-Observation")
+        print("\nUsing FastTreeObs")
 
         def check_is_observation_valid(observation):
             return True
@@ -141,17 +137,17 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
             return observation
 
-        tree_observation = Extra(max_depth=observation_tree_depth)
+        tree_observation = FastTreeObs(max_depth=observation_tree_depth)
         tree_observation.check_is_observation_valid = check_is_observation_valid
         tree_observation.get_normalized_observation = get_normalized_observation
 
     # Setup the environments
-    train_env = create_rail_env(train_env_params, tree_observation, close_following)
+    train_env = create_rail_env(train_env_params, tree_observation)
     train_env.reset(regenerate_schedule=True, regenerate_rail=True)
-    eval_env = create_rail_env(eval_env_params, tree_observation, close_following)
+    eval_env = create_rail_env(eval_env_params, tree_observation)
     eval_env.reset(regenerate_schedule=True, regenerate_rail=True)
 
-    if not train_params.use_extra_observation:
+    if not train_params.use_fast_tree_observation:
         # Calculate the state size given the depth of the tree observation and the number of features
         n_features_per_node = train_env.obs_builder.observation_dim
         n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
@@ -160,20 +156,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         # Calculate the state size given the depth of the tree observation and the number of features
         state_size = tree_observation.observation_dim
 
-    # Setup renderer
-    if train_params.render:
-        env_renderer = RenderTool(train_env, gl="PGL", agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS)
-
-    # The action space of flatland is 5 discrete actions
-    action_size = 5
-
-    # Max number of steps per episode
-    # This is the official formula used during evaluations
-    # See details in flatland.envs.schedule_generators.sparse_schedule_generator
-    # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
-    max_steps = train_env._max_episode_steps
-
-    action_count = [0] * action_size
+    action_count = [0] * get_flatland_full_action_size()
     action_dict = dict()
     agent_obs = [None] * n_agents
     agent_prev_obs = [None] * n_agents
@@ -181,25 +164,39 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
     update_values = [False] * n_agents
 
     # Smoothed values used as target for hyperparameter tuning
-    smoothed_normalized_score = -1.0
     smoothed_eval_normalized_score = -1.0
-    smoothed_completion = 0.0
     smoothed_eval_completion = 0.0
 
     scores_window = deque(maxlen=checkpoint_interval)  # todo smooth when rendering instead
     completion_window = deque(maxlen=checkpoint_interval)
 
+    if train_params.action_size == "reduced":
+        set_action_size_reduced()
+    else:
+        set_action_size_full()
+
     # Double Dueling DQN policy
-    policy = DDDQNPolicy(state_size, action_size, train_params)
-    if False:
-        policy = ExtraPolicy(state_size, action_size)
-    if False:
-        policy = PPOAgent(state_size, action_size, n_agents, train_env)
-    if False:
-        policy = MultiPolicy(state_size, action_size, n_agents, train_env)
+    if train_params.policy == "DDDQN":
+        policy = DDDQNPolicy(state_size, get_action_size(), train_params)
+    elif train_params.policy == "PPO":
+        policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
+    elif train_params.policy == "DeadLockAvoidance":
+        policy = DeadLockAvoidanceAgent(train_env, get_action_size(), enable_eps=False)
+    elif train_params.policy == "DeadLockAvoidanceWithDecision":
+        # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
+        inter_policy = DDDQNPolicy(state_size, get_action_size(), train_params)
+        policy = DeadLockAvoidanceWithDecisionAgent(train_env, state_size, get_action_size(), inter_policy)
+    elif train_params.policy == "MultiDecision":
+        policy = MultiDecisionAgent(state_size, get_action_size(), train_params)
+    else:
+        policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
+
+    # make sure that at least one policy is set
+    if policy is None:
+        policy = DDDQNPolicy(state_size, get_action_size(), train_params)
 
     # Load existing policy
-    if train_params.load_policy is not None:
+    if train_params.load_policy != "":
         policy.load(train_params.load_policy)
 
     # Loads existing replay buffer
@@ -212,11 +209,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
             print(e)
             exit(1)
 
-    try:
-        print(
-            "\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))
-    except:
-        print("\n💾 Don't have a Replay buffer")
+    print("\n💾 Replay buffer status: {}/{} experiences".format(len(policy.memory.memory), train_params.buffer_size))
 
     hdd = psutil.disk_usage('/')
     if save_replay_buffer and (hdd.free / (2 ** 30)) < 500.0:
@@ -225,10 +218,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
                 hdd.free / (2 ** 30)))
 
     # TensorBoard writer
-    writer = SummaryWriter()
-    writer.add_hparams(vars(train_params), {})  # FIXME
-    writer.add_hparams(vars(train_env_params), {})
-    writer.add_hparams(vars(obs_params), {})
+    writer = SummaryWriter(comment="_" + train_params.policy + "_" + train_params.action_size)
 
     training_timer = Timer()
     training_timer.start()
@@ -243,7 +233,6 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
             training_id
         ))
 
-    rl_policy = policy
     for episode_idx in range(n_episodes + 1):
         step_timer = Timer()
         reset_timer = Timer()
@@ -253,23 +242,21 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
 
         # Reset environment
         reset_timer.start()
-        obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
-
-        # train different number of agents : 1,2,3,... n_agents
-        for handle in range(train_env.get_num_agents()):
-            if (episode_idx % n_agents) < handle:
-                train_env.agents[handle].status = RailAgentStatus.DONE_REMOVED
-
-        # start with simple deadlock avoidance agent policy (imitation learning?)
-        if episode_idx < 500:
-            policy = DeadLockAvoidanceAgent(train_env, state_size, action_size)
+        if train_params.n_agent_fixed:
+            number_of_agents = n_agents
+            train_env_params.n_agents = n_agents
         else:
-            policy = rl_policy
+            number_of_agents = int(min(n_agents, 1 + np.floor(episode_idx / 200)))
+            train_env_params.n_agents = episode_idx % number_of_agents + 1
 
-        policy.reset()
+        train_env = create_rail_env(train_env_params, tree_observation)
+        obs, info = train_env.reset(regenerate_rail=True, regenerate_schedule=True)
+        policy.reset(train_env)
         reset_timer.end()
 
         if train_params.render:
+            # Setup renderer
+            env_renderer = RenderTool(train_env, gl="PGL")
             env_renderer.set_new_rail()
 
         score = 0
@@ -277,39 +264,47 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         actions_taken = []
 
         # Build initial agent-specific observations
-        for agent in train_env.get_agent_handles():
-            if tree_observation.check_is_observation_valid(obs[agent]):
-                agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], observation_tree_depth,
-                                                                               observation_radius=observation_radius)
-                agent_prev_obs[agent] = agent_obs[agent].copy()
+        for agent_handle in train_env.get_agent_handles():
+            if tree_observation.check_is_observation_valid(obs[agent_handle]):
+                agent_obs[agent_handle] = tree_observation.get_normalized_observation(obs[agent_handle],
+                                                                                      observation_tree_depth,
+                                                                                      observation_radius=observation_radius)
+                agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy()
+
+        # Max number of steps per episode
+        # This is the official formula used during evaluations
+        # See details in flatland.envs.schedule_generators.sparse_schedule_generator
+        # max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+        max_steps = train_env._max_episode_steps
 
         # Run episode
+        policy.start_episode(train=True)
         for step in range(max_steps - 1):
             inference_timer.start()
-            policy.start_step()
-            for agent in train_env.get_agent_handles():
-                update_values[agent] = False
-                if info['action_required'][agent]:
-                    update_values[agent] = True
-                    action = policy.act(agent, agent_obs[agent], eps=eps_start)
-                    action_count[action] += 1
-                    actions_taken.append(action)
+            policy.start_step(train=True)
+            for agent_handle in train_env.get_agent_handles():
+                agent = train_env.agents[agent_handle]
+                if info['action_required'][agent_handle]:
+                    update_values[agent_handle] = True
+                    action = policy.act(agent_handle, agent_obs[agent_handle], eps=eps_start)
+                    action_count[map_action(action)] += 1
+                    actions_taken.append(map_action(action))
                 else:
                     # An action is not required if the train hasn't joined the railway network,
                     # if it already reached its target, or if is currently malfunctioning.
+                    update_values[agent_handle] = False
                     action = 0
-                action_dict.update({agent: action})
-            policy.end_step()
-
+                action_dict.update({agent_handle: action})
+            policy.end_step(train=True)
             inference_timer.end()
 
             # Environment step
             step_timer.start()
-            next_obs, all_rewards, done, info = train_env.step(action_dict)
+            next_obs, all_rewards, done, info = train_env.step(map_actions(action_dict))
             step_timer.end()
 
             # Render an episode at some interval
-            if train_params.render and episode_idx % render_interval == 0:
+            if train_params.render:
                 env_renderer.render_env(
                     show=True,
                     frames=False,
@@ -318,36 +313,37 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
                 )
 
             # Update replay buffer and train agent
-            for agent in train_env.get_agent_handles():
-                if update_values[agent] or done['__all__']:
+            for agent_handle in train_env.get_agent_handles():
+                if update_values[agent_handle] or done['__all__']:
                     # Only learn from timesteps where somethings happened
                     learn_timer.start()
-                    policy.step(agent,
-                                agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
-                                agent_obs[agent],
-                                done[agent] and done['__all__'])
+                    policy.step(agent_handle,
+                                agent_prev_obs[agent_handle],
+                                map_action_policy(agent_prev_action[agent_handle]),
+                                all_rewards[agent_handle],
+                                agent_obs[agent_handle],
+                                done[agent_handle])
                     learn_timer.end()
 
-                    agent_prev_obs[agent] = agent_obs[agent].copy()
-                    agent_prev_action[agent] = action_dict[agent]
+                    agent_prev_obs[agent_handle] = agent_obs[agent_handle].copy()
+                    agent_prev_action[agent_handle] = action_dict[agent_handle]
 
                 # Preprocess the new observations
-                if tree_observation.check_is_observation_valid(next_obs[agent]):
+                if tree_observation.check_is_observation_valid(next_obs[agent_handle]):
                     preproc_timer.start()
-                    agent_obs[agent] = tree_observation.get_normalized_observation(next_obs[agent],
-                                                                                   observation_tree_depth,
-                                                                                   observation_radius=observation_radius)
+                    agent_obs[agent_handle] = tree_observation.get_normalized_observation(next_obs[agent_handle],
+                                                                                          observation_tree_depth,
+                                                                                          observation_radius=observation_radius)
                     preproc_timer.end()
 
-                score += all_rewards[agent]
+                score += all_rewards[agent_handle]
 
             nb_steps = step
 
             if done['__all__']:
                 break
 
-            if check_if_all_blocked(train_env):
-                break
+        policy.end_episode(train=True)
         # Epsilon decay
         eps_start = max(eps_end, eps_decay * eps_start)
 
@@ -362,28 +358,30 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         smoothed_normalized_score = np.mean(scores_window)
         smoothed_completion = np.mean(completion_window)
 
+        if train_params.render:
+            env_renderer.close_window()
+
         # Print logs
-        if episode_idx % checkpoint_interval == 0:
+        if episode_idx % checkpoint_interval == 0 and episode_idx > 0:
             policy.save('./checkpoints/' + training_id + '-' + str(episode_idx) + '.pth')
 
             if save_replay_buffer:
                 policy.save_replay_buffer('./replay_buffers/' + training_id + '-' + str(episode_idx) + '.pkl')
 
-            if train_params.render and False:
-                env_renderer.close_window()
-
             # reset action count
-            action_count = [0] * action_size
+            action_count = [0] * get_flatland_full_action_size()
 
         print(
-            '\r🚂 Episode {:7}'
-            '\t 🏆 Score: {:7.3f}'
+            '\r🚂 Episode {}'
+            '\t 🚉 nAgents {:2}/{:2}'
+            ' 🏆 Score: {:7.3f}'
             ' Avg: {:7.3f}'
             '\t 💯 Done: {:6.2f}%'
             ' Avg: {:6.2f}%'
             '\t 🎲 Epsilon: {:.3f} '
             '\t 🔀 Action Probs: {}'.format(
                 episode_idx,
+                train_env_params.n_agents, number_of_agents,
                 normalized_score,
                 smoothed_normalized_score,
                 100 * completion,
@@ -393,7 +391,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
             ), end=" ")
 
         # Evaluate policy and log results at some interval
-        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0 and episode_idx > 0:
+        if episode_idx % checkpoint_interval == 0 and n_eval_episodes > 0:
             scores, completions, nb_steps_eval = eval_policy(eval_env,
                                                              tree_observation,
                                                              policy,
@@ -429,6 +427,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         writer.add_scalar("training/completion", np.mean(completion), episode_idx)
         writer.add_scalar("training/smoothed_completion", np.mean(smoothed_completion), episode_idx)
         writer.add_scalar("training/nb_steps", nb_steps, episode_idx)
+        writer.add_scalar("training/n_agents", train_env_params.n_agents, episode_idx)
         writer.add_histogram("actions/distribution", np.array(actions_taken), episode_idx)
         writer.add_scalar("actions/nothing", action_probs[RailEnvActions.DO_NOTHING], episode_idx)
         writer.add_scalar("actions/left", action_probs[RailEnvActions.MOVE_LEFT], episode_idx)
@@ -443,6 +442,7 @@ def train_agent(train_params, train_env_params, eval_env_params, obs_params, clo
         writer.add_scalar("timer/learn", learn_timer.get(), episode_idx)
         writer.add_scalar("timer/preproc", preproc_timer.get(), episode_idx)
         writer.add_scalar("timer/total", training_timer.get_current(), episode_idx)
+        writer.flush()
 
 
 def format_action_prob(action_probs):
@@ -472,22 +472,24 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
         score = 0.0
 
         obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
-
+        policy.reset(env)
         final_step = 0
 
+        policy.start_episode(train=False)
         for step in range(max_steps - 1):
+            policy.start_step(train=False)
             for agent in env.get_agent_handles():
                 if tree_observation.check_is_observation_valid(agent_obs[agent]):
                     agent_obs[agent] = tree_observation.get_normalized_observation(obs[agent], tree_depth=tree_depth,
                                                                                    observation_radius=observation_radius)
 
-                action = RailEnvActions.DO_NOTHING
+                action = 0
                 if info['action_required'][agent]:
                     if tree_observation.check_is_observation_valid(agent_obs[agent]):
                         action = policy.act(agent, agent_obs[agent], eps=0.0)
                 action_dict.update({agent: action})
-
-            obs, all_rewards, done, info = env.step(action_dict)
+            policy.end_step(train=False)
+            obs, all_rewards, done, info = env.step(map_actions(action_dict))
 
             for agent in env.get_agent_handles():
                 score += all_rewards[agent]
@@ -496,10 +498,7 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
             if done['__all__']:
                 break
-
-            if check_if_all_blocked(env):
-                break
-
+        policy.end_episode(train=False)
         normalized_score = score / (max_steps * env.get_num_agents())
         scores.append(normalized_score)
 
@@ -509,47 +508,52 @@ def eval_policy(env, tree_observation, policy, train_params, obs_params):
 
         nb_steps.append(final_step)
 
-    print("\t✅ Eval: score {:7.3f} done {:6.2f}%".format(np.mean(scores), np.mean(completions) * 100.0))
+    print(" ✅ Eval: score {:.3f} done {:.1f}%".format(np.mean(scores), np.mean(completions) * 100.0))
 
     return scores, completions, nb_steps
 
 
 if __name__ == "__main__":
     parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=200000, type=int)
-    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=0,
+    parser.add_argument("-n", "--n_episodes", help="number of episodes to run", default=5000, type=int)
+    parser.add_argument("--n_agent_fixed", help="hold the number of agent fixed", action='store_true')
+    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=1,
+                        type=int)
+    parser.add_argument("-e", "--evaluation_env_config", help="evaluation config id (eg 0 for Test_0)", default=1,
                         type=int)
-    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=5, type=int)
-    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=200, type=int)
-    parser.add_argument("--eps_start", help="max exploration", default=0.5, type=float)
+    parser.add_argument("--n_evaluation_episodes", help="number of evaluation episodes", default=10, type=int)
+    parser.add_argument("--checkpoint_interval", help="checkpoint interval", default=100, type=int)
+    parser.add_argument("--eps_start", help="max exploration", default=1.0, type=float)
     parser.add_argument("--eps_end", help="min exploration", default=0.01, type=float)
-    parser.add_argument("--eps_decay", help="exploration decay", default=0.9985, type=float)
-    parser.add_argument("--buffer_size", help="replay buffer size", default=int(1e6), type=int)
+    parser.add_argument("--eps_decay", help="exploration decay", default=0.9975, type=float)
+    parser.add_argument("--buffer_size", help="replay buffer size", default=int(32_000), type=int)
     parser.add_argument("--buffer_min_size", help="min buffer size to start training", default=0, type=int)
     parser.add_argument("--restore_replay_buffer", help="replay buffer to restore", default="", type=str)
     parser.add_argument("--save_replay_buffer", help="save replay buffer at each evaluation interval", default=False,
                         type=bool)
     parser.add_argument("--batch_size", help="minibatch size", default=128, type=int)
-    parser.add_argument("--gamma", help="discount factor", default=0.99, type=float)
-    parser.add_argument("--tau", help="soft update of target parameters", default=1e-3, type=float)
+    parser.add_argument("--gamma", help="discount factor", default=0.97, type=float)
+    parser.add_argument("--tau", help="soft update of target parameters", default=0.5e-3, type=float)
     parser.add_argument("--learning_rate", help="learning rate", default=0.5e-4, type=float)
     parser.add_argument("--hidden_size", help="hidden size (2 fc layers)", default=128, type=int)
-    parser.add_argument("--update_every", help="how often to update the network", default=8, type=int)
+    parser.add_argument("--update_every", help="how often to update the network", default=10, type=int)
     parser.add_argument("--use_gpu", help="use GPU if available", default=False, type=bool)
-    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=1, type=int)
+    parser.add_argument("--num_threads", help="number of threads PyTorch can use", default=4, type=int)
+    parser.add_argument("--render", help="render 1 episode in 100", action='store_true')
     parser.add_argument("--load_policy", help="policy filename (reference) to load", default="", type=str)
-    parser.add_argument("--use_extra_observation", help="extra observation", default=True, type=bool)
-    parser.add_argument("--close_following", help="enable close following feature", default=True, type=bool)
+    parser.add_argument("--use_fast_tree_observation", help="use FastTreeObs instead of stock TreeObs",
+                        action='store_true')
     parser.add_argument("--max_depth", help="max depth", default=2, type=int)
-    parser.add_argument("-t", "--training_env_config", help="training config id (eg 0 for Test_0)", default=2,
-                        type=int)
-    parser.add_argument("--render", help="render 1 episode in 100", default=False, type=bool)
+    parser.add_argument("--policy",
+                        help="policy name [DDDQN, PPO, DeadLockAvoidance, DeadLockAvoidanceWithDecision, MultiDecision]",
+                        default="DeadLockAvoidance")
+    parser.add_argument("--action_size", help="define the action size [reduced,full]", default="full", type=str)
 
     training_params = parser.parse_args()
     env_params = [
         {
             # Test_0
-            "n_agents": 5,
+            "n_agents": 1,
             "x_dim": 25,
             "y_dim": 25,
             "n_cities": 2,
@@ -560,6 +564,17 @@ if __name__ == "__main__":
         },
         {
             # Test_1
+            "n_agents": 5,
+            "x_dim": 25,
+            "y_dim": 25,
+            "n_cities": 2,
+            "max_rails_between_cities": 2,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 50,
+            "seed": 0
+        },
+        {
+            # Test_2
             "n_agents": 10,
             "x_dim": 30,
             "y_dim": 30,
@@ -570,10 +585,10 @@ if __name__ == "__main__":
             "seed": 0
         },
         {
-            # Test_2
+            # Test_3
             "n_agents": 20,
-            "x_dim": 30,
-            "y_dim": 30,
+            "x_dim": 35,
+            "y_dim": 35,
             "n_cities": 3,
             "max_rails_between_cities": 2,
             "max_rails_in_city": 3,
@@ -582,25 +597,23 @@ if __name__ == "__main__":
         },
         {
             # Test_3
-            "n_agents": 106,
-            "x_dim": 50,
-            "y_dim": 50,
-            "n_cities": 12,
+            "n_agents": 58,
+            "x_dim": 40,
+            "y_dim": 40,
+            "n_cities": 5,
             "max_rails_between_cities": 2,
-            "max_rails_in_city": 4,
-            "malfunction_rate": 1 / 50000,
+            "max_rails_in_city": 3,
+            "malfunction_rate": 1 / 200,
             "seed": 0
         },
     ]
 
     obs_params = {
-        "observation_tree_depth": training_params.max_depth,  # FIXME
+        "observation_tree_depth": training_params.max_depth,
         "observation_radius": 10,
         "observation_max_path_depth": 30
     }
 
-    print("close_following: ", training_params.close_following)
-
 
     def check_env_config(id):
         if id >= len(env_params) or id < 0:
@@ -615,6 +628,10 @@ if __name__ == "__main__":
     training_env_params = env_params[training_params.training_env_config]
     evaluation_env_params = env_params[training_params.evaluation_env_config]
 
+    # FIXME hard-coded for sweep search
+    # see https://wb-forum.slack.com/archives/CL4V2QE59/p1602931982236600 to implement properly
+    # training_params.use_fast_tree_observation = True
+
     print("\nTraining parameters:")
     pprint(vars(training_params))
     print("\nTraining environment parameters (Test_{}):".format(training_params.training_env_config))
@@ -626,4 +643,4 @@ if __name__ == "__main__":
 
     os.environ["OMP_NUM_THREADS"] = str(training_params.num_threads)
     train_agent(training_params, Namespace(**training_env_params), Namespace(**evaluation_env_params),
-                Namespace(**obs_params), training_params.close_following)
+                Namespace(**obs_params))
diff --git a/reinforcement_learning/multi_decision_agent.py b/reinforcement_learning/multi_decision_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..5047bcd1b52cf1b424ff172d6dcd2b7fba97957f
--- /dev/null
+++ b/reinforcement_learning/multi_decision_agent.py
@@ -0,0 +1,90 @@
+from flatland.envs.rail_env import RailEnv
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from reinforcement_learning.policy import LearningPolicy, DummyMemory
+from reinforcement_learning.ppo_agent import PPOPolicy
+
+
+class MultiDecisionAgent(LearningPolicy):
+
+    def __init__(self, state_size, action_size, in_parameters=None):
+        print(">> MultiDecisionAgent")
+        super(MultiDecisionAgent, self).__init__()
+        self.state_size = state_size
+        self.action_size = action_size
+        self.in_parameters = in_parameters
+        self.memory = DummyMemory()
+        self.loss = 0
+
+        self.ppo_policy = PPOPolicy(state_size, action_size, use_replay_buffer=False, in_parameters=in_parameters)
+        self.dddqn_policy = DDDQNPolicy(state_size, action_size, in_parameters)
+        self.policy_selector = PPOPolicy(state_size, 2)
+
+
+    def step(self, handle, state, action, reward, next_state, done):
+        self.ppo_policy.step(handle, state, action, reward, next_state, done)
+        self.dddqn_policy.step(handle, state, action, reward, next_state, done)
+        select = self.policy_selector.act(handle, state, 0.0)
+        self.policy_selector.step(handle, state, select, reward, next_state, done)
+
+    def act(self, handle, state, eps=0.):
+        select = self.policy_selector.act(handle, state, eps)
+        if select == 0:
+            return self.dddqn_policy.act(handle, state, eps)
+        return self.policy_selector.act(handle, state, eps)
+
+    def save(self, filename):
+        self.ppo_policy.save(filename)
+        self.dddqn_policy.save(filename)
+        self.policy_selector.save(filename)
+
+    def load(self, filename):
+        self.ppo_policy.load(filename)
+        self.dddqn_policy.load(filename)
+        self.policy_selector.load(filename)
+
+    def start_step(self, train):
+        self.ppo_policy.start_step(train)
+        self.dddqn_policy.start_step(train)
+        self.policy_selector.start_step(train)
+
+    def end_step(self, train):
+        self.ppo_policy.end_step(train)
+        self.dddqn_policy.end_step(train)
+        self.policy_selector.end_step(train)
+
+    def start_episode(self, train):
+        self.ppo_policy.start_episode(train)
+        self.dddqn_policy.start_episode(train)
+        self.policy_selector.start_episode(train)
+
+    def end_episode(self, train):
+        self.ppo_policy.end_episode(train)
+        self.dddqn_policy.end_episode(train)
+        self.policy_selector.end_episode(train)
+
+    def load_replay_buffer(self, filename):
+        self.ppo_policy.load_replay_buffer(filename)
+        self.dddqn_policy.load_replay_buffer(filename)
+        self.policy_selector.load_replay_buffer(filename)
+
+    def test(self):
+        self.ppo_policy.test()
+        self.dddqn_policy.test()
+        self.policy_selector.test()
+
+    def reset(self, env: RailEnv):
+        self.ppo_policy.reset(env)
+        self.dddqn_policy.reset(env)
+        self.policy_selector.reset(env)
+
+    def clone(self):
+        multi_descision_agent = MultiDecisionAgent(
+            self.state_size,
+            self.action_size,
+            self.in_parameters
+        )
+        multi_descision_agent.ppo_policy = self.ppo_policy.clone()
+        multi_descision_agent.dddqn_policy = self.dddqn_policy.clone()
+        multi_descision_agent.policy_selector = self.policy_selector.clone()
+        return multi_descision_agent
diff --git a/reinforcement_learning/multi_policy.py b/reinforcement_learning/multi_policy.py
index 765bcf599f681f8b7a3dca311f223c5eed85e42d..5ee8cb40b97d8ddce1d5a6b4d77aed5fb689b263 100644
--- a/reinforcement_learning/multi_policy.py
+++ b/reinforcement_learning/multi_policy.py
@@ -1,10 +1,9 @@
 import numpy as np
-from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.rail_env import RailEnv
 
 from reinforcement_learning.policy import Policy
-from reinforcement_learning.ppo.ppo_agent import PPOAgent
+from reinforcement_learning.ppo_agent import PPOPolicy
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
-from utils.extra import ExtraPolicy
 
 
 class MultiPolicy(Policy):
@@ -13,20 +12,20 @@ class MultiPolicy(Policy):
         self.action_size = action_size
         self.memory = []
         self.loss = 0
-        self.extra_policy = ExtraPolicy(state_size, action_size)
-        self.ppo_policy = PPOAgent(state_size + action_size, action_size, n_agents, env)
+        self.deadlock_avoidance_policy = DeadLockAvoidanceAgent(env, action_size, False)
+        self.ppo_policy = PPOPolicy(state_size + action_size, action_size)
 
     def load(self, filename):
         self.ppo_policy.load(filename)
-        self.extra_policy.load(filename)
+        self.deadlock_avoidance_policy.load(filename)
 
     def save(self, filename):
         self.ppo_policy.save(filename)
-        self.extra_policy.save(filename)
+        self.deadlock_avoidance_policy.save(filename)
 
     def step(self, handle, state, action, reward, next_state, done):
-        action_extra_state = self.extra_policy.act(handle, state, 0.0)
-        action_extra_next_state = self.extra_policy.act(handle, next_state, 0.0)
+        action_extra_state = self.deadlock_avoidance_policy.act(handle, state, 0.0)
+        action_extra_next_state = self.deadlock_avoidance_policy.act(handle, next_state, 0.0)
 
         extended_state = np.copy(state)
         for action_itr in np.arange(self.action_size):
@@ -35,11 +34,11 @@ class MultiPolicy(Policy):
         for action_itr in np.arange(self.action_size):
             extended_next_state = np.append(extended_next_state, [int(action_extra_next_state == action_itr)])
 
-        self.extra_policy.step(handle, state, action, reward, next_state, done)
+        self.deadlock_avoidance_policy.step(handle, state, action, reward, next_state, done)
         self.ppo_policy.step(handle, extended_state, action, reward, extended_next_state, done)
 
     def act(self, handle, state, eps=0.):
-        action_extra_state = self.extra_policy.act(handle, state, 0.0)
+        action_extra_state = self.deadlock_avoidance_policy.act(handle, state, 0.0)
         extended_state = np.copy(state)
         for action_itr in np.arange(self.action_size):
             extended_state = np.append(extended_state, [int(action_extra_state == action_itr)])
@@ -47,18 +46,18 @@ class MultiPolicy(Policy):
         self.loss = self.ppo_policy.loss
         return action_ppo
 
-    def reset(self):
-        self.ppo_policy.reset()
-        self.extra_policy.reset()
+    def reset(self, env: RailEnv):
+        self.ppo_policy.reset(env)
+        self.deadlock_avoidance_policy.reset(env)
 
     def test(self):
         self.ppo_policy.test()
-        self.extra_policy.test()
+        self.deadlock_avoidance_policy.test()
 
-    def start_step(self):
-        self.extra_policy.start_step()
-        self.ppo_policy.start_step()
+    def start_step(self, train):
+        self.deadlock_avoidance_policy.start_step(train)
+        self.ppo_policy.start_step(train)
 
-    def end_step(self):
-        self.extra_policy.end_step()
-        self.ppo_policy.end_step()
+    def end_step(self, train):
+        self.deadlock_avoidance_policy.end_step(train)
+        self.ppo_policy.end_step(train)
diff --git a/reinforcement_learning/ordered_policy.py b/reinforcement_learning/ordered_policy.py
index 3dc55ee13e489a526b1346a01bd8737652c77e9f..2db171d2e1429a085488b02f9818ba75c57b2694 100644
--- a/reinforcement_learning/ordered_policy.py
+++ b/reinforcement_learning/ordered_policy.py
@@ -1,34 +1,34 @@
-import sys
-from pathlib import Path
-
-import numpy as np
-
-from reinforcement_learning.policy import Policy
-
-base_dir = Path(__file__).resolve().parent.parent
-sys.path.append(str(base_dir))
-
-from utils.observation_utils import split_tree_into_feature_groups, min_gt
-
-
-class OrderedPolicy(Policy):
-    def __init__(self):
-        self.action_size = 5
-
-    def act(self, state, eps=0.):
-        _, distance, _ = split_tree_into_feature_groups(state, 1)
-        distance = distance[1:]
-        min_dist = min_gt(distance, 0)
-        min_direction = np.where(distance == min_dist)
-        if len(min_direction[0]) > 1:
-            return min_direction[0][-1] + 1
-        return min_direction[0] + 1
-
-    def step(self, state, action, reward, next_state, done):
-        return
-
-    def save(self, filename):
-        return
-
-    def load(self, filename):
-        return
+import sys
+from pathlib import Path
+
+import numpy as np
+
+from reinforcement_learning.policy import Policy
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from utils.observation_utils import split_tree_into_feature_groups, min_gt
+
+
+class OrderedPolicy(Policy):
+    def __init__(self):
+        self.action_size = 5
+
+    def act(self, handle, state, eps=0.):
+        _, distance, _ = split_tree_into_feature_groups(state, 1)
+        distance = distance[1:]
+        min_dist = min_gt(distance, 0)
+        min_direction = np.where(distance == min_dist)
+        if len(min_direction[0]) > 1:
+            return min_direction[0][-1] + 1
+        return min_direction[0] + 1
+
+    def step(self, state, action, reward, next_state, done):
+        return
+
+    def save(self, filename):
+        return
+
+    def load(self, filename):
+        return
diff --git a/reinforcement_learning/policy.py b/reinforcement_learning/policy.py
index 9c77845e46540df8c2a0beff658728b1604cbd89..fe28cbc561e72f646209aafff84ed4f2145996f0 100644
--- a/reinforcement_learning/policy.py
+++ b/reinforcement_learning/policy.py
@@ -1,27 +1,62 @@
+from flatland.envs.rail_env import RailEnv
+
+
+class DummyMemory:
+    def __init__(self):
+        self.memory = []
+
+    def __len__(self):
+        return 0
+
+
 class Policy:
     def step(self, handle, state, action, reward, next_state, done):
         raise NotImplementedError
 
-    def act(self, state, eps=0.):
+    def act(self, handle, state, eps=0.):
         raise NotImplementedError
 
     def save(self, filename):
-        pass
+        raise NotImplementedError
 
     def load(self, filename):
+        raise NotImplementedError
+
+    def start_step(self, train):
         pass
 
-    def test(self):
+    def end_step(self, train):
         pass
 
-    def save_replay_buffer(self):
+    def start_episode(self, train):
         pass
 
-    def reset(self):
+    def end_episode(self, train):
         pass
 
-    def start_step(self):
+    def load_replay_buffer(self, filename):
         pass
 
-    def end_step(self):
+    def test(self):
+        pass
+
+    def reset(self, env: RailEnv):
         pass
+
+    def clone(self):
+        return self
+
+
+class HeuristicPolicy(Policy):
+    def __init__(self):
+        super(HeuristicPolicy).__init__()
+
+
+class LearningPolicy(Policy):
+    def __init__(self):
+        super(LearningPolicy).__init__()
+
+
+class HybridPolicy(Policy):
+    def __init__(self):
+        super(HybridPolicy).__init__()
diff --git a/reinforcement_learning/ppo/model.py b/reinforcement_learning/ppo/model.py
deleted file mode 100644
index 51b86ff16691c03f6a754405352bb4cf48e4b914..0000000000000000000000000000000000000000
--- a/reinforcement_learning/ppo/model.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import torch.nn as nn
-import torch.nn.functional as F
-
-
-class PolicyNetwork(nn.Module):
-    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
-        super().__init__()
-        self.fc1 = nn.Linear(state_size, hidsize1)
-        self.fc2 = nn.Linear(hidsize1, hidsize2)
-        # self.fc3 = nn.Linear(hidsize2, hidsize3)
-        self.output = nn.Linear(hidsize2, action_size)
-        self.softmax = nn.Softmax(dim=1)
-        self.bn0 = nn.BatchNorm1d(state_size, affine=False)
-
-    def forward(self, inputs):
-        x = self.bn0(inputs.float())
-        x = F.relu(self.fc1(x))
-        x = F.relu(self.fc2(x))
-        # x = F.relu(self.fc3(x))
-        return self.softmax(self.output(x))
diff --git a/reinforcement_learning/ppo/ppo_agent.py b/reinforcement_learning/ppo/ppo_agent.py
deleted file mode 100644
index ec904e4598ae1a15fcff8f0bd1aa3b4cd2f5f3e9..0000000000000000000000000000000000000000
--- a/reinforcement_learning/ppo/ppo_agent.py
+++ /dev/null
@@ -1,141 +0,0 @@
-import os
-import random
-
-import numpy as np
-import torch
-from torch.distributions.categorical import Categorical
-
-from reinforcement_learning.policy import Policy
-from reinforcement_learning.ppo.model import PolicyNetwork
-from reinforcement_learning.ppo.replay_memory import Episode, ReplayBuffer
-
-BUFFER_SIZE = 128_000
-BATCH_SIZE = 8192
-GAMMA = 0.95
-LR = 0.5e-4
-CLIP_FACTOR = .005
-UPDATE_EVERY = 30
-
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
-
-class PPOAgent(Policy):
-    def __init__(self, state_size, action_size, num_agents, env):
-        self.action_size = action_size
-        self.state_size = state_size
-        self.num_agents = num_agents
-        self.policy = PolicyNetwork(state_size, action_size).to(device)
-        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
-        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
-        self.episodes = [Episode() for _ in range(num_agents)]
-        self.memory = ReplayBuffer(BUFFER_SIZE)
-        self.t_step = 0
-        self.loss = 0
-        self.env = env
-
-    def reset(self):
-        self.finished = [False] * len(self.episodes)
-        self.tot_reward = [0] * self.num_agents
-
-    # Decide on an action to take in the environment
-
-    def act(self, handle, state, eps=None):
-        if True:
-            self.policy.eval()
-            with torch.no_grad():
-                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
-                return Categorical(output).sample().item()
-
-        # Epsilon-greedy action selection
-        if random.random() > eps:
-            self.policy.eval()
-            with torch.no_grad():
-                output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
-                return Categorical(output).sample().item()
-        else:
-            return random.choice(np.arange(self.action_size))
-
-    # Record the results of the agent's action and update the model
-    def step(self, handle, state, action, reward, next_state, done):
-        if not self.finished[handle]:
-            # Push experience into Episode memory
-            self.tot_reward[handle] += reward
-            if done == 1:
-                reward = 1  # self.tot_reward[handle]
-            else:
-                reward = 0
-
-            self.episodes[handle].push(state, action, reward, next_state, done)
-
-            # When we finish the episode, discount rewards and push the experience into replay memory
-            if done:
-                self.episodes[handle].discount_rewards(GAMMA)
-                self.memory.push_episode(self.episodes[handle])
-                self.episodes[handle].reset()
-                self.finished[handle] = True
-
-        # Perform a gradient update every UPDATE_EVERY time steps
-        self.t_step = (self.t_step + 1) % UPDATE_EVERY
-        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
-            self._learn(*self.memory.sample(BATCH_SIZE, device))
-
-    def _clip_gradient(self, model, clip):
-
-        for p in model.parameters():
-            p.grad.data.clamp_(-clip, clip)
-        return
-
-        """Computes a gradient clipping coefficient based on gradient norm."""
-        totalnorm = 0
-        for p in model.parameters():
-            if p.grad is not None:
-                modulenorm = p.grad.data.norm()
-                totalnorm += modulenorm ** 2
-        totalnorm = np.sqrt(totalnorm)
-        coeff = min(1, clip / (totalnorm + 1e-6))
-
-        for p in model.parameters():
-            if p.grad is not None:
-                p.grad.mul_(coeff)
-
-    def _learn(self, states, actions, rewards, next_state, done):
-        self.policy.train()
-
-        responsible_outputs = torch.gather(self.policy(states), 1, actions)
-        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
-
-        # rewards = rewards - rewards.mean()
-        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
-        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
-        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
-        self.loss = loss
-
-        # Compute loss and perform a gradient step
-        self.old_policy.load_state_dict(self.policy.state_dict())
-        self.optimizer.zero_grad()
-        loss.backward()
-        # self._clip_gradient(self.policy, 1.0)
-        self.optimizer.step()
-
-    # Checkpointing methods
-    def save(self, filename):
-        # print("Saving model from checkpoint:", filename)
-        torch.save(self.policy.state_dict(), filename + ".policy")
-        torch.save(self.optimizer.state_dict(), filename + ".optimizer")
-
-    def load(self, filename):
-        print("load policy from file", filename)
-        if os.path.exists(filename + ".policy"):
-            print(' >> ', filename + ".policy")
-            try:
-                self.policy.load_state_dict(torch.load(filename + ".policy"))
-            except:
-                print(" >> failed!")
-                pass
-        if os.path.exists(filename + ".optimizer"):
-            print(' >> ', filename + ".optimizer")
-            try:
-                self.optimizer.load_state_dict(torch.load(filename + ".optimizer"))
-            except:
-                print(" >> failed!")
-                pass
diff --git a/reinforcement_learning/ppo/replay_memory.py b/reinforcement_learning/ppo/replay_memory.py
deleted file mode 100644
index 3e6619b40169597d7a4b379f4ce2c9ddccd4cd9b..0000000000000000000000000000000000000000
--- a/reinforcement_learning/ppo/replay_memory.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import torch
-import random
-import numpy as np
-from collections import namedtuple, deque, Iterable
-
-
-Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done"))
-
-
-class Episode:
-    memory = []
-
-    def reset(self):
-        self.memory = []
-
-    def push(self, *args):
-        self.memory.append(tuple(args))
-
-    def discount_rewards(self, gamma):
-        running_add = 0.
-        for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]:
-            running_add = running_add * gamma + reward
-            self.memory[i] = (state, action, running_add, *rest)
-
-
-class ReplayBuffer:
-    def __init__(self, buffer_size):
-        self.memory = deque(maxlen=buffer_size)
-
-    def push(self, state, action, reward, next_state, done):
-        self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done))
-
-    def push_episode(self, episode):
-        for step in episode.memory:
-            self.push(*step)
-
-    def sample(self, batch_size, device):
-        experiences = random.sample(self.memory, k=batch_size)
-
-        states      = torch.from_numpy(self.stack([e.state      for e in experiences])).float().to(device)
-        actions     = torch.from_numpy(self.stack([e.action     for e in experiences])).long().to(device)
-        rewards     = torch.from_numpy(self.stack([e.reward     for e in experiences])).float().to(device)
-        next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device)
-        dones       = torch.from_numpy(self.stack([e.done       for e in experiences]).astype(np.uint8)).float().to(device)
-
-        return states, actions, rewards, next_states, dones
-
-    def stack(self, states):
-        sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1]
-        return np.reshape(np.array(states), (len(states), *sub_dims))
-
-    def __len__(self):
-        return len(self.memory)
diff --git a/reinforcement_learning/ppo_agent.py b/reinforcement_learning/ppo_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a734c7626cf315d86c0463ea38383a28c75d31d
--- /dev/null
+++ b/reinforcement_learning/ppo_agent.py
@@ -0,0 +1,301 @@
+import copy
+import os
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.distributions import Categorical
+
+# Hyperparameters
+from reinforcement_learning.policy import LearningPolicy
+from reinforcement_learning.replay_buffer import ReplayBuffer
+
+# https://lilianweng.github.io/lil-log/2018/04/08/policy-gradient-algorithms.html
+
+class EpisodeBuffers:
+    def __init__(self):
+        self.reset()
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def reset(self):
+        self.memory = {}
+
+    def get_transitions(self, handle):
+        return self.memory.get(handle, [])
+
+    def push_transition(self, handle, transition):
+        transitions = self.get_transitions(handle)
+        transitions.append(transition)
+        self.memory.update({handle: transitions})
+
+
+class ActorCriticModel(nn.Module):
+
+    def __init__(self, state_size, action_size, device, hidsize1=512, hidsize2=256):
+        super(ActorCriticModel, self).__init__()
+        self.device = device
+        self.actor = nn.Sequential(
+            nn.Linear(state_size, hidsize1),
+            nn.Tanh(),
+            nn.Linear(hidsize1, hidsize2),
+            nn.Tanh(),
+            nn.Linear(hidsize2, action_size),
+            nn.Softmax(dim=-1)
+        ).to(self.device)
+
+        self.critic = nn.Sequential(
+            nn.Linear(state_size, hidsize1),
+            nn.Tanh(),
+            nn.Linear(hidsize1, hidsize2),
+            nn.Tanh(),
+            nn.Linear(hidsize2, 1)
+        ).to(self.device)
+
+    def forward(self, x):
+        raise NotImplementedError
+
+    def get_actor_dist(self, state):
+        action_probs = self.actor(state)
+        dist = Categorical(action_probs)
+        return dist
+
+    def evaluate(self, states, actions):
+        action_probs = self.actor(states)
+        dist = Categorical(action_probs)
+        action_logprobs = dist.log_prob(actions)
+        dist_entropy = dist.entropy()
+        state_value = self.critic(states)
+        return action_logprobs, torch.squeeze(state_value), dist_entropy
+
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        torch.save(self.actor.state_dict(), filename + ".actor")
+        torch.save(self.critic.state_dict(), filename + ".value")
+
+    def _load(self, obj, filename):
+        if os.path.exists(filename):
+            print(' >> ', filename)
+            try:
+                obj.load_state_dict(torch.load(filename, map_location=self.device))
+            except:
+                print(" >> failed!")
+        return obj
+
+    def load(self, filename):
+        print("load model from file", filename)
+        self.actor = self._load(self.actor, filename + ".actor")
+        self.critic = self._load(self.critic, filename + ".value")
+
+
+class PPOPolicy(LearningPolicy):
+    def __init__(self, state_size, action_size, use_replay_buffer=False, in_parameters=None):
+        print(">> PPOPolicy")
+        super(PPOPolicy, self).__init__()
+        # parameters
+        self.ppo_parameters = in_parameters
+        if self.ppo_parameters is not None:
+            self.hidsize = self.ppo_parameters.hidden_size
+            self.buffer_size = self.ppo_parameters.buffer_size
+            self.batch_size = self.ppo_parameters.batch_size
+            self.learning_rate = self.ppo_parameters.learning_rate
+            self.gamma = self.ppo_parameters.gamma
+            # Device
+            if self.ppo_parameters.use_gpu and torch.cuda.is_available():
+                self.device = torch.device("cuda:0")
+                # print("🐇 Using GPU")
+            else:
+                self.device = torch.device("cpu")
+                # print("🐢 Using CPU")
+        else:
+            self.hidsize = 128
+            self.learning_rate = 1.0e-3
+            self.gamma = 0.95
+            self.buffer_size = 32_000
+            self.batch_size = 1024
+            self.device = torch.device("cpu")
+
+        self.surrogate_eps_clip = 0.1
+        self.K_epoch = 10
+        self.weight_loss = 0.5
+        self.weight_entropy = 0.01
+
+        self.buffer_min_size = 0
+        self.use_replay_buffer = use_replay_buffer
+
+        self.current_episode_memory = EpisodeBuffers()
+        self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
+        self.loss = 0
+        self.actor_critic_model = ActorCriticModel(state_size, action_size,self.device,
+                                                   hidsize1=self.hidsize,
+                                                   hidsize2=self.hidsize)
+        self.optimizer = optim.Adam(self.actor_critic_model.parameters(), lr=self.learning_rate)
+        self.loss_function = nn.MSELoss()  # nn.SmoothL1Loss()
+
+    def reset(self, env):
+        pass
+
+    def act(self, handle, state, eps=None):
+        # sample a action to take
+        torch_state = torch.tensor(state, dtype=torch.float).to(self.device)
+        dist = self.actor_critic_model.get_actor_dist(torch_state)
+        action = dist.sample()
+        return action.item()
+
+    def step(self, handle, state, action, reward, next_state, done):
+        # record transitions ([state] -> [action] -> [reward, next_state, done])
+        torch_action = torch.tensor(action, dtype=torch.float).to(self.device)
+        torch_state = torch.tensor(state, dtype=torch.float).to(self.device)
+        # evaluate actor
+        dist = self.actor_critic_model.get_actor_dist(torch_state)
+        action_logprobs = dist.log_prob(torch_action)
+        transition = (state, action, reward, next_state, action_logprobs.item(), done)
+        self.current_episode_memory.push_transition(handle, transition)
+
+    def _push_transitions_to_replay_buffer(self,
+                                           state_list,
+                                           action_list,
+                                           reward_list,
+                                           state_next_list,
+                                           done_list,
+                                           prob_a_list):
+        for idx in range(len(reward_list)):
+            state_i = state_list[idx]
+            action_i = action_list[idx]
+            reward_i = reward_list[idx]
+            state_next_i = state_next_list[idx]
+            done_i = done_list[idx]
+            prob_action_i = prob_a_list[idx]
+            self.memory.add(state_i, action_i, reward_i, state_next_i, done_i, prob_action_i)
+
+    def _convert_transitions_to_torch_tensors(self, transitions_array):
+        # build empty lists(arrays)
+        state_list, action_list, reward_list, state_next_list, prob_a_list, done_list = [], [], [], [], [], []
+
+        # set discounted_reward to zero
+        discounted_reward = 0
+        for transition in transitions_array[::-1]:
+            state_i, action_i, reward_i, state_next_i, prob_action_i, done_i = transition
+
+            state_list.insert(0, state_i)
+            action_list.insert(0, action_i)
+            if done_i:
+                discounted_reward = 0
+                done_list.insert(0, 1)
+            else:
+                done_list.insert(0, 0)
+
+            discounted_reward = reward_i + self.gamma * discounted_reward
+            reward_list.insert(0, discounted_reward)
+            state_next_list.insert(0, state_next_i)
+            prob_a_list.insert(0, prob_action_i)
+
+        if self.use_replay_buffer:
+            self._push_transitions_to_replay_buffer(state_list, action_list,
+                                                    reward_list, state_next_list,
+                                                    done_list, prob_a_list)
+
+        # convert data to torch tensors
+        states, actions, rewards, states_next, dones, prob_actions = \
+            torch.tensor(state_list, dtype=torch.float).to(self.device), \
+            torch.tensor(action_list).to(self.device), \
+            torch.tensor(reward_list, dtype=torch.float).to(self.device), \
+            torch.tensor(state_next_list, dtype=torch.float).to(self.device), \
+            torch.tensor(done_list, dtype=torch.float).to(self.device), \
+            torch.tensor(prob_a_list).to(self.device)
+
+        return states, actions, rewards, states_next, dones, prob_actions
+
+    def _get_transitions_from_replay_buffer(self, states, actions, rewards, states_next, dones, probs_action):
+        if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
+            states, actions, rewards, states_next, dones, probs_action = self.memory.sample()
+            actions = torch.squeeze(actions)
+            rewards = torch.squeeze(rewards)
+            states_next = torch.squeeze(states_next)
+            dones = torch.squeeze(dones)
+            probs_action = torch.squeeze(probs_action)
+        return states, actions, rewards, states_next, dones, probs_action
+
+    def train_net(self):
+        # All agents have to propagate their experiences made during past episode
+        for handle in range(len(self.current_episode_memory)):
+            # Extract agent's episode history (list of all transitions)
+            agent_episode_history = self.current_episode_memory.get_transitions(handle)
+            if len(agent_episode_history) > 0:
+                # Convert the replay buffer to torch tensors (arrays)
+                states, actions, rewards, states_next, dones, probs_action = \
+                    self._convert_transitions_to_torch_tensors(agent_episode_history)
+
+                # Optimize policy for K epochs:
+                for k_loop in range(int(self.K_epoch)):
+
+                    if self.use_replay_buffer:
+                        states, actions, rewards, states_next, dones, probs_action = \
+                            self._get_transitions_from_replay_buffer(
+                                states, actions, rewards, states_next, dones, probs_action
+                            )
+
+                    # Evaluating actions (actor) and values (critic)
+                    logprobs, state_values, dist_entropy = self.actor_critic_model.evaluate(states, actions)
+
+                    # Finding the ratios (pi_thetas / pi_thetas_replayed):
+                    ratios = torch.exp(logprobs - probs_action.detach())
+
+                    # Finding Surrogate Loos
+                    advantages = rewards - state_values.detach()
+                    surr1 = ratios * advantages
+                    surr2 = torch.clamp(ratios, 1. - self.surrogate_eps_clip, 1. + self.surrogate_eps_clip) * advantages
+
+                    # The loss function is used to estimate the gardient and use the entropy function based
+                    # heuristic to penalize the gradient function when the policy becomes deterministic this would let
+                    # the gradient becomes very flat and so the gradient is no longer useful.
+                    loss = \
+                        -torch.min(surr1, surr2) \
+                        + self.weight_loss * self.loss_function(state_values, rewards) \
+                        - self.weight_entropy * dist_entropy
+
+                    # Make a gradient step
+                    self.optimizer.zero_grad()
+                    loss.mean().backward()
+                    self.optimizer.step()
+
+                    # Transfer the current loss to the agents loss (information) for debug purpose only
+                    self.loss = loss.mean().detach().cpu().numpy()
+
+        # Reset all collect transition data
+        self.current_episode_memory.reset()
+
+    def end_episode(self, train):
+        if train:
+            self.train_net()
+
+    # Checkpointing methods
+    def save(self, filename):
+        # print("Saving model from checkpoint:", filename)
+        self.actor_critic_model.save(filename)
+        torch.save(self.optimizer.state_dict(), filename + ".optimizer")
+
+    def _load(self, obj, filename):
+        if os.path.exists(filename):
+            print(' >> ', filename)
+            try:
+                obj.load_state_dict(torch.load(filename, map_location=self.device))
+            except:
+                print(" >> failed!")
+        else:
+            print(" >> file not found!")
+        return obj
+
+    def load(self, filename):
+        print("load policy from file", filename)
+        self.actor_critic_model.load(filename)
+        print("load optimizer from file", filename)
+        self.optimizer = self._load(self.optimizer, filename + ".optimizer")
+
+    def clone(self):
+        policy = PPOPolicy(self.state_size, self.action_size)
+        policy.actor_critic_model = copy.deepcopy(self.actor_critic_model)
+        policy.optimizer = copy.deepcopy(self.optimizer)
+        return self
diff --git a/reinforcement_learning/replay_buffer.py b/reinforcement_learning/replay_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ba147d5808a507d17d742dce333a0902c32f710
--- /dev/null
+++ b/reinforcement_learning/replay_buffer.py
@@ -0,0 +1,57 @@
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
+import torch
+
+Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done", "action_prob"])
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, device):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.device = device
+
+    def add(self, state, action, reward, next_state, done, action_prob=0.0):
+        """Add a new experience to memory."""
+        e = Experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done, action_prob)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(self.device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(self.device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(self.device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(self.device)
+        action_probs = torch.from_numpy(self.__v_stack_impr([e.action_prob for e in experiences if e is not None])) \
+            .float().to(self.device)
+
+        return states, actions, rewards, next_states, dones, action_probs
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
diff --git a/reinforcement_learning/rl_agent_test.py b/reinforcement_learning/rl_agent_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..8687563f1886f549b415abbf8a3ff3b60e8bdb12
--- /dev/null
+++ b/reinforcement_learning/rl_agent_test.py
@@ -0,0 +1,92 @@
+from collections import deque
+from collections import namedtuple
+
+import gym
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from reinforcement_learning.ppo_agent import PPOPolicy
+
+dddqn_param_nt = namedtuple('DDDQN_Param', ['hidden_size', 'buffer_size', 'batch_size', 'update_every', 'learning_rate',
+                                            'tau', 'gamma', 'buffer_min_size', 'use_gpu'])
+dddqn_param = dddqn_param_nt(hidden_size=128,
+                             buffer_size=1000,
+                             batch_size=64,
+                             update_every=10,
+                             learning_rate=1.e-3,
+                             tau=1.e-2,
+                             gamma=0.95,
+                             buffer_min_size=0,
+                             use_gpu=False)
+
+
+def cartpole(use_dddqn=False):
+    eps = 1.0
+    eps_decay = 0.99
+    min_eps = 0.01
+    training_mode = True
+
+    env = gym.make("CartPole-v1")
+    observation_space = env.observation_space.shape[0]
+    action_space = env.action_space.n
+    if not use_dddqn:
+        policy = PPOPolicy(observation_space, action_space, False)
+    else:
+        policy = DDDQNPolicy(observation_space, action_space, dddqn_param)
+    episode = 0
+    checkpoint_interval = 20
+    scores_window = deque(maxlen=100)
+
+    writer = SummaryWriter()
+
+    while True:
+        episode += 1
+        state = env.reset()
+        policy.reset(env)
+        handle = 0
+        tot_reward = 0
+
+        policy.start_episode(train=training_mode)
+        while True:
+            # env.render()
+            policy.start_step(train=training_mode)
+            action = policy.act(handle, state, eps)
+            state_next, reward, terminal, info = env.step(action)
+            policy.end_step(train=training_mode)
+            tot_reward += reward
+            # reward = reward if not terminal else -reward
+            reward = 0 if not terminal else -1
+            policy.step(handle, state, action, reward, state_next, terminal)
+            state = np.copy(state_next)
+            if terminal:
+                break
+
+        policy.end_episode(train=training_mode)
+        eps = max(min_eps, eps * eps_decay)
+        scores_window.append(tot_reward)
+        if episode % checkpoint_interval == 0:
+            print('\rEpisode: {:5}\treward: {:7.3f}\t avg: {:7.3f}\t eps: {:5.3f}\t replay buffer: {}'.format(episode,
+                                                                                                              tot_reward,
+                                                                                                              np.mean(
+                                                                                                                  scores_window),
+                                                                                                              eps,
+                                                                                                              len(
+                                                                                                                  policy.memory)))
+        else:
+            print('\rEpisode: {:5}\treward: {:7.3f}\t avg: {:7.3f}\t eps: {:5.3f}\t replay buffer: {}'.format(episode,
+                                                                                                              tot_reward,
+                                                                                                              np.mean(
+                                                                                                                  scores_window),
+                                                                                                              eps,
+                                                                                                              len(
+                                                                                                                  policy.memory)),
+                  end=" ")
+
+        writer.add_scalar("CartPole/value", tot_reward, episode)
+        writer.add_scalar("CartPole/smoothed_value", np.mean(scores_window), episode)
+        writer.flush()
+
+
+if __name__ == "__main__":
+    cartpole()
diff --git a/reinforcement_learning/sequential_agent.py b/reinforcement_learning/sequential_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2055a69576454a0252a24e21408db0f04131da0
--- /dev/null
+++ b/reinforcement_learning/sequential_agent.py
@@ -0,0 +1,85 @@
+import sys
+from pathlib import Path
+
+import numpy as np
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import complex_rail_generator
+from flatland.envs.schedule_generators import complex_schedule_generator
+from flatland.utils.rendertools import RenderTool
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.ordered_policy import OrderedPolicy
+
+"""
+This file shows how to move agents in a sequential way: it moves the trains one by one, following a shortest path strategy.
+This is obviously very slow, but it's a good way to get familiar with the different Flatland components: RailEnv, TreeObsForRailEnv, etc...
+
+multi_agent_training.py is a better starting point to train your own solution!
+"""
+
+np.random.seed(2)
+
+x_dim = np.random.randint(8, 20)
+y_dim = np.random.randint(8, 20)
+n_agents = np.random.randint(3, 8)
+n_goals = n_agents + np.random.randint(0, 3)
+min_dist = int(0.75 * min(x_dim, y_dim))
+
+env = RailEnv(
+    width=x_dim,
+    height=y_dim,
+    rail_generator=complex_rail_generator(
+        nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist,
+        max_dist=99999,
+        seed=0
+    ),
+    schedule_generator=complex_schedule_generator(),
+    obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()),
+    number_of_agents=n_agents)
+env.reset(True, True)
+
+tree_depth = 1
+observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv())
+env_renderer = RenderTool(env, gl="PGL", )
+handle = env.get_agent_handles()
+n_episodes = 10
+max_steps = 100 * (env.height + env.width)
+record_images = False
+policy = OrderedPolicy()
+action_dict = dict()
+
+for trials in range(1, n_episodes + 1):
+    # Reset environment
+    obs, info = env.reset(True, True)
+    done = env.dones
+    env_renderer.reset()
+    frame_step = 0
+
+    # Run episode
+    for step in range(max_steps):
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
+
+        if record_images:
+            env_renderer.gl.save_image("./Images/flatland_frame_{:04d}.bmp".format(frame_step))
+            frame_step += 1
+
+        # Action
+        acting_agent = 0
+        for a in range(env.get_num_agents()):
+            if done[a]:
+                acting_agent += 1
+            if a == acting_agent:
+                action = policy.act(a, obs[a])
+            else:
+                action = 4
+            action_dict.update({a: action})
+
+        # Environment step
+        obs, all_rewards, done, _ = env.step(action_dict)
+
+        if done['__all__']:
+            break
diff --git a/reinforcement_learning/sequential_agent_training.py b/reinforcement_learning/sequential_agent_training.py
index ca19d1fcbbb4e3508a16b847d4b4cfcefc6aad98..d1ddd4348a462a9b7c17d6dae36c780acff1fd8b 100644
--- a/reinforcement_learning/sequential_agent_training.py
+++ b/reinforcement_learning/sequential_agent_training.py
@@ -1,13 +1,13 @@
 import sys
-import numpy as np
+from pathlib import Path
 
+import numpy as np
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
 from flatland.utils.rendertools import RenderTool
-from pathlib import Path
 
 base_dir = Path(__file__).resolve().parent.parent
 sys.path.append(str(base_dir))
@@ -66,7 +66,7 @@ for trials in range(1, n_episodes + 1):
             if done[a]:
                 acting_agent += 1
             if a == acting_agent:
-                action = policy.act(obs[a])
+                action = policy.act(a, obs[a])
             else:
                 action = 4
             action_dict.update({a: action})
diff --git a/reinforcement_learning/single_agent_training.py b/reinforcement_learning/single_agent_training.py
index 79a88b25a8bc63011ef04208af53858dbb079d7d..dda07a9db5b6da3c2185f65d259fa0a9cf549c50 100644
--- a/reinforcement_learning/single_agent_training.py
+++ b/reinforcement_learning/single_agent_training.py
@@ -1,203 +1,209 @@
-import random
-import sys
-from argparse import ArgumentParser, Namespace
-from collections import deque
-from pathlib import Path
-
-base_dir = Path(__file__).resolve().parent.parent
-sys.path.append(str(base_dir))
-
-from reinforcement_learning.dddqn_policy import DDDQNPolicy
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_generators import sparse_rail_generator
-from flatland.envs.schedule_generators import sparse_schedule_generator
-from utils.observation_utils import normalize_observation
-from flatland.envs.observations import TreeObsForRailEnv
-
-"""
-This file shows how to train a single agent using a reinforcement learning approach.
-Documentation: https://flatland.aicrowd.com/getting-started/rl/single-agent.html
-
-This is a simple method used for demonstration purposes.
-multi_agent_training.py is a better starting point to train your own solution!
-"""
-
-
-def train_agent(n_episodes):
-    # Environment parameters
-    n_agents = 1
-    x_dim = 25
-    y_dim = 25
-    n_cities = 4
-    max_rails_between_cities = 2
-    max_rails_in_city = 3
-    seed = 42
-
-    # Observation parameters
-    observation_tree_depth = 2
-    observation_radius = 10
-
-    # Exploration parameters
-    eps_start = 1.0
-    eps_end = 0.01
-    eps_decay = 0.997  # for 2500ts
-
-    # Set the seeds
-    random.seed(seed)
-    np.random.seed(seed)
-
-    # Observation builder
-    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth)
-
-    # Setup the environment
-    env = RailEnv(
-        width=x_dim,
-        height=y_dim,
-        rail_generator=sparse_rail_generator(
-            max_num_cities=n_cities,
-            seed=seed,
-            grid_mode=False,
-            max_rails_between_cities=max_rails_between_cities,
-            max_rails_in_city=max_rails_in_city
-        ),
-        schedule_generator=sparse_schedule_generator(),
-        number_of_agents=n_agents,
-        obs_builder_object=tree_observation
-    )
-
-    env.reset(True, True)
-
-    # Calculate the state size given the depth of the tree observation and the number of features
-    n_features_per_node = env.obs_builder.observation_dim
-    n_nodes = 0
-    for i in range(observation_tree_depth + 1):
-        n_nodes += np.power(4, i)
-    state_size = n_features_per_node * n_nodes
-
-    # The action space of flatland is 5 discrete actions
-    action_size = 5
-
-    # Max number of steps per episode
-    # This is the official formula used during evaluations
-    max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
-
-    action_dict = dict()
-
-    # And some variables to keep track of the progress
-    scores_window = deque(maxlen=100)  # todo smooth when rendering instead
-    completion_window = deque(maxlen=100)
-    scores = []
-    completion = []
-    action_count = [0] * action_size
-    agent_obs = [None] * env.get_num_agents()
-    agent_prev_obs = [None] * env.get_num_agents()
-    agent_prev_action = [2] * env.get_num_agents()
-    update_values = False
-
-    # Training parameters
-    training_parameters = {
-        'buffer_size': int(1e5),
-        'batch_size': 32,
-        'update_every': 8,
-        'learning_rate': 0.5e-4,
-        'tau': 1e-3,
-        'gamma': 0.99,
-        'buffer_min_size': 0,
-        'hidden_size': 256,
-        'use_gpu': False
-    }
-
-    # Double Dueling DQN policy
-    policy = DDDQNPolicy(state_size, action_size, Namespace(**training_parameters))
-
-    for episode_idx in range(n_episodes):
-        score = 0
-
-        # Reset environment
-        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
-
-        # Build agent specific observations
-        for agent in env.get_agent_handles():
-            if obs[agent]:
-                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth, observation_radius=observation_radius)
-                agent_prev_obs[agent] = agent_obs[agent].copy()
-
-        # Run episode
-        for step in range(max_steps - 1):
-            for agent in env.get_agent_handles():
-                if info['action_required'][agent]:
-                    # If an action is required, we want to store the obs at that step as well as the action
-                    update_values = True
-                    action = policy.act(agent_obs[agent], eps=eps_start)
-                    action_count[action] += 1
-                else:
-                    update_values = False
-                    action = 0
-                action_dict.update({agent: action})
-
-            # Environment step
-            next_obs, all_rewards, done, info = env.step(action_dict)
-
-            # Update replay buffer and train agent
-            for agent in range(env.get_num_agents()):
-                # Only update the values when we are done or when an action was taken and thus relevant information is present
-                if update_values or done[agent]:
-                    policy.step(agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent], agent_obs[agent], done[agent])
-
-                    agent_prev_obs[agent] = agent_obs[agent].copy()
-                    agent_prev_action[agent] = action_dict[agent]
-
-                if next_obs[agent]:
-                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth, observation_radius=10)
-
-                score += all_rewards[agent]
-
-            if done['__all__']:
-                break
-
-        # Epsilon decay
-        eps_start = max(eps_end, eps_decay * eps_start)
-
-        # Collection information about training
-        tasks_finished = np.sum([int(done[idx]) for idx in env.get_agent_handles()])
-        completion_window.append(tasks_finished / max(1, env.get_num_agents()))
-        scores_window.append(score / (max_steps * env.get_num_agents()))
-        completion.append((np.mean(completion_window)))
-        scores.append(np.mean(scores_window))
-        action_probs = action_count / np.sum(action_count)
-
-        if episode_idx % 100 == 0:
-            end = "\n"
-            torch.save(policy.qnetwork_local, './checkpoints/single-' + str(episode_idx) + '.pth')
-            action_count = [1] * action_size
-        else:
-            end = " "
-
-        print('\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
-            env.get_num_agents(),
-            x_dim, y_dim,
-            episode_idx,
-            np.mean(scores_window),
-            100 * np.mean(completion_window),
-            eps_start,
-            action_probs
-        ), end=end)
-
-    # Plot overall training progress at the end
-    plt.plot(scores)
-    plt.show()
-
-    plt.plot(completion)
-    plt.show()
-
-
-if __name__ == "__main__":
-    parser = ArgumentParser()
-    parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500, type=int)
-    args = parser.parse_args()
-
-    train_agent(args.n_episodes)
+import random
+import sys
+from argparse import ArgumentParser, Namespace
+from collections import deque
+from pathlib import Path
+
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from utils.observation_utils import normalize_observation
+from flatland.envs.observations import TreeObsForRailEnv
+
+"""
+This file shows how to train a single agent using a reinforcement learning approach.
+Documentation: https://flatland.aicrowd.com/getting-started/rl/single-agent.html
+
+This is a simple method used for demonstration purposes.
+multi_agent_training.py is a better starting point to train your own solution!
+"""
+
+
+def train_agent(n_episodes):
+    # Environment parameters
+    n_agents = 1
+    x_dim = 25
+    y_dim = 25
+    n_cities = 4
+    max_rails_between_cities = 2
+    max_rails_in_city = 3
+    seed = 42
+
+    # Observation parameters
+    observation_tree_depth = 2
+    observation_radius = 10
+
+    # Exploration parameters
+    eps_start = 1.0
+    eps_end = 0.01
+    eps_decay = 0.997  # for 2500ts
+
+    # Set the seeds
+    random.seed(seed)
+    np.random.seed(seed)
+
+    # Observation builder
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth)
+
+    # Setup the environment
+    env = RailEnv(
+        width=x_dim,
+        height=y_dim,
+        rail_generator=sparse_rail_generator(
+            max_num_cities=n_cities,
+            seed=seed,
+            grid_mode=False,
+            max_rails_between_cities=max_rails_between_cities,
+            max_rails_in_city=max_rails_in_city
+        ),
+        schedule_generator=sparse_schedule_generator(),
+        number_of_agents=n_agents,
+        obs_builder_object=tree_observation
+    )
+
+    env.reset(True, True)
+
+    # Calculate the state size given the depth of the tree observation and the number of features
+    n_features_per_node = env.obs_builder.observation_dim
+    n_nodes = 0
+    for i in range(observation_tree_depth + 1):
+        n_nodes += np.power(4, i)
+    state_size = n_features_per_node * n_nodes
+
+    # The action space of flatland is 5 discrete actions
+    action_size = 5
+
+    # Max number of steps per episode
+    # This is the official formula used during evaluations
+    max_steps = int(4 * 2 * (env.height + env.width + (n_agents / n_cities)))
+
+    action_dict = dict()
+
+    # And some variables to keep track of the progress
+    scores_window = deque(maxlen=100)  # todo smooth when rendering instead
+    completion_window = deque(maxlen=100)
+    scores = []
+    completion = []
+    action_count = [0] * action_size
+    agent_obs = [None] * env.get_num_agents()
+    agent_prev_obs = [None] * env.get_num_agents()
+    agent_prev_action = [2] * env.get_num_agents()
+    update_values = False
+
+    # Training parameters
+    training_parameters = {
+        'buffer_size': int(1e5),
+        'batch_size': 32,
+        'update_every': 8,
+        'learning_rate': 0.5e-4,
+        'tau': 1e-3,
+        'gamma': 0.99,
+        'buffer_min_size': 0,
+        'hidden_size': 256,
+        'use_gpu': False
+    }
+
+    # Double Dueling DQN policy
+    policy = DDDQNPolicy(state_size, action_size, Namespace(**training_parameters))
+
+    for episode_idx in range(n_episodes):
+        score = 0
+
+        # Reset environment
+        obs, info = env.reset(regenerate_rail=True, regenerate_schedule=True)
+
+        # Build agent specific observations
+        for agent in env.get_agent_handles():
+            if obs[agent]:
+                agent_obs[agent] = normalize_observation(obs[agent], observation_tree_depth,
+                                                         observation_radius=observation_radius)
+                agent_prev_obs[agent] = agent_obs[agent].copy()
+
+        # Run episode
+        for step in range(max_steps - 1):
+            for agent in env.get_agent_handles():
+                if info['action_required'][agent]:
+                    # If an action is required, we want to store the obs at that step as well as the action
+                    update_values = True
+                    action = policy.act(agent, agent_obs[agent], eps=eps_start)
+                    action_count[action] += 1
+                else:
+                    update_values = False
+                    action = 0
+                action_dict.update({agent: action})
+
+            # Environment step
+            next_obs, all_rewards, done, info = env.step(action_dict)
+
+            # Update replay buffer and train agent
+            for agent in range(env.get_num_agents()):
+                # Only update the values when we are done or when an action was taken and thus relevant information is present
+                if update_values or done[agent]:
+                    policy.step(agent,
+                                agent_prev_obs[agent], agent_prev_action[agent], all_rewards[agent],
+                                agent_obs[agent], done[agent])
+
+                    agent_prev_obs[agent] = agent_obs[agent].copy()
+                    agent_prev_action[agent] = action_dict[agent]
+
+                if next_obs[agent]:
+                    agent_obs[agent] = normalize_observation(next_obs[agent], observation_tree_depth,
+                                                             observation_radius=10)
+
+                score += all_rewards[agent]
+
+            if done['__all__']:
+                break
+
+        # Epsilon decay
+        eps_start = max(eps_end, eps_decay * eps_start)
+
+        # Collection information about training
+        tasks_finished = np.sum([int(done[idx]) for idx in env.get_agent_handles()])
+        completion_window.append(tasks_finished / max(1, env.get_num_agents()))
+        scores_window.append(score / (max_steps * env.get_num_agents()))
+        completion.append((np.mean(completion_window)))
+        scores.append(np.mean(scores_window))
+        action_probs = action_count / np.sum(action_count)
+
+        if episode_idx % 100 == 0:
+            end = "\n"
+            torch.save(policy.qnetwork_local, './checkpoints/single-' + str(episode_idx) + '.pth')
+            action_count = [1] * action_size
+        else:
+            end = " "
+
+        print(
+            '\rTraining {} agents on {}x{}\t Episode {}\t Average Score: {:.3f}\tDones: {:.2f}%\tEpsilon: {:.2f} \t Action Probabilities: \t {}'.format(
+                env.get_num_agents(),
+                x_dim, y_dim,
+                episode_idx,
+                np.mean(scores_window),
+                100 * np.mean(completion_window),
+                eps_start,
+                action_probs
+            ), end=end)
+
+    # Plot overall training progress at the end
+    plt.plot(scores)
+    plt.show()
+
+    plt.plot(completion)
+    plt.show()
+
+
+if __name__ == "__main__":
+    parser = ArgumentParser()
+    parser.add_argument("-n", "--n_episodes", dest="n_episodes", help="number of episodes to run", default=500,
+                        type=int)
+    args = parser.parse_args()
+
+    train_agent(args.n_episodes)
diff --git a/replay_buffers/.gitkeep b/replay_buffers/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/run.py b/run.py
index 674ae41557941f1c6d26db1ea055bb0aa5c9cd6a..881d0b080cdd17a01d0acfbe41c06d92b3aebea0 100644
--- a/run.py
+++ b/run.py
@@ -1,201 +1,328 @@
+'''
+DDDQNPolicy experiments - EPSILON impact analysis
+----------------------------------------------------------------------------------------
+checkpoint = "./checkpoints/201124171810-7800.pth"  # Training on AGENTS=10 with Depth=2
+EPSILON = 0.000 # Sum Normalized Reward :  0.000000000000000 (primary score)
+EPSILON = 0.002 # Sum Normalized Reward : 18.445875081269286 (primary score)
+EPSILON = 0.005 # Sum Normalized Reward : 18.371733625865854 (primary score)
+EPSILON = 0.010 # Sum Normalized Reward : 18.249244799876152 (primary score)
+EPSILON = 0.020 # Sum Normalized Reward : 17.526987022691376 (primary score)
+EPSILON = 0.030 # Sum Normalized Reward : 16.796885571003942 (primary score)
+EPSILON = 0.040 # Sum Normalized Reward : 17.280787151431426 (primary score)
+EPSILON = 0.050 # Sum Normalized Reward : 16.256945636647025 (primary score)
+EPSILON = 0.100 # Sum Normalized Reward : 14.828347241759966 (primary score)
+EPSILON = 0.200 # Sum Normalized Reward : 11.192330074898457 (primary score)
+EPSILON = 0.300 # Sum Normalized Reward : 14.523067754608782 (primary score)
+EPSILON = 0.400 # Sum Normalized Reward : 12.901508220410834 (primary score)
+EPSILON = 0.500 # Sum Normalized Reward :  3.754660231871272 (primary score)
+EPSILON = 1.000 # Sum Normalized Reward :  1.397180159192391 (primary score)
+'''
+
+import sys
 import time
+from argparse import Namespace
+from pathlib import Path
 
 import numpy as np
 from flatland.core.env_observation_builder import DummyObservationBuilder
 from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import RailEnvActions
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.evaluators.client import FlatlandRemoteClient
+from flatland.evaluators.client import TimeoutException
 
+from reinforcement_learning.dddqn_policy import DDDQNPolicy
+from reinforcement_learning.deadlockavoidance_with_decision_agent import DeadLockAvoidanceWithDecisionAgent
+from reinforcement_learning.multi_decision_agent import MultiDecisionAgent
+from reinforcement_learning.ppo_agent import PPOPolicy
+from utils.agent_action_config import get_action_size, map_actions, set_action_size_reduced, set_action_size_full
 from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+from utils.deadlock_check import check_if_all_blocked
+from utils.fast_tree_obs import FastTreeObs
+from utils.observation_utils import normalize_observation
 
-#####################################################################
-# Instantiate a Remote Client
-#####################################################################
+base_dir = Path(__file__).resolve().parent.parent
+sys.path.append(str(base_dir))
+
+####################################################
+# EVALUATION PARAMETERS
+set_action_size_full()
+
+# Print per-step logs
+VERBOSE = True
+USE_FAST_TREEOBS = True
+
+if False:
+    # -------------------------------------------------------------------------------------------------------
+    # RL solution
+    # -------------------------------------------------------------------------------------------------------
+    # 116591 adrian_egli
+    # graded	71.305	0.633	RL	Successfully Graded ! More details about this submission can be found at:
+    # http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/51
+    # Fri, 22 Jan 2021 23:37:56
+    set_action_size_reduced()
+    load_policy = "DDDQN"
+    checkpoint = "./checkpoints/210122120236-3000.pth"  # 17.011131341978228
+    EPSILON = 0.0
+
+if False:
+    # -------------------------------------------------------------------------------------------------------
+    # RL solution
+    # -------------------------------------------------------------------------------------------------------
+    # 116658 adrian_egli
+    # graded	73.821	0.655	RL	Successfully Graded ! More details about this submission can be found at:
+    # http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/52
+    # Sat, 23 Jan 2021 07:41:35
+    set_action_size_reduced()
+    load_policy = "PPO"
+    checkpoint = "./checkpoints/210122235754-5000.pth"  # 16.00113400887389
+    EPSILON = 0.0
+
+if True:
+    # -------------------------------------------------------------------------------------------------------
+    # RL solution
+    # -------------------------------------------------------------------------------------------------------
+    # 116659 adrian_egli
+    # graded	80.579	0.715	RL	Successfully Graded ! More details about this submission can be found at:
+    # http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/53
+    # Sat, 23 Jan 2021 07:45:49
+    set_action_size_reduced()
+    load_policy = "DDDQN"
+    checkpoint = "./checkpoints/210122165109-5000.pth"  # 17.993750197899438
+    EPSILON = 0.0
+
+if False:
+    # -------------------------------------------------------------------------------------------------------
+    # !! This is not a RL solution !!!!
+    # -------------------------------------------------------------------------------------------------------
+    # 116727 adrian_egli
+    # graded	106.786	0.768	RL	Successfully Graded ! More details about this submission can be found at:
+    # http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/54
+    # Sat, 23 Jan 2021 14:31:50
+    set_action_size_reduced()
+    load_policy = "DeadLockAvoidance"
+    checkpoint = None
+    EPSILON = 0.0
+
+# load_policy = "DeadLockAvoidance" # 22.13346834815911
+
+# Use last action cache
+USE_ACTION_CACHE = False
+
+# Observation parameters (must match training parameters!)
+observation_tree_depth = 2
+observation_radius = 10
+observation_max_path_depth = 30
+
+####################################################
 
 remote_client = FlatlandRemoteClient()
 
+# Observation builder
+predictor = ShortestPathPredictorForRailEnv(observation_max_path_depth)
+if USE_FAST_TREEOBS:
+    def check_is_observation_valid(observation):
+        return True
 
-#####################################################################
-# Define your custom controller
-#
-# which can take an observation, and the number of agents and 
-# compute the necessary action for this step for all (or even some)
-# of the agents
-#####################################################################
-# def my_controller_RL(extra: Extra, observation, info):
-#     return extra.rl_agent_act(observation, info)
-
-def my_controller(policy, info):
-    policy.start_step()
-    actions = {}
-    # print("-------- act ------------")
-    for handle in range(policy.env.get_num_agents()):
-        if info['action_required'][handle] and handle < policy.env._elapsed_steps:
-            a = policy.act(handle, None, 0)
-        else:
-            a = RailEnvActions.DO_NOTHING
-            agent = policy.env.agents[handle]
-        actions.update({handle: a})
-    policy.end_step()
-    return actions
 
+    def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+        return observation
 
-#####################################################################
-# Instantiate your custom Observation Builder
-# 
-# You can build your own Observation Builder by following 
-# the example here : 
-# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
-#####################################################################
-# my_observation_builder = Extra(max_depth=1)
-my_observation_builder = DummyObservationBuilder()
 
-# Or if you want to use your own approach to build the observation from the env_step, 
-# please feel free to pass a DummyObservationBuilder() object as mentioned below,
-# and that will just return a placeholder True for all observation, and you 
-# can build your own Observation for all the agents as your please.
-# my_observation_builder = DummyObservationBuilder()
+    tree_observation = FastTreeObs(max_depth=observation_tree_depth)
+    state_size = tree_observation.observation_dim
+else:
+    def check_is_observation_valid(observation):
+        return observation
+
+
+    def get_normalized_observation(observation, tree_depth: int, observation_radius=0):
+        return normalize_observation(observation, tree_depth, observation_radius)
+
 
+    tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth, predictor=predictor)
+    # Calculate the state size given the depth of the tree observation and the number of features
+    n_features_per_node = tree_observation.observation_dim
+    n_nodes = sum([np.power(4, i) for i in range(observation_tree_depth + 1)])
+    state_size = n_features_per_node * n_nodes
 
 #####################################################################
 # Main evaluation loop
-#
-# This iterates over an arbitrary number of env evaluations
 #####################################################################
 evaluation_number = 0
-while True:
 
+while True:
     evaluation_number += 1
-    # Switch to a new evaluation environemnt
-    # 
-    # a remote_client.env_create is similar to instantiating a 
-    # RailEnv and then doing a env.reset()
-    # hence it returns the first observation from the 
-    # env.reset()
-    # 
-    # You can also pass your custom observation_builder object
-    # to allow you to have as much control as you wish 
-    # over the observation of your choice.
+
+    # We use a dummy observation and call TreeObsForRailEnv ourselves when needed.
+    # This way we decide if we want to calculate the observations or not instead
+    # of having them calculated every time we perform an env step.
     time_start = time.time()
     observation, info = remote_client.env_create(
-        obs_builder_object=my_observation_builder
+        obs_builder_object=DummyObservationBuilder()
     )
+    env_creation_time = time.time() - time_start
+
     if not observation:
-        #
         # If the remote_client returns False on a `env_create` call,
-        # then it basically means that your agent has already been 
+        # then it basically means that your agent has already been
         # evaluated on all the required evaluation environments,
-        # and hence its safe to break out of the main evaluation loop
+        # and hence it's safe to break out of the main evaluation loop.
         break
 
-    print("Evaluation Number : {}".format(evaluation_number))
-
-    #####################################################################
-    # Access to a local copy of the environment
-    # 
-    #####################################################################
-    # Note: You can access a local copy of the environment 
-    # by using : 
-    #       remote_client.env 
-    # 
-    # But please ensure to not make any changes (or perform any action) on 
-    # the local copy of the env, as then it will diverge from 
-    # the state of the remote copy of the env, and the observations and 
-    # rewards, etc will behave unexpectedly
-    # 
-    # You can however probe the local_env instance to get any information
-    # you need from the environment. It is a valid RailEnv instance.
+    print("Env Path : ", remote_client.current_env_path)
+    print("Env Creation Time : ", env_creation_time)
+
     local_env = remote_client.env
-    number_of_agents = len(local_env.agents)
+    nb_agents = len(local_env.agents)
+    max_nb_steps = local_env._max_episode_steps
+
+    tree_observation.set_env(local_env)
+    tree_observation.reset()
+
+    # Creates the policy. No GPU on evaluation server.
+    if load_policy == "DDDQN":
+        policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True)
+    elif load_policy == "PPO":
+        policy = PPOPolicy(state_size, get_action_size())
+    elif load_policy == "DeadLockAvoidance":
+        policy = DeadLockAvoidanceAgent(local_env, get_action_size(), enable_eps=False)
+    elif load_policy == "DeadLockAvoidanceWithDecision":
+        # inter_policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False, in_parameters=train_params)
+        inter_policy = DDDQNPolicy(state_size, get_action_size(), Namespace(**{'use_gpu': False}), evaluation_mode=True)
+        policy = DeadLockAvoidanceWithDecisionAgent(local_env, state_size, get_action_size(), inter_policy)
+    elif load_policy == "MultiDecision":
+        policy = MultiDecisionAgent(state_size, get_action_size(), Namespace(**{'use_gpu': False}))
+    else:
+        policy = PPOPolicy(state_size, get_action_size(), use_replay_buffer=False,
+                           in_parameters=Namespace(**{'use_gpu': False}))
+
+    policy.load(checkpoint)
 
-    policy = DeadLockAvoidanceAgent(local_env, None, None)
+    policy.reset(local_env)
+    observation = tree_observation.get_many(list(range(nb_agents)))
+
+    print("Evaluation {}: {} agents in {}x{}".format(evaluation_number, nb_agents, local_env.width, local_env.height))
 
     # Now we enter into another infinite loop where we
     # compute the actions for all the individual steps in this episode
     # until the episode is `done`
-    # 
-    # An episode is considered done when either all the agents have 
-    # reached their target destination
-    # or when the number of time steps has exceed max_time_steps, which 
-    # is defined by : 
-    #
-    # max_time_steps = int(4 * 2 * (env.width + env.height + 20))
-    #
+    steps = 0
+
+    # Bookkeeping
     time_taken_by_controller = []
     time_taken_per_step = []
-    steps = 0
 
-    extra = my_observation_builder
-    env_creation_time = time.time() - time_start
-    print("Env Creation Time : ", env_creation_time)
-    print("Agents : ", extra.env.get_num_agents())
-    print("w : ", extra.env.width)
-    print("h : ", extra.env.height)
+    # Action cache: keep track of last observation to avoid running the same inferrence multiple times.
+    # This only makes sense for deterministic policies.
+    agent_last_obs = {}
+    agent_last_action = {}
+    nb_hit = 0
 
-    old_total_done = 0
-    old_total_active = 0
+    policy.start_episode(train=False)
     while True:
-        #####################################################################
-        # Evaluation of a single episode
-        #
-        #####################################################################
-        # Compute the action for this step by using the previously 
-        # defined controller
-        time_start = time.time()
-        action = my_controller(policy, info)
-        time_taken = time.time() - time_start
-        time_taken_by_controller.append(time_taken)
-
-        # Perform the chosen action on the environment.
-        # The action gets applied to both the local and the remote copy 
-        # of the environment instance, and the observation is what is 
-        # returned by the local copy of the env, and the rewards, and done and info
-        # are returned by the remote copy of the env
-        time_start = time.time()
-        observation, all_rewards, done, info = remote_client.env_step(action)
-        steps += 1
-        time_taken = time.time() - time_start
-        time_taken_per_step.append(time_taken)
-
-        total_done = 0
-        total_active = 0
-        for a in range(local_env.get_num_agents()):
-            x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
-            total_done += int(x)
-            total_active += int(local_env.agents[a].status == RailAgentStatus.ACTIVE)
-        if old_total_done != total_done or old_total_active != total_active:
-            print("total_done:", total_done, "\ttotal_active", total_active, "\t num agents",
-                  local_env.get_num_agents())
-        old_total_done = total_done
-        old_total_active = total_active
-
-        if done['__all__']:
-            print("Reward : ", sum(list(all_rewards.values())))
-            #
-            # When done['__all__'] == True, then the evaluation of this 
-            # particular Env instantiation is complete, and we can break out 
-            # of this loop, and move onto the next Env evaluation
+        try:
+            #####################################################################
+            # Evaluation of a single episode
+            #####################################################################
+            steps += 1
+            obs_time, agent_time, step_time = 0.0, 0.0, 0.0
+            no_ops_mode = False
+
+            if not check_if_all_blocked(env=local_env):
+                time_start = time.time()
+                action_dict = {}
+                policy.start_step(train=False)
+                for agent_handle in range(nb_agents):
+                    if info['action_required'][agent_handle]:
+                        if agent_handle in agent_last_obs and np.all(
+                                agent_last_obs[agent_handle] == observation[agent_handle]):
+                            # cache hit
+                            action = agent_last_action[agent_handle]
+                            nb_hit += 1
+                        else:
+                            normalized_observation = get_normalized_observation(observation[agent_handle],
+                                                                                observation_tree_depth,
+                                                                                observation_radius=observation_radius)
+
+                            action = policy.act(agent_handle, normalized_observation, eps=EPSILON)
+
+                    action_dict[agent_handle] = action
+
+                    if USE_ACTION_CACHE:
+                        agent_last_obs[agent_handle] = observation[agent_handle]
+                        agent_last_action[agent_handle] = action
+
+                policy.end_step(train=False)
+                agent_time = time.time() - time_start
+                time_taken_by_controller.append(agent_time)
+
+                time_start = time.time()
+                _, all_rewards, done, info = remote_client.env_step(map_actions(action_dict))
+                step_time = time.time() - time_start
+                time_taken_per_step.append(step_time)
+
+                time_start = time.time()
+                observation = tree_observation.get_many(list(range(nb_agents)))
+                obs_time = time.time() - time_start
+
+            else:
+                # Fully deadlocked: perform no-ops
+                no_ops_mode = True
+
+                time_start = time.time()
+                _, all_rewards, done, info = remote_client.env_step({})
+                step_time = time.time() - time_start
+                time_taken_per_step.append(step_time)
+
+            nb_agents_done = 0
+            for i_agent, agent in enumerate(local_env.agents):
+                # manage the boolean flag to check if all agents are indeed done (or done_removed)
+                if (agent.status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]):
+                    nb_agents_done += 1
+
+            if VERBOSE or done['__all__']:
+                print(
+                    "Step {}/{}\tAgents done: {}\t Obs time {:.3f}s\t Inference time {:.5f}s\t Step time {:.3f}s\t Cache hits {}\t No-ops? {}".format(
+                        str(steps).zfill(4),
+                        max_nb_steps,
+                        nb_agents_done,
+                        obs_time,
+                        agent_time,
+                        step_time,
+                        nb_hit,
+                        no_ops_mode
+                    ), end="\r")
+
+            if done['__all__']:
+                # When done['__all__'] == True, then the evaluation of this
+                # particular Env instantiation is complete, and we can break out
+                # of this loop, and move onto the next Env evaluation
+                print()
+                break
+
+        except TimeoutException as err:
+            # A timeout occurs, won't get any reward for this episode :-(
+            # Skip to next episode as further actions in this one will be ignored.
+            # The whole evaluation will be stopped if there are 10 consecutive timeouts.
+            print("Timeout! Will skip this episode and go to the next.", err)
             break
 
+    policy.end_episode(train=False)
+
     np_time_taken_by_controller = np.array(time_taken_by_controller)
     np_time_taken_per_step = np.array(time_taken_per_step)
-    print("=" * 100)
-    print("=" * 100)
-    print("Evaluation Number : ", evaluation_number)
-    print("Current Env Path : ", remote_client.current_env_path)
-    print("Env Creation Time : ", env_creation_time)
-    print("Number of Steps : ", steps)
     print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
           np_time_taken_by_controller.std())
     print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
     print("=" * 100)
 
-print("Evaluation of all environments complete...")
+print("Evaluation of all environments complete!")
 ########################################################################
 # Submit your Results
-# 
-# Please do not forget to include this call, as this triggers the 
+#
+# Please do not forget to include this call, as this triggers the
 # final computation of the score statistics, video generation, etc
-# and is necesaary to have your submission marked as successfully evaluated
+# and is necessary to have your submission marked as successfully evaluated
 ########################################################################
 print(remote_client.submit())
diff --git a/runs_bench/Jan14_10-56-32_K57261_PPO_reduced/events.out.tfevents.1610618195.K57261.15412.0 b/runs_bench/Jan14_10-56-32_K57261_PPO_reduced/events.out.tfevents.1610618195.K57261.15412.0
new file mode 100644
index 0000000000000000000000000000000000000000..af1eddbb10807e69b8bd41a7325ac3842a57d74c
Binary files /dev/null and b/runs_bench/Jan14_10-56-32_K57261_PPO_reduced/events.out.tfevents.1610618195.K57261.15412.0 differ
diff --git a/runs_bench/Jan18_09-32-17_K57261_DDDQN_reduced/events.out.tfevents.1610958740.K57261.6608.0 b/runs_bench/Jan18_09-32-17_K57261_DDDQN_reduced/events.out.tfevents.1610958740.K57261.6608.0
new file mode 100644
index 0000000000000000000000000000000000000000..87417209df52224b344a0504297d78c818450bfa
Binary files /dev/null and b/runs_bench/Jan18_09-32-17_K57261_DDDQN_reduced/events.out.tfevents.1610958740.K57261.6608.0 differ
diff --git a/runs_bench/Jan18_09-34-10_K57261_DeadLockAvoidance_EPS_reduced/events.out.tfevents.1610958853.K57261.10660.0 b/runs_bench/Jan18_09-34-10_K57261_DeadLockAvoidance_EPS_reduced/events.out.tfevents.1610958853.K57261.10660.0
new file mode 100644
index 0000000000000000000000000000000000000000..99a0f3dc0b36172a60eab9a5ccc188df538e8ea2
Binary files /dev/null and b/runs_bench/Jan18_09-34-10_K57261_DeadLockAvoidance_EPS_reduced/events.out.tfevents.1610958853.K57261.10660.0 differ
diff --git a/runs_bench/Jan18_11-47-54_K57261_DeadLockAvoidance_reduced/events.out.tfevents.1610966876.K57261.4332.0 b/runs_bench/Jan18_11-47-54_K57261_DeadLockAvoidance_reduced/events.out.tfevents.1610966876.K57261.4332.0
new file mode 100644
index 0000000000000000000000000000000000000000..dcb4ee1b567ea4c0cefea764772f5d7e4438c5b7
Binary files /dev/null and b/runs_bench/Jan18_11-47-54_K57261_DeadLockAvoidance_reduced/events.out.tfevents.1610966876.K57261.4332.0 differ
diff --git a/runs_bench/Jan18_11-56-16_K57261_DeadLockAvoidanceWithDecision_reduced/events.out.tfevents.1610967379.K57261.14680.0 b/runs_bench/Jan18_11-56-16_K57261_DeadLockAvoidanceWithDecision_reduced/events.out.tfevents.1610967379.K57261.14680.0
new file mode 100644
index 0000000000000000000000000000000000000000..6cd3eb9ab4c551a9e6f79161015fff9a0f6528e1
Binary files /dev/null and b/runs_bench/Jan18_11-56-16_K57261_DeadLockAvoidanceWithDecision_reduced/events.out.tfevents.1610967379.K57261.14680.0 differ
diff --git a/runs_bench/Jan18_13-46-59_K57261_MultiDecisionAgent_reduced/events.out.tfevents.1610974021.K57261.12972.0 b/runs_bench/Jan18_13-46-59_K57261_MultiDecisionAgent_reduced/events.out.tfevents.1610974021.K57261.12972.0
new file mode 100644
index 0000000000000000000000000000000000000000..63d1f310c3a08c5b1a58aed74d563633ff6a75a4
Binary files /dev/null and b/runs_bench/Jan18_13-46-59_K57261_MultiDecisionAgent_reduced/events.out.tfevents.1610974021.K57261.12972.0 differ
diff --git a/runs_bench/Jan18_14-53-57_K57261_PPO_full/events.out.tfevents.1610978039.K57261.484.0 b/runs_bench/Jan18_14-53-57_K57261_PPO_full/events.out.tfevents.1610978039.K57261.484.0
new file mode 100644
index 0000000000000000000000000000000000000000..09e669d9f22804c398ffda38b4e8a4a68cee9e7a
Binary files /dev/null and b/runs_bench/Jan18_14-53-57_K57261_PPO_full/events.out.tfevents.1610978039.K57261.484.0 differ
diff --git a/runs_bench/Jan18_14-57-56_K57261_DDDQN_full/events.out.tfevents.1610978281.K57261.19984.0 b/runs_bench/Jan18_14-57-56_K57261_DDDQN_full/events.out.tfevents.1610978281.K57261.19984.0
new file mode 100644
index 0000000000000000000000000000000000000000..dad2128599d46e5f29250304d7c14d9b3e10ddc0
Binary files /dev/null and b/runs_bench/Jan18_14-57-56_K57261_DDDQN_full/events.out.tfevents.1610978281.K57261.19984.0 differ
diff --git a/runs_bench/Jan18_16-05-23_K57261_DeadLockAvoidance_EPS_full/events.out.tfevents.1610982327.K57261.6264.0 b/runs_bench/Jan18_16-05-23_K57261_DeadLockAvoidance_EPS_full/events.out.tfevents.1610982327.K57261.6264.0
new file mode 100644
index 0000000000000000000000000000000000000000..ddc9a93711dd9c17f90e725b9a059ed5bb144648
Binary files /dev/null and b/runs_bench/Jan18_16-05-23_K57261_DeadLockAvoidance_EPS_full/events.out.tfevents.1610982327.K57261.6264.0 differ
diff --git a/runs_bench/Jan18_16-14-19_K57261_DeadLockAvoidance_full/events.out.tfevents.1610982862.K57261.14612.0 b/runs_bench/Jan18_16-14-19_K57261_DeadLockAvoidance_full/events.out.tfevents.1610982862.K57261.14612.0
new file mode 100644
index 0000000000000000000000000000000000000000..cb56ffbe6656314c84fb9c402e26eeda2ab3bade
Binary files /dev/null and b/runs_bench/Jan18_16-14-19_K57261_DeadLockAvoidance_full/events.out.tfevents.1610982862.K57261.14612.0 differ
diff --git a/runs_bench/Jan18_16-43-41_K57261_DeadLockAvoidanceWithDecision_full/events.out.tfevents.1610984623.K57261.17628.0 b/runs_bench/Jan18_16-43-41_K57261_DeadLockAvoidanceWithDecision_full/events.out.tfevents.1610984623.K57261.17628.0
new file mode 100644
index 0000000000000000000000000000000000000000..a7944f80744cddafc0c4164726e4858b69391960
Binary files /dev/null and b/runs_bench/Jan18_16-43-41_K57261_DeadLockAvoidanceWithDecision_full/events.out.tfevents.1610984623.K57261.17628.0 differ
diff --git a/runs_bench/Jan18_16-45-04_K57261_MultiDecision_full/events.out.tfevents.1610984709.K57261.1796.0 b/runs_bench/Jan18_16-45-04_K57261_MultiDecision_full/events.out.tfevents.1610984709.K57261.1796.0
new file mode 100644
index 0000000000000000000000000000000000000000..122ed94ad037caa702ac7331cca63a7d7ea28666
Binary files /dev/null and b/runs_bench/Jan18_16-45-04_K57261_MultiDecision_full/events.out.tfevents.1610984709.K57261.1796.0 differ
diff --git a/runs_bench/Screenshots/full.png b/runs_bench/Screenshots/full.png
new file mode 100644
index 0000000000000000000000000000000000000000..a7328bf9fce92f71480b4b78601b044fba885358
Binary files /dev/null and b/runs_bench/Screenshots/full.png differ
diff --git a/runs_bench/Screenshots/reduced.png b/runs_bench/Screenshots/reduced.png
new file mode 100644
index 0000000000000000000000000000000000000000..9c2058748bda163154c10e139d3c7fd63f163e5b
Binary files /dev/null and b/runs_bench/Screenshots/reduced.png differ
diff --git a/sweep.yaml b/sweep.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bd9299caaf8a6c0bf2ee07c2d5292731cc2f2a28
--- /dev/null
+++ b/sweep.yaml
@@ -0,0 +1,21 @@
+# This sweep file can be used to run hyper-parameter search using Weight & Biases tools
+# See: https://docs.wandb.com/sweeps
+program: reinforcement_learning/multi_agent_training.py
+method: bayes
+metric:
+    name: evaluation/smoothed_score
+    goal: maximize
+parameters:
+    n_episodes:
+        values: [2000]
+    hidden_size:
+        # default: 256
+        values: [128, 256, 512]
+    buffer_size:
+        # default: 50000
+        values: [50000, 100000, 500000, 1000000]
+    batch_size:
+        # default: 32
+        values: [16, 32, 64, 128]
+    training_env_config:
+        values: [0, 1, 2]
\ No newline at end of file
diff --git a/utils/agent_action_config.py b/utils/agent_action_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..29750c9222965a044e4f447b7cdb8a97517b9953
--- /dev/null
+++ b/utils/agent_action_config.py
@@ -0,0 +1,73 @@
+from flatland.envs.rail_env import RailEnvActions
+
+# global action size
+global _agent_action_config_action_size
+_agent_action_config_action_size = 5
+
+
+def get_flatland_full_action_size():
+    # The action space of flatland is 5 discrete actions
+    return 5
+
+
+def set_action_size_full():
+    global _agent_action_config_action_size
+    # The agents (DDDQN, PPO, ... ) have this actions space
+    _agent_action_config_action_size = 5
+
+
+def set_action_size_reduced():
+    global _agent_action_config_action_size
+    # The agents (DDDQN, PPO, ... ) have this actions space
+    _agent_action_config_action_size = 4
+
+
+def get_action_size():
+    global _agent_action_config_action_size
+    # The agents (DDDQN, PPO, ... ) have this actions space
+    return _agent_action_config_action_size
+
+
+def map_actions(actions):
+    # Map the
+    if get_action_size() != get_flatland_full_action_size():
+        for key in actions:
+            value = actions.get(key, 0)
+            actions.update({key: map_action(value)})
+    return actions
+
+
+def map_action_policy(action):
+    if get_action_size() != get_flatland_full_action_size():
+        return action - 1
+    return action
+
+
+def map_action(action):
+    if get_action_size() == get_flatland_full_action_size():
+        return action
+
+    if action == 0:
+        return RailEnvActions.MOVE_LEFT
+    if action == 1:
+        return RailEnvActions.MOVE_FORWARD
+    if action == 2:
+        return RailEnvActions.MOVE_RIGHT
+    if action == 3:
+        return RailEnvActions.STOP_MOVING
+
+
+def map_rail_env_action(action):
+    if get_action_size() == get_flatland_full_action_size():
+        return action
+
+    if action == RailEnvActions.MOVE_LEFT:
+        return 0
+    elif action == RailEnvActions.MOVE_FORWARD:
+        return 1
+    elif action == RailEnvActions.MOVE_RIGHT:
+        return 2
+    elif action == RailEnvActions.STOP_MOVING:
+        return 3
+    # action == RailEnvActions.DO_NOTHING:
+    return 3
diff --git a/utils/agent_can_choose_helper.py b/utils/agent_can_choose_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..95636c6a13dc847e798ff283bfb4ccf8fc871a64
--- /dev/null
+++ b/utils/agent_can_choose_helper.py
@@ -0,0 +1,107 @@
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import fast_count_nonzero
+
+
+class AgentCanChooseHelper:
+    def __init__(self):
+        pass
+
+    def build_data(self, env):
+        self.env = env
+        if self.env is not None:
+            self.env.dev_obs_dict = {}
+        self.switches = {}
+        self.switches_neighbours = {}
+        if self.env is not None:
+            self.find_all_cell_where_agent_can_choose()
+
+    def find_all_switches(self):
+        # Search the environment (rail grid) for all switch cells. A switch is a cell where more than one tranisation
+        # exists and collect all direction where the switch is a switch.
+        self.switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in self.switches.keys():
+                            self.switches.update({pos: [dir]})
+                        else:
+                            self.switches[pos].append(dir)
+
+    def find_all_switch_neighbours(self):
+        # Collect all cells where is a neighbour to a switch cell. All cells are neighbour where the agent can make
+        # just one step and he stands on a switch. A switch is a cell where the agents has more than one transition.
+        self.switches_neighbours = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                # look one step forward
+                for dir in range(4):
+                    pos = (h, w)
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    for d in range(4):
+                        if possible_transitions[d] == 1:
+                            new_cell = get_new_position(pos, d)
+                            if new_cell in self.switches.keys() and pos not in self.switches.keys():
+                                if pos not in self.switches_neighbours.keys():
+                                    self.switches_neighbours.update({pos: [dir]})
+                                else:
+                                    self.switches_neighbours[pos].append(dir)
+
+    def find_all_cell_where_agent_can_choose(self):
+        # prepare the memory - collect all cells where the agent can choose more than FORWARD/STOP.
+        self.find_all_switches()
+        self.find_all_switch_neighbours()
+
+    def check_agent_decision(self, position, direction):
+        # Decide whether the agent is
+        # - on a switch
+        # - at a switch neighbour (near to switch). The switch must be a switch where the agent has more option than
+        #   FORWARD/STOP
+        # - all switch : doesn't matter whether the agent has more options than FORWARD/STOP
+        # - all switch neightbors : doesn't matter the agent has more then one options (transistion) when he reach the
+        #   switch
+        agents_on_switch = False
+        agents_on_switch_all = False
+        agents_near_to_switch = False
+        agents_near_to_switch_all = False
+        if position in self.switches.keys():
+            agents_on_switch = direction in self.switches[position]
+            agents_on_switch_all = True
+
+        if position in self.switches_neighbours.keys():
+            new_cell = get_new_position(position, direction)
+            if new_cell in self.switches.keys():
+                if not direction in self.switches[new_cell]:
+                    agents_near_to_switch = direction in self.switches_neighbours[position]
+            else:
+                agents_near_to_switch = direction in self.switches_neighbours[position]
+
+            agents_near_to_switch_all = direction in self.switches_neighbours[position]
+
+        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
+
+    def required_agent_decision(self):
+        agents_can_choose = {}
+        agents_on_switch = {}
+        agents_on_switch_all = {}
+        agents_near_to_switch = {}
+        agents_near_to_switch_all = {}
+        for a in range(self.env.get_num_agents()):
+            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
+                self.check_agent_decision(
+                    self.env.agents[a].position,
+                    self.env.agents[a].direction)
+            agents_on_switch.update({a: ret_agents_on_switch})
+            agents_on_switch_all.update({a: ret_agents_on_switch_all})
+            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
+            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
+
+            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
+
+            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
+
+        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
diff --git a/utils/dead_lock_avoidance_agent.py b/utils/dead_lock_avoidance_agent.py
index 39090a14edbe533297556853f8aecc8b34ddef14..4c4c9033d83a6ab578f183c985df308e914d002e 100644
--- a/utils/dead_lock_avoidance_agent.py
+++ b/utils/dead_lock_avoidance_agent.py
@@ -1,10 +1,27 @@
+from typing import Optional, List
+
 import matplotlib.pyplot as plt
 import numpy as np
+from flatland.core.env_observation_builder import DummyObservationBuilder
 from flatland.envs.agent_utils import RailAgentStatus
 from flatland.envs.rail_env import RailEnv, RailEnvActions, fast_count_nonzero
 
-from reinforcement_learning.policy import Policy
-from utils.shortest_Distance_walker import ShortestDistanceWalker
+from reinforcement_learning.policy import HeuristicPolicy, DummyMemory
+from utils.agent_action_config import map_rail_env_action
+from utils.shortest_distance_walker import ShortestDistanceWalker
+
+
+class DeadlockAvoidanceObservation(DummyObservationBuilder):
+    def __init__(self):
+        self.counter = 0
+
+    def get_many(self, handles: Optional[List[int]] = None) -> bool:
+        self.counter += 1
+        obs = np.ones(len(handles), 2)
+        for handle in handles:
+            obs[handle][0] = handle
+            obs[handle][1] = self.counter
+        return obs
 
 
 class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
@@ -49,29 +66,40 @@ class DeadlockAvoidanceShortestDistanceWalker(ShortestDistanceWalker):
                 self.shortest_distance_agent_map[(handle, position[0], position[1])] = 1
         self.full_shortest_distance_agent_map[(handle, position[0], position[1])] = 1
 
-
-class DeadLockAvoidanceAgent(Policy):
-    def __init__(self, env: RailEnv, state_size, action_size, show_debug_plot=False):
+class DeadLockAvoidanceAgent(HeuristicPolicy):
+    def __init__(self, env: RailEnv, action_size, enable_eps=False, show_debug_plot=False):
+        print(">> DeadLockAvoidance")
         self.env = env
-        self.action_size = action_size
-        self.state_size = state_size
-        self.memory = []
+        self.memory = DummyMemory()
         self.loss = 0
+        self.action_size = action_size
         self.agent_can_move = {}
+        self.agent_can_move_value = {}
         self.switches = {}
         self.show_debug_plot = show_debug_plot
+        self.enable_eps = enable_eps
 
     def step(self, handle, state, action, reward, next_state, done):
         pass
 
     def act(self, handle, state, eps=0.):
-        # agent = self.env.agents[handle]
+        # Epsilon-greedy action selection
+        if self.enable_eps:
+            if np.random.random() < eps:
+                return np.random.choice(np.arange(self.action_size))
+
+        # agent = self.env.agents[state[0]]
         check = self.agent_can_move.get(handle, None)
-        if check is None:
-            return RailEnvActions.STOP_MOVING
-        return check[3]
+        act = RailEnvActions.STOP_MOVING
+        if check is not None:
+            act = check[3]
+        return map_rail_env_action(act)
+
+    def get_agent_can_move_value(self, handle):
+        return self.agent_can_move_value.get(handle, np.inf)
 
-    def reset(self):
+    def reset(self, env):
+        self.env = env
         self.agent_positions = None
         self.shortest_distance_walker = None
         self.switches = {}
@@ -87,12 +115,12 @@ class DeadLockAvoidanceAgent(Policy):
                         else:
                             self.switches[pos].append(dir)
 
-    def start_step(self):
+    def start_step(self, train):
         self.build_agent_position_map()
         self.shortest_distance_mapper()
         self.extract_agent_can_move()
 
-    def end_step(self):
+    def end_step(self, train):
         pass
 
     def get_actions(self):
@@ -122,7 +150,9 @@ class DeadLockAvoidanceAgent(Policy):
         for handle in range(self.env.get_num_agents()):
             agent = self.env.agents[handle]
             if agent.status < RailAgentStatus.DONE:
-                next_step_ok = self.check_agent_can_move(shortest_distance_agent_map[handle],
+                next_step_ok = self.check_agent_can_move(handle,
+                                                         shortest_distance_agent_map[handle],
+                                                         self.shortest_distance_walker.same_agent_map.get(handle, []),
                                                          self.shortest_distance_walker.opp_agent_map.get(handle, []),
                                                          full_shortest_distance_agent_map)
                 if next_step_ok:
@@ -139,7 +169,9 @@ class DeadLockAvoidanceAgent(Policy):
             plt.pause(0.01)
 
     def check_agent_can_move(self,
+                             handle,
                              my_shortest_walking_path,
+                             same_agents,
                              opp_agents,
                              full_shortest_distance_agent_map):
         agent_positions_map = (self.agent_positions > -1).astype(int)
@@ -147,7 +179,16 @@ class DeadLockAvoidanceAgent(Policy):
         next_step_ok = True
         for opp_a in opp_agents:
             opp = full_shortest_distance_agent_map[opp_a]
-            delta = ((delta - opp - agent_positions_map) > 0).astype(int)
-            if (np.sum(delta) < 2 + len(opp_agents)):
+            delta = ((my_shortest_walking_path - opp - agent_positions_map) > 0).astype(int)
+            if np.sum(delta) < (3 + len(opp_agents)):
                 next_step_ok = False
+            v = self.agent_can_move_value.get(handle, np.inf)
+            v = min(v, np.sum(delta))
+            self.agent_can_move_value.update({handle: v})
         return next_step_ok
+
+    def save(self, filename):
+        pass
+
+    def load(self, filename):
+        pass
diff --git a/utils/deadlock_check.py b/utils/deadlock_check.py
index 28c65fa6185fa9131fbc493a33fa26529e0290db..4df6731c951bf5599f832735344402da58d8caaf 100644
--- a/utils/deadlock_check.py
+++ b/utils/deadlock_check.py
@@ -1,42 +1,95 @@
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import RailAgentStatus
-
-
-def check_if_all_blocked(env):
-    """
-    Checks whether all the agents are blocked (full deadlock situation).
-    In that case it is pointless to keep running inference as no agent will be able to move.
-    :param env: current environment
-    :return:
-    """
-
-    # First build a map of agents in each position
-    location_has_agent = {}
-    for agent in env.agents:
-        if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position:
-            location_has_agent[tuple(agent.position)] = 1
-
-    # Looks for any agent that can still move
-    for handle in env.get_agent_handles():
-        agent = env.agents[handle]
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            agent_virtual_position = agent.initial_position
-        elif agent.status == RailAgentStatus.ACTIVE:
-            agent_virtual_position = agent.position
-        elif agent.status == RailAgentStatus.DONE:
-            agent_virtual_position = agent.target
-        else:
-            continue
-
-        possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
-        orientation = agent.direction
-
-        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
-            if possible_transitions[branch_direction]:
-                new_position = get_new_position(agent_virtual_position, branch_direction)
-
-                if new_position not in location_has_agent:
-                    return False
-
-    # No agent can move at all: full deadlock!
-    return True
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import fast_count_nonzero
+
+
+def get_agent_positions(env):
+    agent_positions: np.ndarray = np.full((env.height, env.width), -1)
+    for agent_handle in env.get_agent_handles():
+        agent = env.agents[agent_handle]
+        if agent.status == RailAgentStatus.ACTIVE:
+            position = agent.position
+            if position is None:
+                position = agent.initial_position
+            agent_positions[position] = agent_handle
+    return agent_positions
+
+
+def get_agent_targets(env):
+    agent_targets = []
+    for agent_handle in env.get_agent_handles():
+        agent = env.agents[agent_handle]
+        if agent.status == RailAgentStatus.ACTIVE:
+            agent_targets.append(agent.target)
+    return agent_targets
+
+
+def check_for_deadlock(handle, env, agent_positions, check_position=None, check_direction=None):
+    agent = env.agents[handle]
+    if agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED:
+        return False
+
+    position = agent.position
+    if position is None:
+        position = agent.initial_position
+    if check_position is not None:
+        position = check_position
+    direction = agent.direction
+    if check_direction is not None:
+        direction = check_direction
+
+    possible_transitions = env.rail.get_transitions(*position, direction)
+    num_transitions = fast_count_nonzero(possible_transitions)
+    for dir_loop in range(4):
+        if possible_transitions[dir_loop] == 1:
+            new_position = get_new_position(position, dir_loop)
+            opposite_agent = agent_positions[new_position]
+            if opposite_agent != handle and opposite_agent != -1:
+                num_transitions -= 1
+            else:
+                return False
+
+    is_deadlock = num_transitions <= 0
+    return is_deadlock
+
+
+def check_if_all_blocked(env):
+    """
+    Checks whether all the agents are blocked (full deadlock situation).
+    In that case it is pointless to keep running inference as no agent will be able to move.
+    :param env: current environment
+    :return:
+    """
+
+    # First build a map of agents in each position
+    location_has_agent = {}
+    for agent in env.agents:
+        if agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and agent.position:
+            location_has_agent[tuple(agent.position)] = 1
+
+    # Looks for any agent that can still move
+    for handle in env.get_agent_handles():
+        agent = env.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            agent_virtual_position = agent.target
+        else:
+            continue
+
+        possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
+        orientation = agent.direction
+
+        for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]:
+            if possible_transitions[branch_direction]:
+                new_position = get_new_position(agent_virtual_position, branch_direction)
+
+                if new_position not in location_has_agent:
+                    return False
+
+    # No agent can move at all: full deadlock!
+    return True
diff --git a/utils/extra.py b/utils/extra.py
deleted file mode 100644
index 89ed0bb9ea2b7a993fe9eecd9332f0910294cbce..0000000000000000000000000000000000000000
--- a/utils/extra.py
+++ /dev/null
@@ -1,363 +0,0 @@
-import numpy as np
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import RailEnvActions, fast_argmax, fast_count_nonzero
-
-from reinforcement_learning.policy import Policy
-from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent, DeadlockAvoidanceShortestDistanceWalker
-
-
-class ExtraPolicy(Policy):
-    def __init__(self, state_size, action_size):
-        self.state_size = state_size
-        self.action_size = action_size
-        self.memory = []
-        self.loss = 0
-
-    def load(self, filename):
-        pass
-
-    def save(self, filename):
-        pass
-
-    def step(self, handle, state, action, reward, next_state, done):
-        pass
-
-    def act(self, handle, state, eps=0.):
-        a = 0
-        b = 4
-        action = RailEnvActions.STOP_MOVING
-        if state[2] == 1 and state[10 + a] == 0:
-            action = RailEnvActions.MOVE_LEFT
-        elif state[3] == 1 and state[11 + a] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-        elif state[4] == 1 and state[12 + a] == 0:
-            action = RailEnvActions.MOVE_RIGHT
-        elif state[5] == 1 and state[13 + a] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-
-        elif state[6] == 1 and state[10 + b] == 0:
-            action = RailEnvActions.MOVE_LEFT
-        elif state[7] == 1 and state[11 + b] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-        elif state[8] == 1 and state[12 + b] == 0:
-            action = RailEnvActions.MOVE_RIGHT
-        elif state[9] == 1 and state[13 + b] == 0:
-            action = RailEnvActions.MOVE_FORWARD
-
-        return action
-
-    def test(self):
-        pass
-
-
-class Extra(ObservationBuilder):
-
-    def __init__(self, max_depth):
-        self.max_depth = max_depth
-        self.observation_dim = 31
-
-    def build_data(self):
-        self.dead_lock_avoidance_agent = None
-        if self.env is not None:
-            self.env.dev_obs_dict = {}
-            self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, None, None)
-
-        self.switches = {}
-        self.switches_neighbours = {}
-        self.debug_render_list = []
-        self.debug_render_path_list = []
-        if self.env is not None:
-            self.find_all_cell_where_agent_can_choose()
-
-    def find_all_cell_where_agent_can_choose(self):
-
-        switches = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                pos = (h, w)
-                for dir in range(4):
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    num_transitions = fast_count_nonzero(possible_transitions)
-                    if num_transitions > 1:
-                        if pos not in switches.keys():
-                            switches.update({pos: [dir]})
-                        else:
-                            switches[pos].append(dir)
-
-        switches_neighbours = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                # look one step forward
-                for dir in range(4):
-                    pos = (h, w)
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    for d in range(4):
-                        if possible_transitions[d] == 1:
-                            new_cell = get_new_position(pos, d)
-                            if new_cell in switches.keys() and pos not in switches.keys():
-                                if pos not in switches_neighbours.keys():
-                                    switches_neighbours.update({pos: [dir]})
-                                else:
-                                    switches_neighbours[pos].append(dir)
-
-        self.switches = switches
-        self.switches_neighbours = switches_neighbours
-
-    def check_agent_descision(self, position, direction):
-        switches = self.switches
-        switches_neighbours = self.switches_neighbours
-        agents_on_switch = False
-        agents_on_switch_all = False
-        agents_near_to_switch = False
-        agents_near_to_switch_all = False
-        if position in switches.keys():
-            agents_on_switch = direction in switches[position]
-            agents_on_switch_all = True
-
-        if position in switches_neighbours.keys():
-            new_cell = get_new_position(position, direction)
-            if new_cell in switches.keys():
-                if not direction in switches[new_cell]:
-                    agents_near_to_switch = direction in switches_neighbours[position]
-            else:
-                agents_near_to_switch = direction in switches_neighbours[position]
-
-            agents_near_to_switch_all = direction in switches_neighbours[position]
-
-        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
-
-    def required_agent_descision(self):
-        agents_can_choose = {}
-        agents_on_switch = {}
-        agents_on_switch_all = {}
-        agents_near_to_switch = {}
-        agents_near_to_switch_all = {}
-        for a in range(self.env.get_num_agents()):
-            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all, ret_agents_on_switch_all = \
-                self.check_agent_descision(
-                    self.env.agents[a].position,
-                    self.env.agents[a].direction)
-            agents_on_switch.update({a: ret_agents_on_switch})
-            agents_on_switch_all.update({a: ret_agents_on_switch_all})
-            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
-            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
-
-            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
-
-            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
-
-        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all, agents_on_switch_all
-
-    def debug_render(self, env_renderer):
-        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
-            self.required_agent_descision()
-        self.env.dev_obs_dict = {}
-        for a in range(max(3, self.env.get_num_agents())):
-            self.env.dev_obs_dict.update({a: []})
-
-        selected_agent = None
-        if agents_can_choose[0]:
-            if self.env.agents[0].position is not None:
-                self.debug_render_list.append(self.env.agents[0].position)
-            else:
-                self.debug_render_list.append(self.env.agents[0].initial_position)
-
-        if self.env.agents[0].position is not None:
-            self.debug_render_path_list.append(self.env.agents[0].position)
-        else:
-            self.debug_render_path_list.append(self.env.agents[0].initial_position)
-
-        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
-        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
-        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
-        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
-
-        self.env.dev_obs_dict[0] = self.debug_render_list
-        self.env.dev_obs_dict[1] = self.switches.keys()
-        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
-        self.env.dev_obs_dict[3] = self.debug_render_path_list
-
-    def reset(self):
-        self.build_data()
-        return
-
-
-    def _check_dead_lock_at_branching_position(self, handle, new_position, branch_direction):
-        _, full_shortest_distance_agent_map = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
-        opp_agents = self.dead_lock_avoidance_agent.shortest_distance_walker.opp_agent_map.get(handle, [])
-        local_walker = DeadlockAvoidanceShortestDistanceWalker(
-            self.env,
-            self.dead_lock_avoidance_agent.shortest_distance_walker.agent_positions,
-            self.dead_lock_avoidance_agent.shortest_distance_walker.switches)
-        local_walker.walk_to_target(handle, new_position, branch_direction)
-        shortest_distance_agent_map, _ = self.dead_lock_avoidance_agent.shortest_distance_walker.getData()
-        my_shortest_path_to_check = shortest_distance_agent_map[handle]
-        next_step_ok = self.dead_lock_avoidance_agent.check_agent_can_move(my_shortest_path_to_check,
-                                                                           opp_agents,
-                                                                           full_shortest_distance_agent_map)
-        return next_step_ok
-
-    def _explore(self, handle, new_position, new_direction, depth=0):
-
-        has_opp_agent = 0
-        has_same_agent = 0
-        has_switch = 0
-        visited = []
-
-        # stop exploring (max_depth reached)
-        if depth >= self.max_depth:
-            return has_opp_agent, has_same_agent, has_switch, visited
-
-        # max_explore_steps = 100
-        cnt = 0
-        while cnt < 100:
-            cnt += 1
-
-            visited.append(new_position)
-            opp_a = self.env.agent_positions[new_position]
-            if opp_a != -1 and opp_a != handle:
-                if self.env.agents[opp_a].direction != new_direction:
-                    # opp agent found
-                    has_opp_agent = 1
-                    return has_opp_agent, has_same_agent, has_switch, visited
-                else:
-                    has_same_agent = 1
-                    return has_opp_agent, has_same_agent, has_switch, visited
-
-            # convert one-hot encoding to 0,1,2,3
-            agents_on_switch, \
-            agents_near_to_switch, \
-            agents_near_to_switch_all, \
-            agents_on_switch_all = \
-                self.check_agent_descision(new_position, new_direction)
-            if agents_near_to_switch:
-                return has_opp_agent, has_same_agent, has_switch, visited
-
-            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
-            if agents_on_switch:
-                f = 0
-                for dir_loop in range(4):
-                    if possible_transitions[dir_loop] == 1:
-                        f += 1
-                        hoa, hsa, hs, v = self._explore(handle,
-                                                        get_new_position(new_position, dir_loop),
-                                                        dir_loop,
-                                                        depth + 1)
-                        visited.append(v)
-                        has_opp_agent += hoa
-                        has_same_agent += hsa
-                        has_switch += hs
-                f = max(f, 1.0)
-                return has_opp_agent / f, has_same_agent / f, has_switch / f, visited
-            else:
-                new_direction = fast_argmax(possible_transitions)
-                new_position = get_new_position(new_position, new_direction)
-
-        return has_opp_agent, has_same_agent, has_switch, visited
-
-    def get(self, handle):
-
-        if handle == 0:
-            self.dead_lock_avoidance_agent.start_step()
-
-        # all values are [0,1]
-        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
-        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
-        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
-        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
-        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
-        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
-        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
-        # observation[7]  : current agent is located at a switch, where it can take a routing decision
-        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
-        # observation[9]  : current agent is located one step before/after a switch
-        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
-        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
-        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
-        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
-        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
-        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
-        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
-        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
-        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
-        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
-        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
-        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
-        # observation[22] : If there is a switch on the path which agent can not use -> 1
-        # observation[23] : If there is a switch on the path which agent can not use -> 1
-        # observation[24] : If there is a switch on the path which agent can not use -> 1
-        # observation[25] : If there is a switch on the path which agent can not use -> 1
-        # observation[26] : Is there a deadlock signal on shortest path walk(s) (direction 0)-> 1
-        # observation[27] : Is there a deadlock signal on shortest path walk(s) (direction 1)-> 1
-        # observation[28] : Is there a deadlock signal on shortest path walk(s) (direction 2)-> 1
-        # observation[29] : Is there a deadlock signal on shortest path walk(s) (direction 3)-> 1
-        # observation[30] : Is there a deadlock signal on shortest path walk(s) (current position check)-> 1
-
-        observation = np.zeros(self.observation_dim)
-        visited = []
-        agent = self.env.agents[handle]
-
-        agent_done = False
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            agent_virtual_position = agent.initial_position
-            observation[4] = 1
-        elif agent.status == RailAgentStatus.ACTIVE:
-            agent_virtual_position = agent.position
-            observation[5] = 1
-        else:
-            observation[6] = 1
-            agent_virtual_position = (-1, -1)
-            agent_done = True
-
-        if not agent_done:
-            visited.append(agent_virtual_position)
-            distance_map = self.env.distance_map.get()
-            current_cell_dist = distance_map[handle,
-                                             agent_virtual_position[0], agent_virtual_position[1],
-                                             agent.direction]
-            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
-            orientation = agent.direction
-            if fast_count_nonzero(possible_transitions) == 1:
-                orientation = fast_argmax(possible_transitions)
-
-            for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
-                if possible_transitions[branch_direction]:
-                    new_position = get_new_position(agent_virtual_position, branch_direction)
-                    new_cell_dist = distance_map[handle,
-                                                 new_position[0], new_position[1],
-                                                 branch_direction]
-                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
-                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
-
-                    has_opp_agent, has_same_agent, has_switch, v = self._explore(handle, new_position, branch_direction)
-                    visited.append(v)
-
-                    observation[10 + dir_loop] = 1
-                    observation[14 + dir_loop] = has_opp_agent
-                    observation[18 + dir_loop] = has_same_agent
-                    observation[22 + dir_loop] = has_switch
-
-                    next_step_ok = self._check_dead_lock_at_branching_position(handle, new_position, branch_direction)
-                    if next_step_ok:
-                        observation[26 + dir_loop] = 1
-
-        agents_on_switch, \
-        agents_near_to_switch, \
-        agents_near_to_switch_all, \
-        agents_on_switch_all = \
-            self.check_agent_descision(agent_virtual_position, agent.direction)
-        observation[7] = int(agents_on_switch)
-        observation[8] = int(agents_near_to_switch)
-        observation[9] = int(agents_near_to_switch_all)
-
-        observation[30] = int(self.dead_lock_avoidance_agent.act(handle, None, 0) == RailEnvActions.STOP_MOVING)
-
-        self.env.dev_obs_dict.update({handle: visited})
-
-        return observation
-
-    @staticmethod
-    def agent_can_choose(observation):
-        return observation[7] == 1 or observation[8] == 1
diff --git a/utils/fast_tree_obs.py b/utils/fast_tree_obs.py
new file mode 100755
index 0000000000000000000000000000000000000000..4d4ce4b0396f790b2e71879801f2821792116052
--- /dev/null
+++ b/utils/fast_tree_obs.py
@@ -0,0 +1,267 @@
+from typing import List, Optional, Any
+
+import numpy as np
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.envs.rail_env import fast_count_nonzero, fast_argmax, RailEnvActions
+
+from utils.agent_action_config import get_flatland_full_action_size
+from utils.agent_can_choose_helper import AgentCanChooseHelper
+from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
+from utils.deadlock_check import get_agent_positions, get_agent_targets
+
+"""
+LICENCE for the FastTreeObs Observation Builder  
+
+The observation can be used freely and reused for further submissions. Only the author needs to be referred to
+/mentioned in any submissions - if the entire observation or parts, or the main idea is used.
+
+Author: Adrian Egli (adrian.egli@gmail.com)
+
+[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2)
+[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/)
+"""
+
+
+class FastTreeObs(ObservationBuilder):
+
+    def __init__(self, max_depth: Any):
+        self.max_depth = max_depth
+        self.observation_dim = 35
+        self.agent_can_choose_helper = None
+        self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(None, get_flatland_full_action_size())
+
+    def debug_render(self, env_renderer):
+        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
+            self.agent_can_choose_helper.required_agent_decision()
+        self.env.dev_obs_dict = {}
+        for a in range(max(3, self.env.get_num_agents())):
+            self.env.dev_obs_dict.update({a: []})
+
+        selected_agent = None
+        if agents_can_choose[0]:
+            if self.env.agents[0].position is not None:
+                self.debug_render_list.append(self.env.agents[0].position)
+            else:
+                self.debug_render_list.append(self.env.agents[0].initial_position)
+
+        if self.env.agents[0].position is not None:
+            self.debug_render_path_list.append(self.env.agents[0].position)
+        else:
+            self.debug_render_path_list.append(self.env.agents[0].initial_position)
+
+        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
+        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
+        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
+        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
+
+        self.env.dev_obs_dict[0] = self.debug_render_list
+        self.env.dev_obs_dict[1] = self.agent_can_choose_helper.switches.keys()
+        self.env.dev_obs_dict[2] = self.agent_can_choose_helper.switches_neighbours.keys()
+        self.env.dev_obs_dict[3] = self.debug_render_path_list
+
+    def reset(self):
+        if self.agent_can_choose_helper is None:
+            self.agent_can_choose_helper = AgentCanChooseHelper()
+        self.agent_can_choose_helper.build_data(self.env)
+        self.debug_render_list = []
+        self.debug_render_path_list = []
+
+    def _explore(self, handle, new_position, new_direction, distance_map, depth=0):
+        has_opp_agent = 0
+        has_same_agent = 0
+        has_target = 0
+        has_opp_target = 0
+        visited = []
+        min_dist = distance_map[handle, new_position[0], new_position[1], new_direction]
+
+        # stop exploring (max_depth reached)
+        if depth >= self.max_depth:
+            return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+
+        # max_explore_steps = 100 -> just to ensure that the exploration ends
+        cnt = 0
+        while cnt < 100:
+            cnt += 1
+
+            visited.append(new_position)
+            opp_a = self.env.agent_positions[new_position]
+            if opp_a != -1 and opp_a != handle:
+                if self.env.agents[opp_a].direction != new_direction:
+                    # opp agent found -> stop exploring. This would be a strong signal.
+                    has_opp_agent = 1
+                    return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+                else:
+                    # same agent found
+                    # the agent can follow the agent, because this agent is still moving ahead and there shouldn't
+                    # be any dead-lock nor other issue -> agent is just walking -> if other agent has a deadlock
+                    # this should be avoided by other agents -> one edge case would be when other agent has it's
+                    # target on this branch -> thus the agents should scan further whether there will be an opposite
+                    # agent walking on same track
+                    has_same_agent = 1
+                    # !NOT stop exploring!
+                    return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+
+            # agents_on_switch == TRUE -> Current cell is a switch where the agent can decide (branch) in exploration
+            # agent_near_to_switch == TRUE -> One cell before the switch, where the agent can decide
+            #
+            agents_on_switch, agents_near_to_switch, _, _ = \
+                self.agent_can_choose_helper.check_agent_decision(new_position, new_direction)
+
+            if agents_near_to_switch:
+                # The exploration was walking on a path where the agent can not decide
+                # Best option would be MOVE_FORWARD -> Skip exploring - just walking
+                return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+
+            if self.env.agents[handle].target in self.agents_target:
+                has_opp_target = 1
+
+            if self.env.agents[handle].target == new_position:
+                has_target = 1
+                return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+
+            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
+            if agents_on_switch:
+                orientation = new_direction
+                possible_transitions_nonzero = fast_count_nonzero(possible_transitions)
+                if possible_transitions_nonzero == 1:
+                    orientation = fast_argmax(possible_transitions)
+
+                for dir_loop, branch_direction in enumerate(
+                        [(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
+                    # branch the exploration path and aggregate the found information
+                    # --- OPEN RESEARCH QUESTION ---> is this good or shall we use full detailed information as
+                    # we did in the TreeObservation (FLATLAND) ?
+                    if possible_transitions[dir_loop] == 1:
+                        hoa, hsa, ht, hot, v, m_dist = self._explore(handle,
+                                                                     get_new_position(new_position, dir_loop),
+                                                                     dir_loop,
+                                                                     distance_map,
+                                                                     depth + 1)
+                        visited.append(v)
+                        has_opp_agent = max(hoa, has_opp_agent)
+                        has_same_agent = max(hsa, has_same_agent)
+                        has_target = max(has_target, ht)
+                        has_opp_target = max(has_opp_target, hot)
+                        min_dist = min(min_dist, m_dist)
+                return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+            else:
+                new_direction = fast_argmax(possible_transitions)
+                new_position = get_new_position(new_position, new_direction)
+
+            min_dist = min(min_dist, distance_map[handle, new_position[0], new_position[1], new_direction])
+
+        return has_opp_agent, has_same_agent, has_target, has_opp_target, visited, min_dist
+
+    def get_many(self, handles: Optional[List[int]] = None):
+        self.dead_lock_avoidance_agent.reset(self.env)
+        self.dead_lock_avoidance_agent.start_step(False)
+        self.agent_positions = get_agent_positions(self.env)
+        self.agents_target = get_agent_targets(self.env)
+        observations = super().get_many(handles)
+        self.dead_lock_avoidance_agent.end_step(False)
+        return observations
+
+    def get(self, handle: int = 0):
+        # all values are [0,1]
+        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
+        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
+        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
+        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
+        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
+        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
+        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
+        # observation[7]  : current agent is located at a switch, where it can take a routing decision
+        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
+        # observation[9]  : current agent is located one step before/after a switch
+        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
+        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
+        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
+        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
+        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
+        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
+        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
+        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
+        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
+        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
+        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
+        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
+        # observation[22] : If there is a switch on the path which agent can not use -> 1
+        # observation[23] : If there is a switch on the path which agent can not use -> 1
+        # observation[24] : If there is a switch on the path which agent can not use -> 1
+        # observation[25] : If there is a switch on the path which agent can not use -> 1
+
+        observation = np.zeros(self.observation_dim)
+        visited = []
+        agent = self.env.agents[handle]
+
+        agent_done = False
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+            observation[4] = 1
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+            observation[5] = 1
+        else:
+            observation[6] = 1
+            agent_virtual_position = (-1, -1)
+            agent_done = True
+
+        if not agent_done:
+            visited.append(agent_virtual_position)
+            distance_map = self.env.distance_map.get()
+            current_cell_dist = distance_map[handle,
+                                             agent_virtual_position[0], agent_virtual_position[1],
+                                             agent.direction]
+            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
+            orientation = agent.direction
+            if fast_count_nonzero(possible_transitions) == 1:
+                orientation = fast_argmax(possible_transitions)
+
+            for dir_loop, branch_direction in enumerate([(orientation + dir_loop) % 4 for dir_loop in range(-1, 3)]):
+                if possible_transitions[branch_direction]:
+                    new_position = get_new_position(agent_virtual_position, branch_direction)
+                    new_cell_dist = distance_map[handle,
+                                                 new_position[0], new_position[1],
+                                                 branch_direction]
+                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
+                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
+
+                    has_opp_agent, has_same_agent, has_target, has_opp_target, v, min_dist = self._explore(handle,
+                                                                                                           new_position,
+                                                                                                           branch_direction,
+                                                                                                           distance_map)
+                    visited.append(v)
+
+                    if not (np.math.isinf(min_dist) and np.math.isinf(current_cell_dist)):
+                        observation[11 + dir_loop] = int(min_dist < current_cell_dist)
+                    observation[15 + dir_loop] = has_opp_agent
+                    observation[19 + dir_loop] = has_same_agent
+                    observation[23 + dir_loop] = has_target
+                    observation[27 + dir_loop] = has_opp_target
+
+            agents_on_switch, \
+            agents_near_to_switch, \
+            agents_near_to_switch_all, \
+            agents_on_switch_all = \
+                self.agent_can_choose_helper.check_agent_decision(agent_virtual_position, agent.direction)
+
+            observation[7] = int(agents_on_switch)
+            observation[8] = int(agents_on_switch_all)
+            observation[9] = int(agents_near_to_switch)
+            observation[10] = int(agents_near_to_switch_all)
+
+            action = self.dead_lock_avoidance_agent.act(handle, None, eps=0)
+            observation[30] = action == RailEnvActions.DO_NOTHING
+            observation[31] = action == RailEnvActions.MOVE_LEFT
+            observation[32] = action == RailEnvActions.MOVE_FORWARD
+            observation[33] = action == RailEnvActions.MOVE_RIGHT
+            observation[34] = action == RailEnvActions.STOP_MOVING
+
+        self.env.dev_obs_dict.update({handle: visited})
+
+        observation[np.isinf(observation)] = -1
+        observation[np.isnan(observation)] = -1
+
+        return observation
diff --git a/utils/observation_utils.py b/utils/observation_utils.py
index 3dc767f2b76197fbb98d63f326d7720dbbbdc020..aa18aefe8ba2aeeb75ff0287b53a001db174db42 100644
--- a/utils/observation_utils.py
+++ b/utils/observation_utils.py
@@ -1,124 +1,124 @@
-import numpy as np
-from flatland.envs.observations import TreeObsForRailEnv
-
-def max_lt(seq, val):
-    """
-    Return greatest item in seq for which item < val applies.
-    None is returned if seq was empty or all items in seq were >= val.
-    """
-    max = 0
-    idx = len(seq) - 1
-    while idx >= 0:
-        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
-            max = seq[idx]
-        idx -= 1
-    return max
-
-
-def min_gt(seq, val):
-    """
-    Return smallest item in seq for which item > val applies.
-    None is returned if seq was empty or all items in seq were >= val.
-    """
-    min = np.inf
-    idx = len(seq) - 1
-    while idx >= 0:
-        if seq[idx] >= val and seq[idx] < min:
-            min = seq[idx]
-        idx -= 1
-    return min
-
-
-def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
-    """
-    This function returns the difference between min and max value of an observation
-    :param obs: Observation that should be normalized
-    :param clip_min: min value where observation will be clipped
-    :param clip_max: max value where observation will be clipped
-    :return: returnes normalized and clipped observatoin
-    """
-    if fixed_radius > 0:
-        max_obs = fixed_radius
-    else:
-        max_obs = max(1, max_lt(obs, 1000)) + 1
-
-    min_obs = 0  # min(max_obs, min_gt(obs, 0))
-    if normalize_to_range:
-        min_obs = min_gt(obs, 0)
-    if min_obs > max_obs:
-        min_obs = max_obs
-    if max_obs == min_obs:
-        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
-    norm = np.abs(max_obs - min_obs)
-    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
-
-
-def _split_node_into_feature_groups(node) -> (np.ndarray, np.ndarray, np.ndarray):
-    data = np.zeros(6)
-    distance = np.zeros(1)
-    agent_data = np.zeros(4)
-
-    data[0] = node.dist_own_target_encountered
-    data[1] = node.dist_other_target_encountered
-    data[2] = node.dist_other_agent_encountered
-    data[3] = node.dist_potential_conflict
-    data[4] = node.dist_unusable_switch
-    data[5] = node.dist_to_next_branch
-
-    distance[0] = node.dist_min_to_target
-
-    agent_data[0] = node.num_agents_same_direction
-    agent_data[1] = node.num_agents_opposite_direction
-    agent_data[2] = node.num_agents_malfunctioning
-    agent_data[3] = node.speed_min_fractional
-
-    return data, distance, agent_data
-
-
-def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
-    if node == -np.inf:
-        remaining_depth = max_tree_depth - current_tree_depth
-        # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
-        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
-        return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4
-
-    data, distance, agent_data = _split_node_into_feature_groups(node)
-
-    if not node.childs:
-        return data, distance, agent_data
-
-    for direction in TreeObsForRailEnv.tree_explored_actions_char:
-        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
-        data = np.concatenate((data, sub_data))
-        distance = np.concatenate((distance, sub_distance))
-        agent_data = np.concatenate((agent_data, sub_agent_data))
-
-    return data, distance, agent_data
-
-
-def split_tree_into_feature_groups(tree, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
-    """
-    This function splits the tree into three difference arrays of values
-    """
-    data, distance, agent_data = _split_node_into_feature_groups(tree)
-
-    for direction in TreeObsForRailEnv.tree_explored_actions_char:
-        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
-        data = np.concatenate((data, sub_data))
-        distance = np.concatenate((distance, sub_distance))
-        agent_data = np.concatenate((agent_data, sub_agent_data))
-
-    return data, distance, agent_data
-
-
-def normalize_observation(observation, tree_depth: int, observation_radius=0):
-    """
-    This function normalizes the observation used by the RL algorithm
-    """
-    data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
-
-    data = norm_obs_clip(data, fixed_radius=observation_radius)
-    distance = norm_obs_clip(distance, normalize_to_range=True)
-    agent_data = np.clip(agent_data, -1, 1)
-    normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
-    return normalized_obs
+import numpy as np
+from flatland.envs.observations import TreeObsForRailEnv
+
+def max_lt(seq, val):
+    """
+    Return greatest item in seq for which item < val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    max = 0
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] < val and seq[idx] >= 0 and seq[idx] > max:
+            max = seq[idx]
+        idx -= 1
+    return max
+
+
+def min_gt(seq, val):
+    """
+    Return smallest item in seq for which item > val applies.
+    None is returned if seq was empty or all items in seq were >= val.
+    """
+    min = np.inf
+    idx = len(seq) - 1
+    while idx >= 0:
+        if seq[idx] >= val and seq[idx] < min:
+            min = seq[idx]
+        idx -= 1
+    return min
+
+
+def norm_obs_clip(obs, clip_min=-1, clip_max=1, fixed_radius=0, normalize_to_range=False):
+    """
+    This function returns the difference between min and max value of an observation
+    :param obs: Observation that should be normalized
+    :param clip_min: min value where observation will be clipped
+    :param clip_max: max value where observation will be clipped
+    :return: returnes normalized and clipped observatoin
+    """
+    if fixed_radius > 0:
+        max_obs = fixed_radius
+    else:
+        max_obs = max(1, max_lt(obs, 1000)) + 1
+
+    min_obs = 0  # min(max_obs, min_gt(obs, 0))
+    if normalize_to_range:
+        min_obs = min_gt(obs, 0)
+    if min_obs > max_obs:
+        min_obs = max_obs
+    if max_obs == min_obs:
+        return np.clip(np.array(obs) / max_obs, clip_min, clip_max)
+    norm = np.abs(max_obs - min_obs)
+    return np.clip((np.array(obs) - min_obs) / norm, clip_min, clip_max)
+
+
+def _split_node_into_feature_groups(node) -> (np.ndarray, np.ndarray, np.ndarray):
+    data = np.zeros(6)
+    distance = np.zeros(1)
+    agent_data = np.zeros(4)
+
+    data[0] = node.dist_own_target_encountered
+    data[1] = node.dist_other_target_encountered
+    data[2] = node.dist_other_agent_encountered
+    data[3] = node.dist_potential_conflict
+    data[4] = node.dist_unusable_switch
+    data[5] = node.dist_to_next_branch
+
+    distance[0] = node.dist_min_to_target
+
+    agent_data[0] = node.num_agents_same_direction
+    agent_data[1] = node.num_agents_opposite_direction
+    agent_data[2] = node.num_agents_malfunctioning
+    agent_data[3] = node.speed_min_fractional
+
+    return data, distance, agent_data
+
+
+def _split_subtree_into_feature_groups(node, current_tree_depth: int, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
+    if node == -np.inf:
+        remaining_depth = max_tree_depth - current_tree_depth
+        # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
+        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
+        return [-np.inf] * num_remaining_nodes * 6, [-np.inf] * num_remaining_nodes, [-np.inf] * num_remaining_nodes * 4
+
+    data, distance, agent_data = _split_node_into_feature_groups(node)
+
+    if not node.childs:
+        return data, distance, agent_data
+
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
+        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(node.childs[direction], current_tree_depth + 1, max_tree_depth)
+        data = np.concatenate((data, sub_data))
+        distance = np.concatenate((distance, sub_distance))
+        agent_data = np.concatenate((agent_data, sub_agent_data))
+
+    return data, distance, agent_data
+
+
+def split_tree_into_feature_groups(tree, max_tree_depth: int) -> (np.ndarray, np.ndarray, np.ndarray):
+    """
+    This function splits the tree into three difference arrays of values
+    """
+    data, distance, agent_data = _split_node_into_feature_groups(tree)
+
+    for direction in TreeObsForRailEnv.tree_explored_actions_char:
+        sub_data, sub_distance, sub_agent_data = _split_subtree_into_feature_groups(tree.childs[direction], 1, max_tree_depth)
+        data = np.concatenate((data, sub_data))
+        distance = np.concatenate((distance, sub_distance))
+        agent_data = np.concatenate((agent_data, sub_agent_data))
+
+    return data, distance, agent_data
+
+
+def normalize_observation(observation, tree_depth: int, observation_radius=0):
+    """
+    This function normalizes the observation used by the RL algorithm
+    """
+    data, distance, agent_data = split_tree_into_feature_groups(observation, tree_depth)
+
+    data = norm_obs_clip(data, fixed_radius=observation_radius)
+    distance = norm_obs_clip(distance, normalize_to_range=True)
+    agent_data = np.clip(agent_data, -1, 1)
+    normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data))
+    return normalized_obs
diff --git a/utils/shortest_Distance_walker.py b/utils/shortest_distance_walker.py
similarity index 100%
rename from utils/shortest_Distance_walker.py
rename to utils/shortest_distance_walker.py
diff --git a/utils/shortest_path_walker_heuristic_agent.py b/utils/shortest_path_walker_heuristic_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2cbab04f407edeae3fba5030a0b7b3309560cfc
--- /dev/null
+++ b/utils/shortest_path_walker_heuristic_agent.py
@@ -0,0 +1,57 @@
+import numpy as np
+from flatland.envs.rail_env import RailEnvActions
+
+from reinforcement_learning.policy import Policy
+
+
+class ShortestPathWalkerHeuristicPolicy(Policy):
+    def step(self, state, action, reward, next_state, done):
+        pass
+
+    def act(self, handle, node, eps=0.):
+
+        left_node = node.childs.get('L')
+        forward_node = node.childs.get('F')
+        right_node = node.childs.get('R')
+
+        dist_map = np.zeros(5)
+        dist_map[RailEnvActions.DO_NOTHING] = np.inf
+        dist_map[RailEnvActions.STOP_MOVING] = 100000
+        # left
+        if left_node == -np.inf:
+            dist_map[RailEnvActions.MOVE_LEFT] = np.inf
+        else:
+            if left_node.num_agents_opposite_direction == 0:
+                dist_map[RailEnvActions.MOVE_LEFT] = left_node.dist_min_to_target
+            else:
+                dist_map[RailEnvActions.MOVE_LEFT] = np.inf
+        # forward
+        if forward_node == -np.inf:
+            dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
+        else:
+            if forward_node.num_agents_opposite_direction == 0:
+                dist_map[RailEnvActions.MOVE_FORWARD] = forward_node.dist_min_to_target
+            else:
+                dist_map[RailEnvActions.MOVE_FORWARD] = np.inf
+        # right
+        if right_node == -np.inf:
+            dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
+        else:
+            if right_node.num_agents_opposite_direction == 0:
+                dist_map[RailEnvActions.MOVE_RIGHT] = right_node.dist_min_to_target
+            else:
+                dist_map[RailEnvActions.MOVE_RIGHT] = np.inf
+        return np.argmin(dist_map)
+
+    def save(self, filename):
+        pass
+
+    def load(self, filename):
+        pass
+
+
+policy = ShortestPathWalkerHeuristicPolicy()
+
+
+def normalize_observation(observation, tree_depth: int, observation_radius=0):
+    return observation
diff --git a/utils/timer.py b/utils/timer.py
index 6e397c79cc46c9f49e967365c3a2ad9bbf7cd5f6..aa02e9f18bb01fc9730c484c079aa96556027f2b 100644
--- a/utils/timer.py
+++ b/utils/timer.py
@@ -1,33 +1,33 @@
-from timeit import default_timer
-
-
-class Timer(object):
-    """
-    Utility to measure times.
-
-    TODO:
-    - add "lap" method to make it easier to measure average time (+std) when measuring the same thing multiple times.
-    """
-
-    def __init__(self):
-        self.total_time = 0.0
-        self.start_time = 0.0
-        self.end_time = 0.0
-
-    def start(self):
-        self.start_time = default_timer()
-
-    def end(self):
-        self.total_time += default_timer() - self.start_time
-
-    def get(self):
-        return self.total_time
-
-    def get_current(self):
-        return default_timer() - self.start_time
-
-    def reset(self):
-        self.__init__()
-
-    def __repr__(self):
+from timeit import default_timer
+
+
+class Timer(object):
+    """
+    Utility to measure times.
+
+    TODO:
+    - add "lap" method to make it easier to measure average time (+std) when measuring the same thing multiple times.
+    """
+
+    def __init__(self):
+        self.total_time = 0.0
+        self.start_time = 0.0
+        self.end_time = 0.0
+
+    def start(self):
+        self.start_time = default_timer()
+
+    def end(self):
+        self.total_time += default_timer() - self.start_time
+
+    def get(self):
+        return self.total_time
+
+    def get_current(self):
+        return default_timer() - self.start_time
+
+    def reset(self):
+        self.__init__()
+
+    def __repr__(self):
         return self.get()
\ No newline at end of file