diff --git a/.gitignore b/.gitignore
index da214e57a3cb40bddeb8e0e2d0b518b3de06f20e..54ef0cba10a309cb13c6077b720208c5e142257b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,127 +1,127 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-pip-wheel-metadata/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-#  Usually these files are written by a python script from a template
-#  before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-.hypothesis/
-.pytest_cache/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-.python-version
-
-# pipenv
-#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-#   However, in case of collaboration, if having platform-specific dependencies or dependencies
-#   having no cross-platform support, pipenv may install dependencies that don't work, or not
-#   install all needed dependencies.
-#Pipfile.lock
-
-# celery beat schedule file
-celerybeat-schedule
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-scratch/test-envs/
-scratch/
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+scratch/test-envs/
+scratch/
diff --git a/README.md b/README.md
index c4e93dc4ae84e6debde8a2483c7b3f94edbc0d5e..22846ae8c532fa537e79ca807fe71860100f1f23 100644
--- a/README.md
+++ b/README.md
@@ -1,26 +1,26 @@
-![AIcrowd-Logo](https://raw.githubusercontent.com/AIcrowd/AIcrowd/master/app/assets/images/misc/aicrowd-horizontal.png)
-
-# Flatland Challenge Starter Kit
-
-**[Follow these instructions to submit your solutions!](http://flatland.aicrowd.com/getting-started/first-submission.html)**
-
-
-![flatland](https://i.imgur.com/0rnbSLY.gif)
-
-Main links
----
-
-* [Flatland documentation](https://flatland.aicrowd.com/)
-* [NeurIPS 2020 Challenge](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/)
-
-Communication
----
-
-* [Discord Channel](https://discord.com/invite/hCR3CZG)
-* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge)
-* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
-
-Author
----
-
-- **[Sharada Mohanty](https://twitter.com/MeMohanty)**
+![AIcrowd-Logo](https://raw.githubusercontent.com/AIcrowd/AIcrowd/master/app/assets/images/misc/aicrowd-horizontal.png)
+
+# Flatland Challenge Starter Kit
+
+**[Follow these instructions to submit your solutions!](http://flatland.aicrowd.com/getting-started/first-submission.html)**
+
+
+![flatland](https://i.imgur.com/0rnbSLY.gif)
+
+Main links
+---
+
+* [Flatland documentation](https://flatland.aicrowd.com/)
+* [NeurIPS 2020 Challenge](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/)
+
+Communication
+---
+
+* [Discord Channel](https://discord.com/invite/hCR3CZG)
+* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge)
+* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
+
+Author
+---
+
+- **[Sharada Mohanty](https://twitter.com/MeMohanty)**
diff --git a/aicrowd.json b/aicrowd.json
index 68c76af4fd222127604c5e5e3252429f9795fa4c..de611d36b277f16014795935206b0c9250f2bc37 100644
--- a/aicrowd.json
+++ b/aicrowd.json
@@ -1,7 +1,7 @@
-{
-  "challenge_id": "neurips-2020-flatland-challenge",
-  "grader_id": "neurips-2020-flatland-challenge",
-  "debug": false,
-  "tags": ["RL"]
-}
-
+{
+  "challenge_id": "neurips-2020-flatland-challenge",
+  "grader_id": "neurips-2020-flatland-challenge",
+  "debug": false,
+  "tags": ["RL"]
+}
+
diff --git a/apt.txt b/apt.txt
index d02172562addd7910a88661afa1bf72348d21d02..9c46ae885a5f3bcf73cf5689c501d39f76550bdc 100644
--- a/apt.txt
+++ b/apt.txt
@@ -1,6 +1,6 @@
-curl
-git
-vim
-ssh
-gcc
-python-cairo-dev
+curl
+git
+vim
+ssh
+gcc
+python-cairo-dev
diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..4f703391c1f977f6a63e4d8320ad8cdfb10e9a97
Binary files /dev/null and b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..d92c6b5ad82ef53eca4332ccc23056a80a3699a2
Binary files /dev/null and b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..1db257665e9b57e80c24c1e1bcc165fbcdb80d71
Binary files /dev/null and b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..84ffd934079e5114f881aad914112c35e5b0f777
Binary files /dev/null and b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..63b4fc065ee61195c655c5aef7b41b6e354bf75d
Binary files /dev/null and b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..7430a2d9b643bf30f9522c8ef3390bac8009f4c1
Binary files /dev/null and b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..8f5843e3dc95213ff24db6375e9a4ba65dfb29ef
Binary files /dev/null and b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..7bc95369b9607e0cd3528f8be5db73eac34c0e17
Binary files /dev/null and b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..339210ff779ad9a89efbf3620151939facd12c77
Binary files /dev/null and b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..7617876cf3d7031f066a779fde687404b0a1cc6f
Binary files /dev/null and b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..b93d28155360433cbb1574b2e797bf1e293c2f6c
Binary files /dev/null and b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..bc21bc40897b530a65966b1cbbbaeb41835f7b69
Binary files /dev/null and b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/No_col_20/model_checkpoint.meta b/checkpoints/No_col_20/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..fb226078f5ccd715028992499f09f6edfcc4857e
Binary files /dev/null and b/checkpoints/No_col_20/model_checkpoint.meta differ
diff --git a/checkpoints/No_col_20/model_checkpoint.optimizer b/checkpoints/No_col_20/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..c725a9ccd7791016f89cf8426746dd549dc23410
Binary files /dev/null and b/checkpoints/No_col_20/model_checkpoint.optimizer differ
diff --git a/checkpoints/No_col_20/model_checkpoint.policy b/checkpoints/No_col_20/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..91e0a1abe40721649a98ccc76271f34295f7932c
Binary files /dev/null and b/checkpoints/No_col_20/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.meta b/checkpoints/best_0.4757/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..8c645aa2e8794ff4edf9974b297ccca3eed5b013
Binary files /dev/null and b/checkpoints/best_0.4757/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer b/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..bd1ca4d0a9df1ccd8eff82001b542b0edadc3796
Binary files /dev/null and b/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.policy b/checkpoints/best_0.4757/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..d0bcce75eb2b058ef3a63abdfb48d1d78e5f1ef9
Binary files /dev/null and b/checkpoints/best_0.4757/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.meta b/checkpoints/best_0.4893/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..17b21c3612387374893f2a5ce283dc671e0bb66c
Binary files /dev/null and b/checkpoints/best_0.4893/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer b/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..bf6ef14481c4832e7fd8289f775cd1da08d27d25
Binary files /dev/null and b/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.policy b/checkpoints/best_0.4893/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..91722141aa59d4376ad3a7695a08d016908af6ab
Binary files /dev/null and b/checkpoints/best_0.4893/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.meta b/checkpoints/best_0.5003/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..0cbaf636a4b6be40fe1bca6e522239cc733171b4
Binary files /dev/null and b/checkpoints/best_0.5003/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..87a0386cbdbedd9cdaa4d8de81e7131caa71b815
Binary files /dev/null and b/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.policy b/checkpoints/best_0.5003/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..dfde8badb9c05dc8599bcccf81f27580b4e0ff08
Binary files /dev/null and b/checkpoints/best_0.5003/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.meta b/checkpoints/best_0.5109/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..4079d5979b350916b55315bc305fdaf842b7be27
Binary files /dev/null and b/checkpoints/best_0.5109/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..a7f7fb013513f9a3dfe716acd7ef49c2fa0d57d9
Binary files /dev/null and b/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.policy b/checkpoints/best_0.5109/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..a613c4b42171de1169cbfa42d68a932457399cba
Binary files /dev/null and b/checkpoints/best_0.5109/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.meta b/checkpoints/best_0.5172/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..40125f2717372d6f7b2d9d3c86cb11c592ba99b5
Binary files /dev/null and b/checkpoints/best_0.5172/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..7feedff1803645d8b131523fc21f990f96bc681a
Binary files /dev/null and b/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.policy b/checkpoints/best_0.5172/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..fe9741cdd7a7abd2d171c84aa95b1f1d0a17841c
Binary files /dev/null and b/checkpoints/best_0.5172/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.meta b/checkpoints/best_0.5355/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..b45d9b8937c572df6febfa2f0ac5a9d4cda4eb0e
Binary files /dev/null and b/checkpoints/best_0.5355/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..915d64ab8e782a47c0de68a8ccdb81842a074c85
Binary files /dev/null and b/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.policy b/checkpoints/best_0.5355/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..3300d40d010b85f3e4395c9aed630f6beea486af
Binary files /dev/null and b/checkpoints/best_0.5355/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.meta b/checkpoints/best_0.5435/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..56a27a0763598ba9748c4b337fcb59e95ccdf612
Binary files /dev/null and b/checkpoints/best_0.5435/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..1cec17631653a3834677441f541024d582d406b4
Binary files /dev/null and b/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.policy b/checkpoints/best_0.5435/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..8a3f4e113a4fd5ca7fe4ef27b5dd81d682939df7
Binary files /dev/null and b/checkpoints/best_0.5435/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.meta b/checkpoints/best_0.8620/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..8f5843e3dc95213ff24db6375e9a4ba65dfb29ef
Binary files /dev/null and b/checkpoints/best_0.8620/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer b/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..13b15ba48dde84af9e70f18cb0e4395737351f00
Binary files /dev/null and b/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.policy b/checkpoints/best_0.8620/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..7049ae616c44eb2c20edc1eb6aaf26f16e839851
Binary files /dev/null and b/checkpoints/best_0.8620/ppo/model_checkpoint.policy differ
diff --git a/checkpoints/dqn/README.md b/checkpoints/dqn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..7877436bc8fc766ce1409c3810a48b13da31ac39
--- /dev/null
+++ b/checkpoints/dqn/README.md
@@ -0,0 +1 @@
+DQN checkpoints will be saved here
diff --git a/checkpoints/dqn/model_checkpoint.local b/checkpoints/dqn/model_checkpoint.local
new file mode 100644
index 0000000000000000000000000000000000000000..cca8687b9d333d2a050d8f6910960ef80f0680b7
Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.local differ
diff --git a/checkpoints/dqn/model_checkpoint.meta b/checkpoints/dqn/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..502812ecee5de1aa38caa05cc9766f7d2f04ba7e
Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.meta differ
diff --git a/checkpoints/dqn/model_checkpoint.optimizer b/checkpoints/dqn/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..5badbab158a8d003b069c5a2bb0b8628f25dd89b
Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.optimizer differ
diff --git a/checkpoints/dqn/model_checkpoint.target b/checkpoints/dqn/model_checkpoint.target
new file mode 100644
index 0000000000000000000000000000000000000000..6a853ac88d2554eef6d0b6a414d87d9993c22998
Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.target differ
diff --git a/checkpoints/ppo/README.md b/checkpoints/ppo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..fce8cd3428ab5c13ea082dd922054080beae4822
--- /dev/null
+++ b/checkpoints/ppo/README.md
@@ -0,0 +1 @@
+PPO checkpoints will be saved here
diff --git a/checkpoints/ppo/model_checkpoint.meta b/checkpoints/ppo/model_checkpoint.meta
new file mode 100644
index 0000000000000000000000000000000000000000..7617876cf3d7031f066a779fde687404b0a1cc6f
Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.meta differ
diff --git a/checkpoints/ppo/model_checkpoint.optimizer b/checkpoints/ppo/model_checkpoint.optimizer
new file mode 100644
index 0000000000000000000000000000000000000000..b93d28155360433cbb1574b2e797bf1e293c2f6c
Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.optimizer differ
diff --git a/checkpoints/ppo/model_checkpoint.policy b/checkpoints/ppo/model_checkpoint.policy
new file mode 100644
index 0000000000000000000000000000000000000000..bc21bc40897b530a65966b1cbbbaeb41835f7b69
Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.policy differ
diff --git a/docker_run.sh b/docker_run.sh
index eeec29823b7603c361b65c0752f62a3b328d31c9..f14996e5e254c6d266e2f4d0bb47033b8547aaec 100755
--- a/docker_run.sh
+++ b/docker_run.sh
@@ -1,18 +1,18 @@
-#!/bin/bash
-
-
-if [ -e environ_secret.sh ]
-then
-    echo "Note: Gathering environment variables from environ_secret.sh"
-    source environ_secret.sh
-else
-    echo "Note: Gathering environment variables from environ.sh"
-    source environ.sh
-fi
-
-# Expected Env variables : in environ.sh
-sudo docker run \
-    --net=host \
-    -v ./scratch/test-envs:/flatland_envs:z \
-    -it ${IMAGE_NAME}:${IMAGE_TAG} \
-    /home/aicrowd/run.sh
+#!/bin/bash
+
+
+if [ -e environ_secret.sh ]
+then
+    echo "Note: Gathering environment variables from environ_secret.sh"
+    source environ_secret.sh
+else
+    echo "Note: Gathering environment variables from environ.sh"
+    source environ.sh
+fi
+
+# Expected Env variables : in environ.sh
+sudo docker run \
+    --net=host \
+    -v ./scratch/test-envs:/flatland_envs:z \
+    -it ${IMAGE_NAME}:${IMAGE_TAG} \
+    /home/aicrowd/run.sh
diff --git a/environment.yml b/environment.yml
index a4109dcac710713829a6589eb06685d464984b7b..e79148acd6e4f97031ea83a2fc01f5f941a8f1e0 100644
--- a/environment.yml
+++ b/environment.yml
@@ -1,112 +1,112 @@
-name: flatland-rl
-channels:
-  - anaconda
-  - conda-forge
-  - defaults
-dependencies:
-  - tk=8.6.8
-  - cairo=1.16.0
-  - cairocffi=1.1.0
-  - cairosvg=2.4.2
-  - cffi=1.12.3
-  - cssselect2=0.2.1
-  - defusedxml=0.6.0
-  - fontconfig=2.13.1
-  - freetype=2.10.0
-  - gettext=0.19.8.1
-  - glib=2.58.3
-  - icu=64.2
-  - jpeg=9c
-  - libiconv=1.15
-  - libpng=1.6.37
-  - libtiff=4.0.10
-  - libuuid=2.32.1
-  - libxcb=1.13
-  - libxml2=2.9.9
-  - lz4-c=1.8.3
-  - olefile=0.46
-  - pcre=8.41
-  - pillow=5.3.0
-  - pixman=0.38.0
-  - pthread-stubs=0.4
-  - pycairo=1.18.1
-  - pycparser=2.19
-  - tinycss2=1.0.2
-  - webencodings=0.5.1
-  - xorg-kbproto=1.0.7
-  - xorg-libice=1.0.10
-  - xorg-libsm=1.2.3
-  - xorg-libx11=1.6.8
-  - xorg-libxau=1.0.9
-  - xorg-libxdmcp=1.1.3
-  - xorg-libxext=1.3.4
-  - xorg-libxrender=0.9.10
-  - xorg-renderproto=0.11.1
-  - xorg-xextproto=7.3.0
-  - xorg-xproto=7.0.31
-  - zstd=1.4.0
-  - _libgcc_mutex=0.1
-  - ca-certificates=2019.5.15
-  - certifi=2019.6.16
-  - libedit=3.1.20181209
-  - libffi=3.2.1
-  - ncurses=6.1
-  - openssl=1.1.1c
-  - pip=19.1.1
-  - python=3.6.8
-  - readline=7.0
-  - setuptools=41.0.1
-  - sqlite=3.29.0
-  - wheel=0.33.4
-  - xz=5.2.4
-  - zlib=1.2.11
-  - pip:
-    - atomicwrites==1.3.0
-    - importlib-metadata==0.19
-    - importlib-resources==1.0.2
-    - attrs==19.1.0
-    - chardet==3.0.4
-    - click==7.0
-    - cloudpickle==1.2.2
-    - crowdai-api==0.1.21
-    - cycler==0.10.0
-    - filelock==3.0.12
-    - flatland-rl==2.2.1
-    - future==0.17.1
-    - gym==0.14.0
-    - idna==2.8
-    - kiwisolver==1.1.0
-    - lxml==4.4.0
-    - matplotlib==3.1.1
-    - more-itertools==7.2.0
-    - msgpack==0.6.1
-    - msgpack-numpy==0.4.4.3
-    - numpy==1.17.0
-    - packaging==19.0
-    - pandas==0.25.0
-    - pluggy==0.12.0
-    - py==1.8.0
-    - pyarrow==0.14.1
-    - pyglet==1.3.2
-    - pyparsing==2.4.1.1
-    - pytest==5.0.1
-    - pytest-runner==5.1
-    - python-dateutil==2.8.0
-    - python-gitlab==1.10.0
-    - pytz==2019.1
-    - torch==1.5.0
-    - recordtype==1.3
-    - redis==3.3.2
-    - requests==2.22.0
-    - scipy==1.3.1
-    - six==1.12.0
-    - svgutils==0.3.1
-    - timeout-decorator==0.4.1
-    - toml==0.10.0
-    - tox==3.13.2
-    - urllib3==1.25.3
-    - ushlex==0.99.1
-    - virtualenv==16.7.2
-    - wcwidth==0.1.7
-    - xarray==0.12.3
-    - zipp==0.5.2
+name: flatland-rl
+channels:
+  - anaconda
+  - conda-forge
+  - defaults
+dependencies:
+  - tk=8.6.8
+  - cairo=1.16.0
+  - cairocffi=1.1.0
+  - cairosvg=2.4.2
+  - cffi=1.12.3
+  - cssselect2=0.2.1
+  - defusedxml=0.6.0
+  - fontconfig=2.13.1
+  - freetype=2.10.0
+  - gettext=0.19.8.1
+  - glib=2.58.3
+  - icu=64.2
+  - jpeg=9c
+  - libiconv=1.15
+  - libpng=1.6.37
+  - libtiff=4.0.10
+  - libuuid=2.32.1
+  - libxcb=1.13
+  - libxml2=2.9.9
+  - lz4-c=1.8.3
+  - olefile=0.46
+  - pcre=8.41
+  - pillow=5.3.0
+  - pixman=0.38.0
+  - pthread-stubs=0.4
+  - pycairo=1.18.1
+  - pycparser=2.19
+  - tinycss2=1.0.2
+  - webencodings=0.5.1
+  - xorg-kbproto=1.0.7
+  - xorg-libice=1.0.10
+  - xorg-libsm=1.2.3
+  - xorg-libx11=1.6.8
+  - xorg-libxau=1.0.9
+  - xorg-libxdmcp=1.1.3
+  - xorg-libxext=1.3.4
+  - xorg-libxrender=0.9.10
+  - xorg-renderproto=0.11.1
+  - xorg-xextproto=7.3.0
+  - xorg-xproto=7.0.31
+  - zstd=1.4.0
+  - _libgcc_mutex=0.1
+  - ca-certificates=2019.5.15
+  - certifi=2019.6.16
+  - libedit=3.1.20181209
+  - libffi=3.2.1
+  - ncurses=6.1
+  - openssl=1.1.1c
+  - pip=19.1.1
+  - python=3.6.8
+  - readline=7.0
+  - setuptools=41.0.1
+  - sqlite=3.29.0
+  - wheel=0.33.4
+  - xz=5.2.4
+  - zlib=1.2.11
+  - pip:
+    - atomicwrites==1.3.0
+    - importlib-metadata==0.19
+    - importlib-resources==1.0.2
+    - attrs==19.1.0
+    - chardet==3.0.4
+    - click==7.0
+    - cloudpickle==1.2.2
+    - crowdai-api==0.1.21
+    - cycler==0.10.0
+    - filelock==3.0.12
+    - flatland-rl==2.2.1
+    - future==0.17.1
+    - gym==0.14.0
+    - idna==2.8
+    - kiwisolver==1.1.0
+    - lxml==4.4.0
+    - matplotlib==3.1.1
+    - more-itertools==7.2.0
+    - msgpack==0.6.1
+    - msgpack-numpy==0.4.4.3
+    - numpy==1.17.0
+    - packaging==19.0
+    - pandas==0.25.0
+    - pluggy==0.12.0
+    - py==1.8.0
+    - pyarrow==0.14.1
+    - pyglet==1.3.2
+    - pyparsing==2.4.1.1
+    - pytest==5.0.1
+    - pytest-runner==5.1
+    - python-dateutil==2.8.0
+    - python-gitlab==1.10.0
+    - pytz==2019.1
+    - torch==1.5.0
+    - recordtype==1.3
+    - redis==3.3.2
+    - requests==2.22.0
+    - scipy==1.3.1
+    - six==1.12.0
+    - svgutils==0.3.1
+    - timeout-decorator==0.4.1
+    - toml==0.10.0
+    - tox==3.13.2
+    - urllib3==1.25.3
+    - ushlex==0.99.1
+    - virtualenv==16.7.2
+    - wcwidth==0.1.7
+    - xarray==0.12.3
+    - zipp==0.5.2
diff --git a/my_observation_builder.py b/my_observation_builder.py
index 915ff839cb6cad922fe6ca7513465a4a9edde705..482eecfc36c53e9390b1600b7d0ab9a02f032d9a 100644
--- a/my_observation_builder.py
+++ b/my_observation_builder.py
@@ -1,101 +1,101 @@
-#!/usr/bin/env python 
-
-import collections
-from typing import Optional, List, Dict, Tuple
-
-import numpy as np
-
-from flatland.core.env import Environment
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
-
-
-class CustomObservationBuilder(ObservationBuilder):
-    """
-    Template for building a custom observation builder for the RailEnv class
-
-    The observation in this case composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width),\
-          where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
-        
-        - the individual agent object (with position, direction, target information available)
-
-    """
-    def __init__(self):
-        super(CustomObservationBuilder, self).__init__()
-
-    def set_env(self, env: Environment):
-        super().set_env(env)
-        # Note :
-        # The instantiations which depend on parameters of the Env object should be 
-        # done here, as it is only here that the updated self.env instance is available
-        self.rail_obs = np.zeros((self.env.height, self.env.width))
-
-    def reset(self):
-        """
-        Called internally on every env.reset() call, 
-        to reset any observation specific variables that are being used
-        """
-        self.rail_obs[:] = 0        
-        for _x in range(self.env.width):
-            for _y in range(self.env.height):
-                # Get the transition map value at location _x, _y
-                transition_value = self.env.rail.get_full_transitions(_y, _x)
-                self.rail_obs[_y, _x] = transition_value
-
-    def get(self, handle: int = 0):
-        """
-        Returns the built observation for a single agent with handle : handle
-
-        In this particular case, we return 
-        - the global transition_map of the RailEnv,
-        - a tuple containing, the current agent's:
-            - state
-            - position
-            - direction
-            - initial_position
-            - target
-        """
-
-        agent = self.env.agents[handle]
-        """
-        Available information for each agent object : 
-
-        - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
-        - agent.position : Current position of the agent
-        - agent.direction : Current direction of the agent
-        - agent.initial_position : Initial Position of the agent
-        - agent.target : Target position of the agent
-        """
-
-        status = agent.status
-        position = agent.position
-        direction = agent.direction
-        initial_position = agent.initial_position
-        target = agent.target
-
-        
-        """
-        You can also optionally access the states of the rest of the agents by 
-        using something similar to 
-
-        for i in range(len(self.env.agents)):
-            other_agent: EnvAgent = self.env.agents[i]
-
-            # ignore other agents not in the grid any more
-            if other_agent.status == RailAgentStatus.DONE_REMOVED:
-                continue
-
-            ## Gather other agent specific params 
-            other_agent_status = other_agent.status
-            other_agent_position = other_agent.position
-            other_agent_direction = other_agent.direction
-            other_agent_initial_position = other_agent.initial_position
-            other_agent_target = other_agent.target
-
-            ## Do something nice here if you wish
-        """
-        return self.rail_obs, (status, position, direction, initial_position, target)
-
+#!/usr/bin/env python 
+
+import collections
+from typing import Optional, List, Dict, Tuple
+
+import numpy as np
+
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
+
+
+class CustomObservationBuilder(ObservationBuilder):
+    """
+    Template for building a custom observation builder for the RailEnv class
+
+    The observation in this case composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width),\
+          where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
+        
+        - the individual agent object (with position, direction, target information available)
+
+    """
+    def __init__(self):
+        super(CustomObservationBuilder, self).__init__()
+
+    def set_env(self, env: Environment):
+        super().set_env(env)
+        # Note :
+        # The instantiations which depend on parameters of the Env object should be 
+        # done here, as it is only here that the updated self.env instance is available
+        self.rail_obs = np.zeros((self.env.height, self.env.width))
+
+    def reset(self):
+        """
+        Called internally on every env.reset() call, 
+        to reset any observation specific variables that are being used
+        """
+        self.rail_obs[:] = 0        
+        for _x in range(self.env.width):
+            for _y in range(self.env.height):
+                # Get the transition map value at location _x, _y
+                transition_value = self.env.rail.get_full_transitions(_y, _x)
+                self.rail_obs[_y, _x] = transition_value
+
+    def get(self, handle: int = 0):
+        """
+        Returns the built observation for a single agent with handle : handle
+
+        In this particular case, we return 
+        - the global transition_map of the RailEnv,
+        - a tuple containing, the current agent's:
+            - state
+            - position
+            - direction
+            - initial_position
+            - target
+        """
+
+        agent = self.env.agents[handle]
+        """
+        Available information for each agent object : 
+
+        - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
+        - agent.position : Current position of the agent
+        - agent.direction : Current direction of the agent
+        - agent.initial_position : Initial Position of the agent
+        - agent.target : Target position of the agent
+        """
+
+        status = agent.status
+        position = agent.position
+        direction = agent.direction
+        initial_position = agent.initial_position
+        target = agent.target
+
+        
+        """
+        You can also optionally access the states of the rest of the agents by 
+        using something similar to 
+
+        for i in range(len(self.env.agents)):
+            other_agent: EnvAgent = self.env.agents[i]
+
+            # ignore other agents not in the grid any more
+            if other_agent.status == RailAgentStatus.DONE_REMOVED:
+                continue
+
+            ## Gather other agent specific params 
+            other_agent_status = other_agent.status
+            other_agent_position = other_agent.position
+            other_agent_direction = other_agent.direction
+            other_agent_initial_position = other_agent.initial_position
+            other_agent_target = other_agent.target
+
+            ## Do something nice here if you wish
+        """
+        return self.rail_obs, (status, position, direction, initial_position, target)
+
diff --git a/run.py b/run.py
index 855dac19dcfdd6d4a971a8e74ee36c57591b45fe..27c3107896fde363315d78eadd778ffe2ac99f7e 100644
--- a/run.py
+++ b/run.py
@@ -1,174 +1,175 @@
-import time
-
-import numpy as np
-from flatland.envs.agent_utils import RailAgentStatus
-from flatland.evaluators.client import FlatlandRemoteClient
-
-#####################################################################
-# Instantiate a Remote Client
-#####################################################################
-from src.extra import Extra
-from src.observations import MyTreeObsForRailEnv
-
-remote_client = FlatlandRemoteClient()
-
-
-#####################################################################
-# Define your custom controller
-#
-# which can take an observation, and the number of agents and 
-# compute the necessary action for this step for all (or even some)
-# of the agents
-#####################################################################
-def my_controller(extra: Extra, observation, my_observation_builder):
-    return extra.rl_agent_act(observation, my_observation_builder.max_depth)
-
-
-#####################################################################
-# Instantiate your custom Observation Builder
-# 
-# You can build your own Observation Builder by following 
-# the example here : 
-# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
-#####################################################################
-my_observation_builder = MyTreeObsForRailEnv(max_depth=3)
-
-# Or if you want to use your own approach to build the observation from the env_step, 
-# please feel free to pass a DummyObservationBuilder() object as mentioned below,
-# and that will just return a placeholder True for all observation, and you 
-# can build your own Observation for all the agents as your please.
-# my_observation_builder = DummyObservationBuilder()
-
-
-#####################################################################
-# Main evaluation loop
-#
-# This iterates over an arbitrary number of env evaluations
-#####################################################################
-evaluation_number = 0
-while True:
-
-    evaluation_number += 1
-    # Switch to a new evaluation environemnt
-    # 
-    # a remote_client.env_create is similar to instantiating a 
-    # RailEnv and then doing a env.reset()
-    # hence it returns the first observation from the 
-    # env.reset()
-    # 
-    # You can also pass your custom observation_builder object
-    # to allow you to have as much control as you wish 
-    # over the observation of your choice.
-    time_start = time.time()
-    observation, info = remote_client.env_create(
-        obs_builder_object=my_observation_builder
-    )
-    if not observation:
-        #
-        # If the remote_client returns False on a `env_create` call,
-        # then it basically means that your agent has already been 
-        # evaluated on all the required evaluation environments,
-        # and hence its safe to break out of the main evaluation loop
-        break
-
-    print("Evaluation Number : {}".format(evaluation_number))
-
-    #####################################################################
-    # Access to a local copy of the environment
-    # 
-    #####################################################################
-    # Note: You can access a local copy of the environment 
-    # by using : 
-    #       remote_client.env 
-    # 
-    # But please ensure to not make any changes (or perform any action) on 
-    # the local copy of the env, as then it will diverge from 
-    # the state of the remote copy of the env, and the observations and 
-    # rewards, etc will behave unexpectedly
-    # 
-    # You can however probe the local_env instance to get any information
-    # you need from the environment. It is a valid RailEnv instance.
-    local_env = remote_client.env
-    number_of_agents = len(local_env.agents)
-
-    # Now we enter into another infinite loop where we 
-    # compute the actions for all the individual steps in this episode
-    # until the episode is `done`
-    # 
-    # An episode is considered done when either all the agents have 
-    # reached their target destination
-    # or when the number of time steps has exceed max_time_steps, which 
-    # is defined by : 
-    #
-    # max_time_steps = int(4 * 2 * (env.width + env.height + 20))
-    #
-    time_taken_by_controller = []
-    time_taken_per_step = []
-    steps = 0
-
-    extra = Extra(local_env)
-    env_creation_time = time.time() - time_start
-    print("Env Creation Time : ", env_creation_time)
-    print("Agents : ", extra.env.get_num_agents())
-    print("w : ", extra.env.width)
-    print("h : ", extra.env.height)
-
-    while True:
-        #####################################################################
-        # Evaluation of a single episode
-        #
-        #####################################################################
-        # Compute the action for this step by using the previously 
-        # defined controller
-        time_start = time.time()
-        action = my_controller(extra, observation, my_observation_builder)
-        time_taken = time.time() - time_start
-        time_taken_by_controller.append(time_taken)
-
-        # Perform the chosen action on the environment.
-        # The action gets applied to both the local and the remote copy 
-        # of the environment instance, and the observation is what is 
-        # returned by the local copy of the env, and the rewards, and done and info
-        # are returned by the remote copy of the env
-        time_start = time.time()
-        observation, all_rewards, done, info = remote_client.env_step(action)
-        steps += 1
-        time_taken = time.time() - time_start
-        time_taken_per_step.append(time_taken)
-
-        total_done = 0
-        for a in range(local_env.get_num_agents()):
-            x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
-            total_done += int(x)
-        print("total_done:", total_done)
-
-        if done['__all__']:
-            print("Reward : ", sum(list(all_rewards.values())))
-            #
-            # When done['__all__'] == True, then the evaluation of this 
-            # particular Env instantiation is complete, and we can break out 
-            # of this loop, and move onto the next Env evaluation
-            break
-
-    np_time_taken_by_controller = np.array(time_taken_by_controller)
-    np_time_taken_per_step = np.array(time_taken_per_step)
-    print("=" * 100)
-    print("=" * 100)
-    print("Evaluation Number : ", evaluation_number)
-    print("Current Env Path : ", remote_client.current_env_path)
-    print("Env Creation Time : ", env_creation_time)
-    print("Number of Steps : ", steps)
-    print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
-          np_time_taken_by_controller.std())
-    print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
-    print("=" * 100)
-
-print("Evaluation of all environments complete...")
-########################################################################
-# Submit your Results
-# 
-# Please do not forget to include this call, as this triggers the 
-# final computation of the score statistics, video generation, etc
-# and is necesaary to have your submission marked as successfully evaluated
-########################################################################
-print(remote_client.submit())
+import time
+
+import numpy as np
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.evaluators.client import FlatlandRemoteClient
+
+#####################################################################
+# Instantiate a Remote Client
+#####################################################################
+from src.extra import Extra
+
+remote_client = FlatlandRemoteClient()
+
+
+#####################################################################
+# Define your custom controller
+#
+# which can take an observation, and the number of agents and 
+# compute the necessary action for this step for all (or even some)
+# of the agents
+#####################################################################
+def my_controller(extra: Extra, observation, info):
+    return extra.rl_agent_act(observation, info)
+
+
+#####################################################################
+# Instantiate your custom Observation Builder
+# 
+# You can build your own Observation Builder by following 
+# the example here : 
+# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14
+#####################################################################
+my_observation_builder = Extra(max_depth=2)
+
+# Or if you want to use your own approach to build the observation from the env_step, 
+# please feel free to pass a DummyObservationBuilder() object as mentioned below,
+# and that will just return a placeholder True for all observation, and you 
+# can build your own Observation for all the agents as your please.
+# my_observation_builder = DummyObservationBuilder()
+
+
+#####################################################################
+# Main evaluation loop
+#
+# This iterates over an arbitrary number of env evaluations
+#####################################################################
+evaluation_number = 0
+while True:
+
+    evaluation_number += 1
+    # Switch to a new evaluation environemnt
+    # 
+    # a remote_client.env_create is similar to instantiating a 
+    # RailEnv and then doing a env.reset()
+    # hence it returns the first observation from the 
+    # env.reset()
+    # 
+    # You can also pass your custom observation_builder object
+    # to allow you to have as much control as you wish 
+    # over the observation of your choice.
+    time_start = time.time()
+    observation, info = remote_client.env_create(
+        obs_builder_object=my_observation_builder
+    )
+    if not observation:
+        #
+        # If the remote_client returns False on a `env_create` call,
+        # then it basically means that your agent has already been 
+        # evaluated on all the required evaluation environments,
+        # and hence its safe to break out of the main evaluation loop
+        break
+
+    print("Evaluation Number : {}".format(evaluation_number))
+
+    #####################################################################
+    # Access to a local copy of the environment
+    # 
+    #####################################################################
+    # Note: You can access a local copy of the environment 
+    # by using : 
+    #       remote_client.env 
+    # 
+    # But please ensure to not make any changes (or perform any action) on 
+    # the local copy of the env, as then it will diverge from 
+    # the state of the remote copy of the env, and the observations and 
+    # rewards, etc will behave unexpectedly
+    # 
+    # You can however probe the local_env instance to get any information
+    # you need from the environment. It is a valid RailEnv instance.
+    local_env = remote_client.env
+    number_of_agents = len(local_env.agents)
+
+    # Now we enter into another infinite loop where we 
+    # compute the actions for all the individual steps in this episode
+    # until the episode is `done`
+    # 
+    # An episode is considered done when either all the agents have 
+    # reached their target destination
+    # or when the number of time steps has exceed max_time_steps, which 
+    # is defined by : 
+    #
+    # max_time_steps = int(4 * 2 * (env.width + env.height + 20))
+    #
+    time_taken_by_controller = []
+    time_taken_per_step = []
+    steps = 0
+
+    extra = my_observation_builder
+    env_creation_time = time.time() - time_start
+    print("Env Creation Time : ", env_creation_time)
+    print("Agents : ", extra.env.get_num_agents())
+    print("w : ", extra.env.width)
+    print("h : ", extra.env.height)
+
+    while True:
+        #####################################################################
+        # Evaluation of a single episode
+        #
+        #####################################################################
+        # Compute the action for this step by using the previously 
+        # defined controller
+        time_start = time.time()
+        action = my_controller(extra, observation, info)
+        time_taken = time.time() - time_start
+        time_taken_by_controller.append(time_taken)
+
+        # Perform the chosen action on the environment.
+        # The action gets applied to both the local and the remote copy 
+        # of the environment instance, and the observation is what is 
+        # returned by the local copy of the env, and the rewards, and done and info
+        # are returned by the remote copy of the env
+        time_start = time.time()
+        observation, all_rewards, done, info = remote_client.env_step(action)
+        steps += 1
+        time_taken = time.time() - time_start
+        time_taken_per_step.append(time_taken)
+
+        total_done = 0
+        total_active = 0
+        for a in range(local_env.get_num_agents()):
+            x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED])
+            total_done += int(x)
+            total_active += int(local_env.agents[a].status == RailAgentStatus.ACTIVE)
+        # print("total_done:", total_done, "\ttotal_active", total_active, "\t num agents", local_env.get_num_agents())
+
+        if done['__all__']:
+            print("Reward : ", sum(list(all_rewards.values())))
+            #
+            # When done['__all__'] == True, then the evaluation of this 
+            # particular Env instantiation is complete, and we can break out 
+            # of this loop, and move onto the next Env evaluation
+            break
+
+    np_time_taken_by_controller = np.array(time_taken_by_controller)
+    np_time_taken_per_step = np.array(time_taken_per_step)
+    print("=" * 100)
+    print("=" * 100)
+    print("Evaluation Number : ", evaluation_number)
+    print("Current Env Path : ", remote_client.current_env_path)
+    print("Env Creation Time : ", env_creation_time)
+    print("Number of Steps : ", steps)
+    print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
+          np_time_taken_by_controller.std())
+    print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
+    print("=" * 100)
+
+print("Evaluation of all environments complete...")
+########################################################################
+# Submit your Results
+# 
+# Please do not forget to include this call, as this triggers the 
+# final computation of the score statistics, video generation, etc
+# and is necesaary to have your submission marked as successfully evaluated
+########################################################################
+print(remote_client.submit())
diff --git a/run.sh b/run.sh
index 953c1660c6abafcc0a474c526ef7ffcedef6a5d8..c6d89fe35ae407a87904f15b601cc14974e8ef6a 100755
--- a/run.sh
+++ b/run.sh
@@ -1,3 +1,3 @@
-#!/bin/bash
-
-python ./run.py
+#!/bin/bash
+
+python ./run.py
diff --git a/src/agent/dueling_double_dqn.py b/src/agent/dueling_double_dqn.py
index 6d82fded3552e4b9d72c22dc4f85fa1fde574e48..f08e17602db84265079fa5f6f7d8422d5a5d4c89 100644
--- a/src/agent/dueling_double_dqn.py
+++ b/src/agent/dueling_double_dqn.py
@@ -1,512 +1,512 @@
-import torch
-import torch.optim as optim
-
-BUFFER_SIZE = int(1e5)  # replay buffer size
-BATCH_SIZE = 512  # minibatch size
-GAMMA = 0.99  # discount factor 0.99
-TAU = 0.5e-3  # for soft update of target parameters
-LR = 0.5e-4  # learning rate 0.5e-4 works
-
-# how often to update the network
-UPDATE_EVERY = 20
-UPDATE_EVERY_FINAL = 10
-UPDATE_EVERY_AGENT_CANT_CHOOSE = 200
-
-
-double_dqn = True  # If using double dqn algorithm
-input_channels = 5  # Number of Input channels
-
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-device = torch.device("cpu")
-print(device)
-
-USE_OPTIMIZER = optim.Adam
-# USE_OPTIMIZER = optim.RMSprop
-print(USE_OPTIMIZER)
-
-
-class Agent:
-    """Interacts with and learns from the environment."""
-
-    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
-        """Initialize an Agent object.
-
-        Params
-        ======
-            state_size (int): dimension of each state
-            action_size (int): dimension of each action
-            seed (int): random seed
-        """
-        self.state_size = state_size
-        self.action_size = action_size
-        self.seed = random.seed(seed)
-        self.version = net_type
-        self.double_dqn = double_dqn
-        # Q-Network
-        if self.version == "Conv":
-            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
-        else:
-            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
-
-        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
-
-        # Replay memory
-        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
-        self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
-        self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
-
-        self.final_step = {}
-
-        # Initialize time step (for updating every UPDATE_EVERY steps)
-        self.t_step = 0
-        self.t_step_final = 0
-        self.t_step_agent_can_not_choose = 0
-
-    def save(self, filename):
-        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
-        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
-
-    def load(self, filename):
-        if os.path.exists(filename + ".local"):
-            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
-            print(filename + ".local -> ok")
-        if os.path.exists(filename + ".target"):
-            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
-            print(filename + ".target -> ok")
-        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
-
-    def _update_model(self, switch=0):
-        # Learn every UPDATE_EVERY time steps.
-        # If enough samples are available in memory, get random subset and learn
-        if switch == 0:
-            self.t_step = (self.t_step + 1) % UPDATE_EVERY
-            if self.t_step == 0:
-                if len(self.memory) > BATCH_SIZE:
-                    experiences = self.memory.sample()
-                    self.learn(experiences, GAMMA)
-        elif switch == 1:
-            self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL
-            if self.t_step_final == 0:
-                if len(self.memory_final) > BATCH_SIZE:
-                    experiences = self.memory_final.sample()
-                    self.learn(experiences, GAMMA)
-        else:
-            # If enough samples are available in memory_agent_can_not_choose, get random subset and learn
-            self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE
-            if self.t_step_agent_can_not_choose == 0:
-                if len(self.memory_agent_can_not_choose) > BATCH_SIZE:
-                    experiences = self.memory_agent_can_not_choose.sample()
-                    self.learn(experiences, GAMMA)
-
-    def step(self, state, action, reward, next_state, done):
-        # Save experience in replay memory
-        self.memory.add(state, action, reward, next_state, done)
-        self._update_model(0)
-
-    def step_agent_can_not_choose(self, state, action, reward, next_state, done):
-        # Save experience in replay memory_agent_can_not_choose
-        self.memory_agent_can_not_choose.add(state, action, reward, next_state, done)
-        self._update_model(2)
-
-    def add_final_step(self, agent_handle, state, action, reward, next_state, done):
-        if self.final_step.get(agent_handle) is None:
-            self.final_step.update({agent_handle: [state, action, reward, next_state, done]})
-
-    def make_final_step(self, additional_reward=0):
-        for _, item in self.final_step.items():
-            state = item[0]
-            action = item[1]
-            reward = item[2] + additional_reward
-            next_state = item[3]
-            done = item[4]
-            self.memory_final.add(state, action, reward, next_state, done)
-            self._update_model(1)
-        self._reset_final_step()
-
-    def _reset_final_step(self):
-        self.final_step = {}
-
-    def act(self, state, eps=0.):
-        """Returns actions for given state as per current policy.
-
-        Params
-        ======
-            state (array_like): current state
-            eps (float): epsilon, for epsilon-greedy action selection
-        """
-        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
-        self.qnetwork_local.eval()
-        with torch.no_grad():
-            action_values = self.qnetwork_local(state)
-        self.qnetwork_local.train()
-
-        # Epsilon-greedy action selection
-        if random.random() > eps:
-            return np.argmax(action_values.cpu().data.numpy())
-        else:
-            return random.choice(np.arange(self.action_size))
-
-    def learn(self, experiences, gamma):
-
-        """Update value parameters using given batch of experience tuples.
-
-        Params
-        ======
-            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
-            gamma (float): discount factor
-        """
-        states, actions, rewards, next_states, dones = experiences
-
-        # Get expected Q values from local model
-        Q_expected = self.qnetwork_local(states).gather(1, actions)
-
-        if self.double_dqn:
-            # Double DQN
-            q_best_action = self.qnetwork_local(next_states).max(1)[1]
-            Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
-        else:
-            # DQN
-            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
-
-            # Compute Q targets for current states
-
-        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
-
-        # Compute loss
-        loss = F.mse_loss(Q_expected, Q_targets)
-        # Minimize the loss
-        self.optimizer.zero_grad()
-        loss.backward()
-        self.optimizer.step()
-
-        # ------------------- update target network ------------------- #
-        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
-
-    def soft_update(self, local_model, target_model, tau):
-        """Soft update model parameters.
-        θ_target = τ*θ_local + (1 - τ)*θ_target
-
-        Params
-        ======
-            local_model (PyTorch model): weights will be copied from
-            target_model (PyTorch model): weights will be copied to
-            tau (float): interpolation parameter
-        """
-        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
-            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
-
-
-class ReplayBuffer:
-    """Fixed-size buffer to store experience tuples."""
-
-    def __init__(self, action_size, buffer_size, batch_size, seed):
-        """Initialize a ReplayBuffer object.
-
-        Params
-        ======
-            action_size (int): dimension of each action
-            buffer_size (int): maximum size of buffer
-            batch_size (int): size of each training batch
-            seed (int): random seed
-        """
-        self.action_size = action_size
-        self.memory = deque(maxlen=buffer_size)
-        self.batch_size = batch_size
-        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
-        self.seed = random.seed(seed)
-
-    def add(self, state, action, reward, next_state, done):
-        """Add a new experience to memory."""
-        e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
-        self.memory.append(e)
-
-    def sample(self):
-        """Randomly sample a batch of experiences from memory."""
-        experiences = random.sample(self.memory, k=self.batch_size)
-
-        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
-            .float().to(device)
-        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
-            .long().to(device)
-        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
-            .float().to(device)
-        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
-            .float().to(device)
-        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
-            .float().to(device)
-
-        return (states, actions, rewards, next_states, dones)
-
-    def __len__(self):
-        """Return the current size of internal memory."""
-        return len(self.memory)
-
-    def __v_stack_impr(self, states):
-        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
-        np_states = np.reshape(np.array(states), (len(states), sub_dim))
-        return np_states
-
-
-import copy
-import os
-import random
-from collections import namedtuple, deque, Iterable
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-import torch.optim as optim
-
-from src.agent.model import QNetwork2, QNetwork
-
-BUFFER_SIZE = int(1e5)  # replay buffer size
-BATCH_SIZE = 512  # minibatch size
-GAMMA = 0.95  # discount factor 0.99
-TAU = 0.5e-4  # for soft update of target parameters
-LR = 0.5e-3  # learning rate 0.5e-4 works
-
-# how often to update the network
-UPDATE_EVERY = 40
-UPDATE_EVERY_FINAL = 1000
-UPDATE_EVERY_AGENT_CANT_CHOOSE = 200
-
-double_dqn = True  # If using double dqn algorithm
-input_channels = 5  # Number of Input channels
-
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-device = torch.device("cpu")
-print(device)
-
-USE_OPTIMIZER = optim.Adam
-# USE_OPTIMIZER = optim.RMSprop
-print(USE_OPTIMIZER)
-
-
-class Agent:
-    """Interacts with and learns from the environment."""
-
-    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
-        """Initialize an Agent object.
-
-        Params
-        ======
-            state_size (int): dimension of each state
-            action_size (int): dimension of each action
-            seed (int): random seed
-        """
-        self.state_size = state_size
-        self.action_size = action_size
-        self.seed = random.seed(seed)
-        self.version = net_type
-        self.double_dqn = double_dqn
-        # Q-Network
-        if self.version == "Conv":
-            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
-        else:
-            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
-            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
-
-        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
-
-        # Replay memory
-        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
-        self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
-        self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
-
-        self.final_step = {}
-
-        # Initialize time step (for updating every UPDATE_EVERY steps)
-        self.t_step = 0
-        self.t_step_final = 0
-        self.t_step_agent_can_not_choose = 0
-
-    def save(self, filename):
-        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
-        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
-
-    def load(self, filename):
-        print("try to load: " + filename)
-        if os.path.exists(filename + ".local"):
-            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
-            print(filename + ".local -> ok")
-        if os.path.exists(filename + ".target"):
-            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
-            print(filename + ".target -> ok")
-        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
-
-    def _update_model(self, switch=0):
-        # Learn every UPDATE_EVERY time steps.
-        # If enough samples are available in memory, get random subset and learn
-        if switch == 0:
-            self.t_step = (self.t_step + 1) % UPDATE_EVERY
-            if self.t_step == 0:
-                if len(self.memory) > BATCH_SIZE:
-                    experiences = self.memory.sample()
-                    self.learn(experiences, GAMMA)
-        elif switch == 1:
-            self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL
-            if self.t_step_final == 0:
-                if len(self.memory_final) > BATCH_SIZE:
-                    experiences = self.memory_final.sample()
-                    self.learn(experiences, GAMMA)
-        else:
-            # If enough samples are available in memory_agent_can_not_choose, get random subset and learn
-            self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE
-            if self.t_step_agent_can_not_choose == 0:
-                if len(self.memory_agent_can_not_choose) > BATCH_SIZE:
-                    experiences = self.memory_agent_can_not_choose.sample()
-                    self.learn(experiences, GAMMA)
-
-    def step(self, state, action, reward, next_state, done):
-        # Save experience in replay memory
-        self.memory.add(state, action, reward, next_state, done)
-        self._update_model(0)
-
-    def step_agent_can_not_choose(self, state, action, reward, next_state, done):
-        # Save experience in replay memory_agent_can_not_choose
-        self.memory_agent_can_not_choose.add(state, action, reward, next_state, done)
-        self._update_model(2)
-
-    def add_final_step(self, agent_handle, state, action, reward, next_state, done):
-        if self.final_step.get(agent_handle) is None:
-            self.final_step.update({agent_handle: [state, action, reward, next_state, done]})
-            return True
-        else:
-            return False
-
-    def make_final_step(self, additional_reward=0):
-        for _, item in self.final_step.items():
-            state = item[0]
-            action = item[1]
-            reward = item[2] + additional_reward
-            next_state = item[3]
-            done = item[4]
-            self.memory_final.add(state, action, reward, next_state, done)
-            self._update_model(1)
-        self._reset_final_step()
-
-    def _reset_final_step(self):
-        self.final_step = {}
-
-    def act(self, state, eps=0.):
-        """Returns actions for given state as per current policy.
-
-        Params
-        ======
-            state (array_like): current state
-            eps (float): epsilon, for epsilon-greedy action selection
-        """
-        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
-        self.qnetwork_local.eval()
-        with torch.no_grad():
-            action_values = self.qnetwork_local(state)
-        self.qnetwork_local.train()
-
-        # Epsilon-greedy action selection
-        if random.random() > eps:
-            return np.argmax(action_values.cpu().data.numpy()), False
-        else:
-            return random.choice(np.arange(self.action_size)), True
-
-    def learn(self, experiences, gamma):
-
-        """Update value parameters using given batch of experience tuples.
-
-        Params
-        ======
-            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
-            gamma (float): discount factor
-        """
-        states, actions, rewards, next_states, dones = experiences
-
-        # Get expected Q values from local model
-        Q_expected = self.qnetwork_local(states).gather(1, actions)
-
-        if self.double_dqn:
-            # Double DQN
-            q_best_action = self.qnetwork_local(next_states).max(1)[1]
-            Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
-        else:
-            # DQN
-            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
-
-            # Compute Q targets for current states
-
-        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
-
-        # Compute loss
-        loss = F.mse_loss(Q_expected, Q_targets)
-        # Minimize the loss
-        self.optimizer.zero_grad()
-        loss.backward()
-        self.optimizer.step()
-
-        # ------------------- update target network ------------------- #
-        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
-
-    def soft_update(self, local_model, target_model, tau):
-        """Soft update model parameters.
-        θ_target = τ*θ_local + (1 - τ)*θ_target
-
-        Params
-        ======
-            local_model (PyTorch model): weights will be copied from
-            target_model (PyTorch model): weights will be copied to
-            tau (float): interpolation parameter
-        """
-        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
-            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
-
-
-class ReplayBuffer:
-    """Fixed-size buffer to store experience tuples."""
-
-    def __init__(self, action_size, buffer_size, batch_size, seed):
-        """Initialize a ReplayBuffer object.
-
-        Params
-        ======
-            action_size (int): dimension of each action
-            buffer_size (int): maximum size of buffer
-            batch_size (int): size of each training batch
-            seed (int): random seed
-        """
-        self.action_size = action_size
-        self.memory = deque(maxlen=buffer_size)
-        self.batch_size = batch_size
-        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
-        self.seed = random.seed(seed)
-
-    def add(self, state, action, reward, next_state, done):
-        """Add a new experience to memory."""
-        e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
-        self.memory.append(e)
-
-    def sample(self):
-        """Randomly sample a batch of experiences from memory."""
-        experiences = random.sample(self.memory, k=self.batch_size)
-
-        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
-            .float().to(device)
-        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
-            .long().to(device)
-        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
-            .float().to(device)
-        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
-            .float().to(device)
-        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
-            .float().to(device)
-
-        return (states, actions, rewards, next_states, dones)
-
-    def __len__(self):
-        """Return the current size of internal memory."""
-        return len(self.memory)
-
-    def __v_stack_impr(self, states):
-        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
-        np_states = np.reshape(np.array(states), (len(states), sub_dim))
-        return np_states
+import torch
+import torch.optim as optim
+
+BUFFER_SIZE = int(1e5)  # replay buffer size
+BATCH_SIZE = 512  # minibatch size
+GAMMA = 0.99  # discount factor 0.99
+TAU = 0.5e-3  # for soft update of target parameters
+LR = 0.5e-4  # learning rate 0.5e-4 works
+
+# how often to update the network
+UPDATE_EVERY = 20
+UPDATE_EVERY_FINAL = 10
+UPDATE_EVERY_AGENT_CANT_CHOOSE = 200
+
+
+double_dqn = True  # If using double dqn algorithm
+input_channels = 5  # Number of Input channels
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")
+print(device)
+
+USE_OPTIMIZER = optim.Adam
+# USE_OPTIMIZER = optim.RMSprop
+print(USE_OPTIMIZER)
+
+
+class Agent:
+    """Interacts with and learns from the environment."""
+
+    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
+        """Initialize an Agent object.
+
+        Params
+        ======
+            state_size (int): dimension of each state
+            action_size (int): dimension of each action
+            seed (int): random seed
+        """
+        self.state_size = state_size
+        self.action_size = action_size
+        self.seed = random.seed(seed)
+        self.version = net_type
+        self.double_dqn = double_dqn
+        # Q-Network
+        if self.version == "Conv":
+            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        else:
+            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+        # Replay memory
+        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+
+        self.final_step = {}
+
+        # Initialize time step (for updating every UPDATE_EVERY steps)
+        self.t_step = 0
+        self.t_step_final = 0
+        self.t_step_agent_can_not_choose = 0
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        if os.path.exists(filename + ".local"):
+            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+            print(filename + ".local -> ok")
+        if os.path.exists(filename + ".target"):
+            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+            print(filename + ".target -> ok")
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+    def _update_model(self, switch=0):
+        # Learn every UPDATE_EVERY time steps.
+        # If enough samples are available in memory, get random subset and learn
+        if switch == 0:
+            self.t_step = (self.t_step + 1) % UPDATE_EVERY
+            if self.t_step == 0:
+                if len(self.memory) > BATCH_SIZE:
+                    experiences = self.memory.sample()
+                    self.learn(experiences, GAMMA)
+        elif switch == 1:
+            self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL
+            if self.t_step_final == 0:
+                if len(self.memory_final) > BATCH_SIZE:
+                    experiences = self.memory_final.sample()
+                    self.learn(experiences, GAMMA)
+        else:
+            # If enough samples are available in memory_agent_can_not_choose, get random subset and learn
+            self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE
+            if self.t_step_agent_can_not_choose == 0:
+                if len(self.memory_agent_can_not_choose) > BATCH_SIZE:
+                    experiences = self.memory_agent_can_not_choose.sample()
+                    self.learn(experiences, GAMMA)
+
+    def step(self, state, action, reward, next_state, done):
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+        self._update_model(0)
+
+    def step_agent_can_not_choose(self, state, action, reward, next_state, done):
+        # Save experience in replay memory_agent_can_not_choose
+        self.memory_agent_can_not_choose.add(state, action, reward, next_state, done)
+        self._update_model(2)
+
+    def add_final_step(self, agent_handle, state, action, reward, next_state, done):
+        if self.final_step.get(agent_handle) is None:
+            self.final_step.update({agent_handle: [state, action, reward, next_state, done]})
+
+    def make_final_step(self, additional_reward=0):
+        for _, item in self.final_step.items():
+            state = item[0]
+            action = item[1]
+            reward = item[2] + additional_reward
+            next_state = item[3]
+            done = item[4]
+            self.memory_final.add(state, action, reward, next_state, done)
+            self._update_model(1)
+        self._reset_final_step()
+
+    def _reset_final_step(self):
+        self.final_step = {}
+
+    def act(self, state, eps=0.):
+        """Returns actions for given state as per current policy.
+
+        Params
+        ======
+            state (array_like): current state
+            eps (float): epsilon, for epsilon-greedy action selection
+        """
+        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy())
+        else:
+            return random.choice(np.arange(self.action_size))
+
+    def learn(self, experiences, gamma):
+
+        """Update value parameters using given batch of experience tuples.
+
+        Params
+        ======
+            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
+            gamma (float): discount factor
+        """
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        Q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+            # Compute Q targets for current states
+
+        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
+
+        # Compute loss
+        loss = F.mse_loss(Q_expected, Q_targets)
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        loss.backward()
+        self.optimizer.step()
+
+        # ------------------- update target network ------------------- #
+        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
+
+    def soft_update(self, local_model, target_model, tau):
+        """Soft update model parameters.
+        θ_target = τ*θ_local + (1 - τ)*θ_target
+
+        Params
+        ======
+            local_model (PyTorch model): weights will be copied from
+            target_model (PyTorch model): weights will be copied to
+            tau (float): interpolation parameter
+        """
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, seed):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+            seed (int): random seed
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+        self.seed = random.seed(seed)
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(device)
+
+        return (states, actions, rewards, next_states, dones)
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
+
+
+import copy
+import os
+import random
+from collections import namedtuple, deque, Iterable
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.optim as optim
+
+from src.agent.model import QNetwork2, QNetwork
+
+BUFFER_SIZE = int(1e5)  # replay buffer size
+BATCH_SIZE = 512  # minibatch size
+GAMMA = 0.95  # discount factor 0.99
+TAU = 0.5e-4  # for soft update of target parameters
+LR = 0.5e-3  # learning rate 0.5e-4 works
+
+# how often to update the network
+UPDATE_EVERY = 40
+UPDATE_EVERY_FINAL = 1000
+UPDATE_EVERY_AGENT_CANT_CHOOSE = 200
+
+double_dqn = True  # If using double dqn algorithm
+input_channels = 5  # Number of Input channels
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+device = torch.device("cpu")
+print(device)
+
+USE_OPTIMIZER = optim.Adam
+# USE_OPTIMIZER = optim.RMSprop
+print(USE_OPTIMIZER)
+
+
+class Agent:
+    """Interacts with and learns from the environment."""
+
+    def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5):
+        """Initialize an Agent object.
+
+        Params
+        ======
+            state_size (int): dimension of each state
+            action_size (int): dimension of each action
+            seed (int): random seed
+        """
+        self.state_size = state_size
+        self.action_size = action_size
+        self.seed = random.seed(seed)
+        self.version = net_type
+        self.double_dqn = double_dqn
+        # Q-Network
+        if self.version == "Conv":
+            self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+        else:
+            self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device)
+            self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
+
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+        # Replay memory
+        self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+        self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed)
+
+        self.final_step = {}
+
+        # Initialize time step (for updating every UPDATE_EVERY steps)
+        self.t_step = 0
+        self.t_step_final = 0
+        self.t_step_agent_can_not_choose = 0
+
+    def save(self, filename):
+        torch.save(self.qnetwork_local.state_dict(), filename + ".local")
+        torch.save(self.qnetwork_target.state_dict(), filename + ".target")
+
+    def load(self, filename):
+        print("try to load: " + filename)
+        if os.path.exists(filename + ".local"):
+            self.qnetwork_local.load_state_dict(torch.load(filename + ".local"))
+            print(filename + ".local -> ok")
+        if os.path.exists(filename + ".target"):
+            self.qnetwork_target.load_state_dict(torch.load(filename + ".target"))
+            print(filename + ".target -> ok")
+        self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR)
+
+    def _update_model(self, switch=0):
+        # Learn every UPDATE_EVERY time steps.
+        # If enough samples are available in memory, get random subset and learn
+        if switch == 0:
+            self.t_step = (self.t_step + 1) % UPDATE_EVERY
+            if self.t_step == 0:
+                if len(self.memory) > BATCH_SIZE:
+                    experiences = self.memory.sample()
+                    self.learn(experiences, GAMMA)
+        elif switch == 1:
+            self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL
+            if self.t_step_final == 0:
+                if len(self.memory_final) > BATCH_SIZE:
+                    experiences = self.memory_final.sample()
+                    self.learn(experiences, GAMMA)
+        else:
+            # If enough samples are available in memory_agent_can_not_choose, get random subset and learn
+            self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE
+            if self.t_step_agent_can_not_choose == 0:
+                if len(self.memory_agent_can_not_choose) > BATCH_SIZE:
+                    experiences = self.memory_agent_can_not_choose.sample()
+                    self.learn(experiences, GAMMA)
+
+    def step(self, state, action, reward, next_state, done):
+        # Save experience in replay memory
+        self.memory.add(state, action, reward, next_state, done)
+        self._update_model(0)
+
+    def step_agent_can_not_choose(self, state, action, reward, next_state, done):
+        # Save experience in replay memory_agent_can_not_choose
+        self.memory_agent_can_not_choose.add(state, action, reward, next_state, done)
+        self._update_model(2)
+
+    def add_final_step(self, agent_handle, state, action, reward, next_state, done):
+        if self.final_step.get(agent_handle) is None:
+            self.final_step.update({agent_handle: [state, action, reward, next_state, done]})
+            return True
+        else:
+            return False
+
+    def make_final_step(self, additional_reward=0):
+        for _, item in self.final_step.items():
+            state = item[0]
+            action = item[1]
+            reward = item[2] + additional_reward
+            next_state = item[3]
+            done = item[4]
+            self.memory_final.add(state, action, reward, next_state, done)
+            self._update_model(1)
+        self._reset_final_step()
+
+    def _reset_final_step(self):
+        self.final_step = {}
+
+    def act(self, state, eps=0.):
+        """Returns actions for given state as per current policy.
+
+        Params
+        ======
+            state (array_like): current state
+            eps (float): epsilon, for epsilon-greedy action selection
+        """
+        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
+        self.qnetwork_local.eval()
+        with torch.no_grad():
+            action_values = self.qnetwork_local(state)
+        self.qnetwork_local.train()
+
+        # Epsilon-greedy action selection
+        if random.random() > eps:
+            return np.argmax(action_values.cpu().data.numpy()), False
+        else:
+            return random.choice(np.arange(self.action_size)), True
+
+    def learn(self, experiences, gamma):
+
+        """Update value parameters using given batch of experience tuples.
+
+        Params
+        ======
+            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
+            gamma (float): discount factor
+        """
+        states, actions, rewards, next_states, dones = experiences
+
+        # Get expected Q values from local model
+        Q_expected = self.qnetwork_local(states).gather(1, actions)
+
+        if self.double_dqn:
+            # Double DQN
+            q_best_action = self.qnetwork_local(next_states).max(1)[1]
+            Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
+        else:
+            # DQN
+            Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
+
+            # Compute Q targets for current states
+
+        Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
+
+        # Compute loss
+        loss = F.mse_loss(Q_expected, Q_targets)
+        # Minimize the loss
+        self.optimizer.zero_grad()
+        loss.backward()
+        self.optimizer.step()
+
+        # ------------------- update target network ------------------- #
+        self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
+
+    def soft_update(self, local_model, target_model, tau):
+        """Soft update model parameters.
+        θ_target = τ*θ_local + (1 - τ)*θ_target
+
+        Params
+        ======
+            local_model (PyTorch model): weights will be copied from
+            target_model (PyTorch model): weights will be copied to
+            tau (float): interpolation parameter
+        """
+        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
+            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
+
+
+class ReplayBuffer:
+    """Fixed-size buffer to store experience tuples."""
+
+    def __init__(self, action_size, buffer_size, batch_size, seed):
+        """Initialize a ReplayBuffer object.
+
+        Params
+        ======
+            action_size (int): dimension of each action
+            buffer_size (int): maximum size of buffer
+            batch_size (int): size of each training batch
+            seed (int): random seed
+        """
+        self.action_size = action_size
+        self.memory = deque(maxlen=buffer_size)
+        self.batch_size = batch_size
+        self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
+        self.seed = random.seed(seed)
+
+    def add(self, state, action, reward, next_state, done):
+        """Add a new experience to memory."""
+        e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)
+        self.memory.append(e)
+
+    def sample(self):
+        """Randomly sample a batch of experiences from memory."""
+        experiences = random.sample(self.memory, k=self.batch_size)
+
+        states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \
+            .float().to(device)
+        actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \
+            .long().to(device)
+        rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \
+            .float().to(device)
+        next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \
+            .float().to(device)
+        dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \
+            .float().to(device)
+
+        return (states, actions, rewards, next_states, dones)
+
+    def __len__(self):
+        """Return the current size of internal memory."""
+        return len(self.memory)
+
+    def __v_stack_impr(self, states):
+        sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1
+        np_states = np.reshape(np.array(states), (len(states), sub_dim))
+        return np_states
diff --git a/src/extra.py b/src/extra.py
index 044c32f757a397ec20387ce8f0f707418848c37e..48a3249377f6ef3ac8b5f6f0d3c756c3130f1ed4 100644
--- a/src/extra.py
+++ b/src/extra.py
@@ -1,234 +1,409 @@
-import numpy as np
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.envs.agent_utils import RailAgentStatus
-from flatland.envs.rail_env import RailEnv
-from flatland.envs.rail_env import RailEnvActions
-
-from src.agent.dueling_double_dqn import Agent
-from src.observations import normalize_observation
-
-state_size = 179
-action_size = 5
-print("state_size: ", state_size)
-print("action_size: ", action_size)
-# Now we load a Double dueling DQN agent
-global_rl_agent = Agent(state_size, action_size, "FC", 0)
-global_rl_agent.load('./nets/training_best_0.626_agents_5276.pth')
-
-
-class Extra:
-    global_rl_agent = None
-
-    def __init__(self, env: RailEnv):
-        self.env = env
-        self.rl_agent = global_rl_agent
-        self.switches = {}
-        self.switches_neighbours = {}
-        self.find_all_cell_where_agent_can_choose()
-        self.steps_counter = 0
-
-        self.debug_render_list = []
-        self.debug_render_path_list = []
-
-    def rl_agent_act(self, observation, max_depth, eps=0.0):
-
-        self.steps_counter += 1
-        print(self.steps_counter, self.env.get_num_agents())
-
-        agent_obs = [None] * self.env.get_num_agents()
-        for a in range(self.env.get_num_agents()):
-            if observation[a]:
-                agent_obs[a] = self.generate_state(a, observation, max_depth)
-
-        action_dict = {}
-        # estimate whether the agent(s) can freely choose an action
-        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
-            self.required_agent_descision()
-
-        for a in range(self.env.get_num_agents()):
-            if agent_obs[a] is not None:
-                if agents_can_choose[a]:
-                    act, agent_rnd = self.rl_agent.act(agent_obs[a], eps=eps)
-
-                    l = len(agent_obs[a])
-                    if agent_obs[a][l - 3] > 0 and agents_near_to_switch_all[a]:
-                        act = RailEnvActions.STOP_MOVING
-
-                    action_dict.update({a: act})
-                else:
-                    act = RailEnvActions.MOVE_FORWARD
-                    action_dict.update({a: act})
-            else:
-                action_dict.update({a: RailEnvActions.DO_NOTHING})
-        return action_dict
-
-    def find_all_cell_where_agent_can_choose(self):
-        switches = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                pos = (h, w)
-                for dir in range(4):
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    num_transitions = np.count_nonzero(possible_transitions)
-                    if num_transitions > 1:
-                        if pos not in switches.keys():
-                            switches.update({pos: [dir]})
-                        else:
-                            switches[pos].append(dir)
-
-        switches_neighbours = {}
-        for h in range(self.env.height):
-            for w in range(self.env.width):
-                # look one step forward
-                for dir in range(4):
-                    pos = (h, w)
-                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
-                    for d in range(4):
-                        if possible_transitions[d] == 1:
-                            new_cell = get_new_position(pos, d)
-                            if new_cell in switches.keys() and pos not in switches.keys():
-                                if pos not in switches_neighbours.keys():
-                                    switches_neighbours.update({pos: [dir]})
-                                else:
-                                    switches_neighbours[pos].append(dir)
-
-        self.switches = switches
-        self.switches_neighbours = switches_neighbours
-
-    def check_agent_descision(self, position, direction, switches, switches_neighbours):
-        agents_on_switch = False
-        agents_near_to_switch = False
-        agents_near_to_switch_all = False
-        if position in switches.keys():
-            agents_on_switch = direction in switches[position]
-
-        if position in switches_neighbours.keys():
-            new_cell = get_new_position(position, direction)
-            if new_cell in switches.keys():
-                if not direction in switches[new_cell]:
-                    agents_near_to_switch = direction in switches_neighbours[position]
-            else:
-                agents_near_to_switch = direction in switches_neighbours[position]
-
-            agents_near_to_switch_all = direction in switches_neighbours[position]
-
-        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
-
-    def required_agent_descision(self):
-        agents_can_choose = {}
-        agents_on_switch = {}
-        agents_near_to_switch = {}
-        agents_near_to_switch_all = {}
-        for a in range(self.env.get_num_agents()):
-            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all = \
-                self.check_agent_descision(
-                    self.env.agents[a].position,
-                    self.env.agents[a].direction,
-                    self.switches,
-                    self.switches_neighbours)
-            agents_on_switch.update({a: ret_agents_on_switch})
-            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
-            agents_near_to_switch.update({a: (ret_agents_near_to_switch or ready_to_depart)})
-
-            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
-
-            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all or ready_to_depart)})
-
-        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
-
-    def check_deadlock(self, only_next_cell_check=False, handle=None):
-        agents_with_deadlock = []
-        agents = range(self.env.get_num_agents())
-        if handle is not None:
-            agents = [handle]
-        for a in agents:
-            if self.env.agents[a].status < RailAgentStatus.DONE:
-                position = self.env.agents[a].position
-                first_step = True
-                if position is None:
-                    position = self.env.agents[a].initial_position
-                    first_step = True
-                direction = self.env.agents[a].direction
-                cnt = 0
-                while position is not None:  # and position != self.env.agents[a].target:
-                    possible_transitions = self.env.rail.get_transitions(*position, direction)
-                    # num_transitions = np.count_nonzero(possible_transitions)
-                    agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(
-                        position,
-                        direction,
-                        self.switches,
-                        self.switches_neighbours)
-
-                    if not agents_on_switch or first_step:
-                        first_step = False
-                        new_direction_me = np.argmax(possible_transitions)
-                        new_cell_me = get_new_position(position, new_direction_me)
-                        opp_agent = self.env.agent_positions[new_cell_me]
-                        if opp_agent != -1:
-                            opp_position = self.env.agents[opp_agent].position
-                            opp_direction = self.env.agents[opp_agent].direction
-                            opp_agents_on_switch, opp_agents_near_to_switch, agents_near_to_switch_all = \
-                                self.check_agent_descision(opp_position,
-                                                           opp_direction,
-                                                           self.switches,
-                                                           self.switches_neighbours)
-
-                            # opp_possible_transitions = self.env.rail.get_transitions(*opp_position, opp_direction)
-                            # opp_num_transitions = np.count_nonzero(opp_possible_transitions)
-                            if not opp_agents_on_switch:
-                                if opp_direction != direction:
-                                    agents_with_deadlock.append(a)
-                                    position = None
-                                else:
-                                    if only_next_cell_check:
-                                        position = None
-                                    else:
-                                        position = new_cell_me
-                                        direction = new_direction_me
-                            else:
-                                if only_next_cell_check:
-                                    position = None
-                                else:
-                                    position = new_cell_me
-                                    direction = new_direction_me
-                        else:
-                            if only_next_cell_check:
-                                position = None
-                            else:
-                                position = new_cell_me
-                                direction = new_direction_me
-                    else:
-                        position = None
-
-                cnt += 1
-                if cnt > 100:
-                    position = None
-
-        return agents_with_deadlock
-
-    def generate_state(self, handle: int, root, max_depth: int):
-        n_obs = normalize_observation(root[handle], max_depth)
-
-        position = self.env.agents[handle].position
-        direction = self.env.agents[handle].direction
-        cell_free_4_first_step = -1
-        deadlock_agents = []
-        if self.env.agents[handle].status == RailAgentStatus.READY_TO_DEPART:
-            if self.env.agent_positions[self.env.agents[handle].initial_position] == -1:
-                cell_free_4_first_step = 1
-            position = self.env.agents[handle].initial_position
-        else:
-            deadlock_agents = self.check_deadlock(only_next_cell_check=False, handle=handle)
-        agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(position,
-                                                                                                        direction,
-                                                                                                        self.switches,
-                                                                                                        self.switches_neighbours)
-
-        append_obs = [self.env.agents[handle].status - RailAgentStatus.ACTIVE,
-                      cell_free_4_first_step,
-                      2.0 * int(len(deadlock_agents)) - 1.0,
-                      2.0 * int(agents_on_switch) - 1.0,
-                      2.0 * int(agents_near_to_switch) - 1.0]
-        n_obs = np.append(n_obs, append_obs)
-
-        return n_obs
+import numpy as np
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.agent_utils import RailAgentStatus
+# Adrian Egli performance fix (the fast methods brings more than 50%)
+from flatland.envs.rail_env import RailEnvActions
+
+from src.ppo.agent import Agent
+
+
+def fast_isclose(a, b, rtol):
+    return (a < (b + rtol)) or (a < (b - rtol))
+
+
+def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool:
+    return (
+        max(min_value[0], min(position[0], max_value[0])),
+        max(min_value[1], min(position[1], max_value[1]))
+    )
+
+
+def fast_argmax(possible_transitions: (int, int, int, int)) -> bool:
+    if possible_transitions[0] == 1:
+        return 0
+    if possible_transitions[1] == 1:
+        return 1
+    if possible_transitions[2] == 1:
+        return 2
+    return 3
+
+
+def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool:
+    return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1]
+
+
+def fast_count_nonzero(possible_transitions: (int, int, int, int)):
+    return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3]
+
+
+class Extra(ObservationBuilder):
+
+    def __init__(self, max_depth):
+        self.max_depth = max_depth
+        self.observation_dim = 22
+        self.agent = None
+
+    def loadAgent(self):
+        if self.agent is not None:
+            return
+        self.state_size = self.env.obs_builder.observation_dim
+        self.action_size = 5
+        print("action_size: ", self.action_size)
+        print("state_size: ", self.state_size)
+        self.agent = Agent(self.state_size, self.action_size, 0)
+        self.agent.load('./checkpoints/', 0, 1.0)
+
+    def build_data(self):
+        if self.env is not None:
+            self.env.dev_obs_dict = {}
+        self.switches = {}
+        self.switches_neighbours = {}
+        self.debug_render_list = []
+        self.debug_render_path_list = []
+        if self.env is not None:
+            self.find_all_cell_where_agent_can_choose()
+
+    def find_all_cell_where_agent_can_choose(self):
+
+        switches = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                pos = (h, w)
+                for dir in range(4):
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    num_transitions = fast_count_nonzero(possible_transitions)
+                    if num_transitions > 1:
+                        if pos not in switches.keys():
+                            switches.update({pos: [dir]})
+                        else:
+                            switches[pos].append(dir)
+
+        switches_neighbours = {}
+        for h in range(self.env.height):
+            for w in range(self.env.width):
+                # look one step forward
+                for dir in range(4):
+                    pos = (h, w)
+                    possible_transitions = self.env.rail.get_transitions(*pos, dir)
+                    for d in range(4):
+                        if possible_transitions[d] == 1:
+                            new_cell = get_new_position(pos, d)
+                            if new_cell in switches.keys() and pos not in switches.keys():
+                                if pos not in switches_neighbours.keys():
+                                    switches_neighbours.update({pos: [dir]})
+                                else:
+                                    switches_neighbours[pos].append(dir)
+
+        self.switches = switches
+        self.switches_neighbours = switches_neighbours
+
+    def check_agent_descision(self, position, direction):
+        switches = self.switches
+        switches_neighbours = self.switches_neighbours
+        agents_on_switch = False
+        agents_near_to_switch = False
+        agents_near_to_switch_all = False
+        if position in switches.keys():
+            agents_on_switch = direction in switches[position]
+
+        if position in switches_neighbours.keys():
+            new_cell = get_new_position(position, direction)
+            if new_cell in switches.keys():
+                if not direction in switches[new_cell]:
+                    agents_near_to_switch = direction in switches_neighbours[position]
+            else:
+                agents_near_to_switch = direction in switches_neighbours[position]
+
+            agents_near_to_switch_all = direction in switches_neighbours[position]
+
+        return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
+
+    def required_agent_descision(self):
+        agents_can_choose = {}
+        agents_on_switch = {}
+        agents_near_to_switch = {}
+        agents_near_to_switch_all = {}
+        for a in range(self.env.get_num_agents()):
+            ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all = \
+                self.check_agent_descision(
+                    self.env.agents[a].position,
+                    self.env.agents[a].direction)
+            agents_on_switch.update({a: ret_agents_on_switch})
+            ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART
+            agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)})
+
+            agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]})
+
+            agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)})
+
+        return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all
+
+    def debug_render(self, env_renderer):
+        agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \
+            self.required_agent_descision()
+        self.env.dev_obs_dict = {}
+        for a in range(max(3, self.env.get_num_agents())):
+            self.env.dev_obs_dict.update({a: []})
+
+        selected_agent = None
+        if agents_can_choose[0]:
+            if self.env.agents[0].position is not None:
+                self.debug_render_list.append(self.env.agents[0].position)
+            else:
+                self.debug_render_list.append(self.env.agents[0].initial_position)
+
+        if self.env.agents[0].position is not None:
+            self.debug_render_path_list.append(self.env.agents[0].position)
+        else:
+            self.debug_render_path_list.append(self.env.agents[0].initial_position)
+
+        env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000")
+        env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600")
+        env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666")
+        env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000")
+
+        self.env.dev_obs_dict[0] = self.debug_render_list
+        self.env.dev_obs_dict[1] = self.switches.keys()
+        self.env.dev_obs_dict[2] = self.switches_neighbours.keys()
+        self.env.dev_obs_dict[3] = self.debug_render_path_list
+
+    def normalize_observation(self, obsData):
+        return obsData
+
+    def check_deadlock(self, only_next_cell_check=True, handle=None):
+        agents_with_deadlock = []
+        agents = range(self.env.get_num_agents())
+        if handle is not None:
+            agents = [handle]
+        for a in agents:
+            if self.env.agents[a].status < RailAgentStatus.DONE:
+                position = self.env.agents[a].position
+                first_step = True
+                if position is None:
+                    position = self.env.agents[a].initial_position
+                    first_step = True
+                direction = self.env.agents[a].direction
+                while position is not None:  # and position != self.env.agents[a].target:
+                    possible_transitions = self.env.rail.get_transitions(*position, direction)
+                    # num_transitions = np.count_nonzero(possible_transitions)
+                    agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(
+                        position,
+                        direction,
+                        self.switches,
+                        self.switches_neighbours)
+
+                    if not agents_on_switch or first_step:
+                        first_step = False
+                        new_direction_me = np.argmax(possible_transitions)
+                        new_cell_me = get_new_position(position, new_direction_me)
+                        opp_agent = self.env.agent_positions[new_cell_me]
+                        if opp_agent != -1:
+                            opp_position = self.env.agents[opp_agent].position
+                            opp_direction = self.env.agents[opp_agent].direction
+                            opp_agents_on_switch, opp_agents_near_to_switch, agents_near_to_switch_all = \
+                                self.check_agent_descision(opp_position,
+                                                           opp_direction,
+                                                           self.switches,
+                                                           self.switches_neighbours)
+
+                            # opp_possible_transitions = self.env.rail.get_transitions(*opp_position, opp_direction)
+                            # opp_num_transitions = np.count_nonzero(opp_possible_transitions)
+                            if not opp_agents_on_switch:
+                                if opp_direction != direction:
+                                    agents_with_deadlock.append(a)
+                                    position = None
+                                else:
+                                    if only_next_cell_check:
+                                        position = None
+                                    else:
+                                        position = new_cell_me
+                                        direction = new_direction_me
+                            else:
+                                if only_next_cell_check:
+                                    position = None
+                                else:
+                                    position = new_cell_me
+                                    direction = new_direction_me
+                        else:
+                            if only_next_cell_check:
+                                position = None
+                            else:
+                                position = new_cell_me
+                                direction = new_direction_me
+                    else:
+                        position = None
+
+        return agents_with_deadlock
+
+    def is_collision(self, obsData):
+        if obsData[4] == 1:
+            # Agent is READY_TO_DEPART
+            return False
+        if obsData[6] == 1:
+            # Agent is DONE / DONE_REMOVED
+            return False
+
+        same_dir = obsData[18] + obsData[19] + obsData[20] + obsData[21]
+        if same_dir > 0:
+            # Agent detect an agent walking in same direction and between the agent and the other agent there are all
+            # cell unoccupied. (Follows the agents)
+            return False
+        freedom = obsData[10] + obsData[11] + obsData[12] + obsData[13]
+        blocked = obsData[14] + obsData[15] + obsData[16] + obsData[17]
+        # if the Agent has equal freedom or less then the agent can not avoid the agent travelling towards
+        # (opposite) direction -> this can cause a deadlock (locally tested)
+        return freedom <= blocked and freedom > 0
+
+    def reset(self):
+        self.build_data()
+        return
+
+    def fast_argmax(self, array):
+        if array[0] == 1:
+            return 0
+        if array[1] == 1:
+            return 1
+        if array[2] == 1:
+            return 2
+        return 3
+
+    def _explore(self, handle, new_position, new_direction, depth=0):
+
+        has_opp_agent = 0
+        has_same_agent = 0
+        visited = []
+
+        # stop exploring (max_depth reached)
+        if depth >= self.max_depth:
+            return has_opp_agent, has_same_agent, visited
+
+        # max_explore_steps = 100
+        cnt = 0
+        while cnt < 100:
+            cnt += 1
+
+            visited.append(new_position)
+            opp_a = self.env.agent_positions[new_position]
+            if opp_a != -1 and opp_a != handle:
+                if self.env.agents[opp_a].direction != new_direction:
+                    # opp agent found
+                    has_opp_agent = 1
+                    return has_opp_agent, has_same_agent, visited
+                else:
+                    has_same_agent = 1
+                    return has_opp_agent, has_same_agent, visited
+
+            # convert one-hot encoding to 0,1,2,3
+            possible_transitions = self.env.rail.get_transitions(*new_position, new_direction)
+            agents_on_switch, \
+            agents_near_to_switch, \
+            agents_near_to_switch_all = \
+                self.check_agent_descision(new_position, new_direction)
+            if agents_near_to_switch:
+                return has_opp_agent, has_same_agent, visited
+
+            if agents_on_switch:
+                for dir_loop in range(4):
+                    if possible_transitions[dir_loop] == 1:
+                        hoa, hsa, v = self._explore(handle, new_position, new_direction, depth + 1)
+                        visited.append(v)
+                        has_opp_agent = 0.5 * (has_opp_agent + hoa)
+                        has_same_agent = 0.5 * (has_same_agent + hsa)
+                return has_opp_agent, has_same_agent, visited
+            else:
+                new_direction = fast_argmax(possible_transitions)
+                new_position = get_new_position(new_position, new_direction)
+        return has_opp_agent, has_same_agent, visited
+
+    def get(self, handle):
+        # all values are [0,1]
+        # observation[0]  : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path
+        # observation[1]  : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path
+        # observation[2]  : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path
+        # observation[3]  : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path
+        # observation[4]  : int(agent.status == RailAgentStatus.READY_TO_DEPART)
+        # observation[5]  : int(agent.status == RailAgentStatus.ACTIVE)
+        # observation[6]  : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED)
+        # observation[7]  : current agent is located at a switch, where it can take a routing decision
+        # observation[8]  : current agent is located at a cell, where it has to take a stop-or-go decision
+        # observation[9]  : current agent is located one step before/after a switch
+        # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0)
+        # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1)
+        # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2)
+        # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3)
+        # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1
+        # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1
+        # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1
+        # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1
+        # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1
+        # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1
+        # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1
+        # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1
+
+        observation = np.zeros(self.observation_dim)
+        visited = []
+        agent = self.env.agents[handle]
+
+        agent_done = False
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+            observation[4] = 1
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+            observation[5] = 1
+        else:
+            observation[6] = 1
+            agent_virtual_position = (-1, -1)
+            agent_done = True
+
+        if not agent_done:
+            visited.append(agent_virtual_position)
+            distance_map = self.env.distance_map.get()
+            current_cell_dist = distance_map[handle,
+                                             agent_virtual_position[0], agent_virtual_position[1],
+                                             agent.direction]
+            possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
+            orientation = agent.direction
+            if fast_count_nonzero(possible_transitions) == 1:
+                orientation = np.argmax(possible_transitions)
+
+            for dir_loop, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
+                if possible_transitions[branch_direction]:
+                    new_position = get_new_position(agent_virtual_position, branch_direction)
+                    new_cell_dist = distance_map[handle,
+                                                 new_position[0], new_position[1],
+                                                 branch_direction]
+                    if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)):
+                        observation[dir_loop] = int(new_cell_dist < current_cell_dist)
+
+                    has_opp_agent, has_same_agent, v = self._explore(handle, new_position, branch_direction)
+                    visited.append(v)
+
+                    observation[10 + dir_loop] = 1
+                    observation[14 + dir_loop] = has_opp_agent
+                    observation[18 + dir_loop] = has_same_agent
+
+        agents_on_switch, \
+        agents_near_to_switch, \
+        agents_near_to_switch_all = \
+            self.check_agent_descision(agent_virtual_position, agent.direction)
+        observation[7] = int(agents_on_switch)
+        observation[8] = int(agents_near_to_switch)
+        observation[9] = int(agents_near_to_switch_all)
+
+        self.env.dev_obs_dict.update({handle: visited})
+
+        return observation
+
+    def rl_agent_act(self, observation, info, eps=0.0):
+        self.loadAgent()
+        action_dict = {}
+        for a in range(self.env.get_num_agents()):
+            if info['action_required'][a]:
+                action_dict[a] = self.agent.act(observation[a], eps=eps)
+                # action_dict[a] = np.random.randint(5)
+            else:
+                action_dict[a] = RailEnvActions.DO_NOTHING
+
+        return action_dict
diff --git a/src/observations.py b/src/observations.py
index fd9659cd41d72e8c92d2a96ac40a961524fb27ce..ce47e508b21d7f0cf9a9a88ee864d3da18184cf1 100644
--- a/src/observations.py
+++ b/src/observations.py
@@ -1,736 +1,736 @@
-"""
-Collection of environment-specific ObservationBuilder.
-"""
-import collections
-from typing import Optional, List, Dict, Tuple
-
-import numpy as np
-from flatland.core.env import Environment
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.env_prediction_builder import PredictionBuilder
-from flatland.core.grid.grid4_utils import get_new_position
-from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
-from flatland.utils.ordered_set import OrderedSet
-
-
-class MyTreeObsForRailEnv(ObservationBuilder):
-    """
-    TreeObsForRailEnv object.
-
-    This object returns observation vectors for agents in the RailEnv environment.
-    The information is local to each agent and exploits the graph structure of the rail
-    network to simplify the representation of the state of the environment for each agent.
-
-    For details about the features in the tree observation see the get() function.
-    """
-    Node = collections.namedtuple('Node', 'dist_min_to_target '
-                                          'target_encountered '
-                                          'num_agents_same_direction '
-                                          'num_agents_opposite_direction '
-                                          'childs')
-
-    tree_explored_actions_char = ['L', 'F', 'R', 'B']
-
-    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
-        super().__init__()
-        self.max_depth = max_depth
-        self.observation_dim = 2
-        self.location_has_agent = {}
-        self.predictor = predictor
-        self.location_has_target = None
-
-        self.switches_list = {}
-        self.switches_neighbours_list = []
-        self.check_agent_descision = None
-
-    def reset(self):
-        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
-
-    def set_switch_and_pre_switch(self, switch_list, pre_switch_list, check_agent_descision):
-        self.switches_list = switch_list
-        self.switches_neighbours_list = pre_switch_list
-        self.check_agent_descision = check_agent_descision
-
-    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
-        """
-        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
-        in the `handles` list.
-        """
-
-        if handles is None:
-            handles = []
-        if self.predictor:
-            self.max_prediction_depth = 0
-            self.predicted_pos = {}
-            self.predicted_dir = {}
-            self.predictions = self.predictor.get()
-            if self.predictions:
-                for t in range(self.predictor.max_depth + 1):
-                    pos_list = []
-                    dir_list = []
-                    for a in handles:
-                        if self.predictions[a] is None:
-                            continue
-                        pos_list.append(self.predictions[a][t][1:3])
-                        dir_list.append(self.predictions[a][t][3])
-                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
-                    self.predicted_dir.update({t: dir_list})
-                self.max_prediction_depth = len(self.predicted_pos)
-        # Update local lookup table for all agents' positions
-        # ignore other agents not in the grid (only status active and done)
-
-        self.location_has_agent = {}
-        self.location_has_agent_direction = {}
-        self.location_has_agent_speed = {}
-        self.location_has_agent_malfunction = {}
-        self.location_has_agent_ready_to_depart = {}
-
-        for _agent in self.env.agents:
-            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
-                    _agent.position:
-                self.location_has_agent[tuple(_agent.position)] = 1
-                self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
-                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
-                self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
-                    'malfunction']
-
-            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
-                    _agent.initial_position:
-                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
-                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
-
-        observations = super().get_many(handles)
-
-        return observations
-
-    def get(self, handle: int = 0) -> Node:
-        """
-        Computes the current observation for agent `handle` in env
-
-        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
-        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
-        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
-        the transitions. The order is::
-
-            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
-
-        Each branch data is organized as::
-
-            [root node information] +
-            [recursive branch data from 'left'] +
-            [... from 'forward'] +
-            [... from 'right] +
-            [... from 'back']
-
-        Each node information is composed of 9 features:
-
-        #1:
-            if own target lies on the explored branch the current distance from the agent in number of cells is stored.
-
-        #2:
-            if another agents target is detected the distance in number of cells from the agents current location\
-            is stored
-
-        #3:
-            if another agent is detected the distance in number of cells from current agent position is stored.
-
-        #4:
-            possible conflict detected
-            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
-             distance in number of cells from current agent position
-
-            0 = No other agent reserve the same cell at similar time
-
-        #5:
-            if an not usable switch (for agent) is detected we store the distance.
-
-        #6:
-            This feature stores the distance in number of cells to the next branching  (current node)
-
-        #7:
-            minimum distance from node to the agent's target given the direction of the agent if this path is chosen
-
-        #8:
-            agent in the same direction
-            n = number of agents present same direction \
-                (possible future use: number of other agents in the same direction in this branch)
-            0 = no agent present same direction
-
-        #9:
-            agent in the opposite direction
-            n = number of agents present other direction than myself (so conflict) \
-                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
-            0 = no agent present other direction than myself
-
-        #10:
-            malfunctioning/blokcing agents
-            n = number of time steps the oberved agent remains blocked
-
-        #11:
-            slowest observed speed of an agent in same direction
-            1 if no agent is observed
-
-            min_fractional speed otherwise
-        #12:
-            number of agents ready to depart but no yet active
-
-        Missing/padding nodes are filled in with -inf (truncated).
-        Missing values in present node are filled in with +inf (truncated).
-
-
-        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
-        In case the target node is reached, the values are [0, 0, 0, 0, 0].
-        """
-
-        if handle > len(self.env.agents):
-            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
-        agent = self.env.agents[handle]  # TODO: handle being treated as index
-
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            agent_virtual_position = agent.initial_position
-        elif agent.status == RailAgentStatus.ACTIVE:
-            agent_virtual_position = agent.position
-        elif agent.status == RailAgentStatus.DONE:
-            agent_virtual_position = agent.target
-        else:
-            return None
-
-        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
-        num_transitions = np.count_nonzero(possible_transitions)
-
-        # Here information about the agent itself is stored
-        distance_map = self.env.distance_map.get()
-
-        root_node_observation = MyTreeObsForRailEnv.Node(dist_min_to_target=distance_map[
-            (handle, *agent_virtual_position,
-             agent.direction)],
-                                                         target_encountered=0,
-                                                         num_agents_same_direction=0,
-                                                         num_agents_opposite_direction=0,
-                                                         childs={})
-
-        visited = OrderedSet()
-
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
-        orientation = agent.direction
-
-        if num_transitions == 1:
-            orientation = np.argmax(possible_transitions)
-
-        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
-            if possible_transitions[branch_direction]:
-                new_cell = get_new_position(agent_virtual_position, branch_direction)
-
-                branch_observation, branch_visited = \
-                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
-                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
-
-                visited |= branch_visited
-            else:
-                # add cells filled with infinity if no transition is possible
-                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
-        self.env.dev_obs_dict[handle] = visited
-
-        return root_node_observation
-
-    def _explore_branch(self, handle, position, direction, tot_dist, depth):
-        """
-        Utility function to compute tree-based observations.
-        We walk along the branch and collect the information documented in the get() function.
-        If there is a branching point a new node is created and each possible branch is explored.
-        """
-
-        # [Recursive branch opened]
-        if depth >= self.max_depth + 1:
-            return [], []
-
-        # Continue along direction until next switch or
-        # until no transitions are possible along the current direction (i.e., dead-ends)
-        # We treat dead-ends as nodes, instead of going back, to avoid loops
-        exploring = True
-
-        visited = OrderedSet()
-        agent = self.env.agents[handle]
-
-        other_agent_opposite_direction = 0
-        other_agent_same_direction = 0
-
-        dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
-
-        last_is_dead_end = False
-        last_is_a_decision_cell = False
-        target_encountered = 0
-
-        cnt = 0
-        while exploring:
-
-            dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1],
-                                                                                     direction])
-
-            if agent.target == position:
-                target_encountered = 1
-
-            new_direction_me = direction
-            new_cell_me = position
-            a = self.env.agent_positions[new_cell_me]
-            if a != -1 and a != handle:
-                opp_agent = self.env.agents[a]
-                # look one step forward
-                # opp_possible_transitions = self.env.rail.get_transitions(*opp_agent.position, opp_agent.direction)
-                if opp_agent.direction != new_direction_me:  # opp_possible_transitions[new_direction_me] == 0:
-                    other_agent_opposite_direction += 1
-                else:
-                    other_agent_same_direction += 1
-
-            # #############################
-            # #############################
-            if (position[0], position[1], direction) in visited:
-                break
-            visited.add((position[0], position[1], direction))
-
-            # If the target node is encountered, pick that as node. Also, no further branching is possible.
-            if np.array_equal(position, self.env.agents[handle].target):
-                last_is_target = True
-                break
-
-            exploring = False
-
-            # Check number of possible transitions for agent and total number of transitions in cell (type)
-            possible_transitions = self.env.rail.get_transitions(*position, direction)
-            num_transitions = np.count_nonzero(possible_transitions)
-            # cell_transitions = self.env.rail.get_transitions(*position, direction)
-            transition_bit = bin(self.env.rail.get_full_transitions(*position))
-            total_transitions = transition_bit.count("1")
-
-            if num_transitions == 1:
-                # Check if dead-end, or if we can go forward along direction
-                nbits = total_transitions
-                if nbits == 1:
-                    # Dead-end!
-                    last_is_dead_end = True
-
-            if self.check_agent_descision is not None:
-                ret_agents_on_switch, ret_agents_near_to_switch, agents_near_to_switch_all = \
-                    self.check_agent_descision(position,
-                                               direction,
-                                               self.switches_list,
-                                               self.switches_neighbours_list)
-                if ret_agents_on_switch:
-                    last_is_a_decision_cell = True
-                    break
-
-            exploring = True
-            # convert one-hot encoding to 0,1,2,3
-            cell_transitions = self.env.rail.get_transitions(*position, direction)
-            direction = np.argmax(cell_transitions)
-            position = get_new_position(position, direction)
-
-            cnt += 1
-            if cnt > 1000:
-                exploring = False
-
-        # #############################
-        # #############################
-        # Modify here to append new / different features for each visited cell!
-
-        node = MyTreeObsForRailEnv.Node(dist_min_to_target=dist_min_to_target,
-                                        target_encountered=target_encountered,
-                                        num_agents_opposite_direction=other_agent_opposite_direction,
-                                        num_agents_same_direction=other_agent_same_direction,
-                                        childs={})
-
-        # #############################
-        # #############################
-        # Start from the current orientation, and see which transitions are available;
-        # organize them as [left, forward, right, back], relative to the current orientation
-        # Get the possible transitions
-        possible_transitions = self.env.rail.get_transitions(*position, direction)
-
-        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
-            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
-                                                                 (branch_direction + 2) % 4):
-                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
-                # it back
-                new_cell = get_new_position(position, (branch_direction + 2) % 4)
-                branch_observation, branch_visited = self._explore_branch(handle,
-                                                                          new_cell,
-                                                                          (branch_direction + 2) % 4,
-                                                                          tot_dist + 1,
-                                                                          depth + 1)
-                node.childs[self.tree_explored_actions_char[i]] = branch_observation
-                if len(branch_visited) != 0:
-                    visited |= branch_visited
-            elif last_is_a_decision_cell and possible_transitions[branch_direction]:
-                new_cell = get_new_position(position, branch_direction)
-                branch_observation, branch_visited = self._explore_branch(handle,
-                                                                          new_cell,
-                                                                          branch_direction,
-                                                                          tot_dist + 1,
-                                                                          depth + 1)
-                node.childs[self.tree_explored_actions_char[i]] = branch_observation
-                if len(branch_visited) != 0:
-                    visited |= branch_visited
-            else:
-                # no exploring possible, add just cells with infinity
-                node.childs[self.tree_explored_actions_char[i]] = -np.inf
-
-        if depth == self.max_depth:
-            node.childs.clear()
-        return node, visited
-
-    def util_print_obs_subtree(self, tree: Node):
-        """
-        Utility function to print tree observations returned by this object.
-        """
-        self.print_node_features(tree, "root", "")
-        for direction in self.tree_explored_actions_char:
-            self.print_subtree(tree.childs[direction], direction, "\t")
-
-    @staticmethod
-    def print_node_features(node: Node, label, indent):
-        print(indent, "Direction ", label, ": ", node.num_agents_same_direction,
-              ", ", node.num_agents_opposite_direction)
-
-    def print_subtree(self, node, label, indent):
-        if node == -np.inf or not node:
-            print(indent, "Direction ", label, ": -np.inf")
-            return
-
-        self.print_node_features(node, label, indent)
-
-        if not node.childs:
-            return
-
-        for direction in self.tree_explored_actions_char:
-            self.print_subtree(node.childs[direction], direction, indent + "\t")
-
-    def set_env(self, env: Environment):
-        super().set_env(env)
-        if self.predictor:
-            self.predictor.set_env(self.env)
-
-    def _reverse_dir(self, direction):
-        return int((direction + 2) % 4)
-
-
-class GlobalObsForRailEnv(ObservationBuilder):
-    """
-    Gives a global observation of the entire rail environment.
-    The observation is composed of the following elements:
-
-        - transition map array with dimensions (env.height, env.width, 16),\
-          assuming 16 bits encoding of transitions.
-
-        - obs_agents_state: A 3D array (map_height, map_width, 5) with
-            - first channel containing the agents position and direction
-            - second channel containing the other agents positions and direction
-            - third channel containing agent/other agent malfunctions
-            - fourth channel containing agent/other agent fractional speeds
-            - fifth channel containing number of other agents ready to depart
-
-        - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
-         target and the positions of the other agents targets (flag only, no counter!).
-    """
-
-    def __init__(self):
-        super(GlobalObsForRailEnv, self).__init__()
-
-    def set_env(self, env: Environment):
-        super().set_env(env)
-
-    def reset(self):
-        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
-        for i in range(self.rail_obs.shape[0]):
-            for j in range(self.rail_obs.shape[1]):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
-
-        agent = self.env.agents[handle]
-        if agent.status == RailAgentStatus.READY_TO_DEPART:
-            agent_virtual_position = agent.initial_position
-        elif agent.status == RailAgentStatus.ACTIVE:
-            agent_virtual_position = agent.position
-        elif agent.status == RailAgentStatus.DONE:
-            agent_virtual_position = agent.target
-        else:
-            return None
-
-        obs_targets = np.zeros((self.env.height, self.env.width, 2))
-        obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
-
-        # TODO can we do this more elegantly?
-        # for r in range(self.env.height):
-        #     for c in range(self.env.width):
-        #         obs_agents_state[(r, c)][4] = 0
-        obs_agents_state[:, :, 4] = 0
-
-        obs_agents_state[agent_virtual_position][0] = agent.direction
-        obs_targets[agent.target][0] = 1
-
-        for i in range(len(self.env.agents)):
-            other_agent: EnvAgent = self.env.agents[i]
-
-            # ignore other agents not in the grid any more
-            if other_agent.status == RailAgentStatus.DONE_REMOVED:
-                continue
-
-            obs_targets[other_agent.target][1] = 1
-
-            # second to fourth channel only if in the grid
-            if other_agent.position is not None:
-                # second channel only for other agents
-                if i != handle:
-                    obs_agents_state[other_agent.position][1] = other_agent.direction
-                obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
-                obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
-            # fifth channel: all ready to depart on this position
-            if other_agent.status == RailAgentStatus.READY_TO_DEPART:
-                obs_agents_state[other_agent.initial_position][4] += 1
-        return self.rail_obs, obs_agents_state, obs_targets
-
-
-class LocalObsForRailEnv(ObservationBuilder):
-    """
-    !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
-    Gives a local observation of the rail environment around the agent.
-    The observation is composed of the following elements:
-
-        - transition map array of the local environment around the given agent, \
-          with dimensions (view_height,2*view_width+1, 16), \
-          assuming 16 bits encoding of transitions.
-
-        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
-        if they are in the agent's vision range, its target position, the positions of the other targets.
-
-        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
-          of the other agents at their position coordinates, if they are in the agent's vision range.
-
-        - A 4 elements array with one hot encoding of the direction.
-
-    Use the parameters view_width and view_height to define the rectangular view of the agent.
-    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
-    observation in front of it.
-
-    .. deprecated:: 2.0.0
-    """
-
-    def __init__(self, view_width, view_height, center):
-
-        super(LocalObsForRailEnv, self).__init__()
-        self.view_width = view_width
-        self.view_height = view_height
-        self.center = center
-        self.max_padding = max(self.view_width, self.view_height - self.center)
-
-    def reset(self):
-        # We build the transition map with a view_radius empty cells expansion on each side.
-        # This helps to collect the local transition map view when the agent is close to a border.
-        self.max_padding = max(self.view_width, self.view_height)
-        self.rail_obs = np.zeros((self.env.height,
-                                  self.env.width, 16))
-        for i in range(self.env.height):
-            for j in range(self.env.width):
-                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
-                bitlist = [0] * (16 - len(bitlist)) + bitlist
-                self.rail_obs[i, j] = np.array(bitlist)
-
-    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
-        agents = self.env.agents
-        agent = agents[handle]
-
-        # Correct agents position for padding
-        # agent_rel_pos[0] = agent.position[0] + self.max_padding
-        # agent_rel_pos[1] = agent.position[1] + self.max_padding
-
-        # Collect visible cells as set to be plotted
-        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
-        local_rail_obs = None
-
-        # Add the visible cells to the observed cells
-        self.env.dev_obs_dict[handle] = set(visited)
-
-        # Locate observed agents and their coresponding targets
-        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
-        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
-        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
-        _idx = 0
-        for pos in visited:
-            curr_rel_coord = rel_coords[_idx]
-            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
-            if pos == agent.target:
-                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
-            else:
-                for tmp_agent in agents:
-                    if pos == tmp_agent.target:
-                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
-            if pos != agent.position:
-                for tmp_agent in agents:
-                    if pos == tmp_agent.position:
-                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
-                            tmp_agent.direction]
-
-            _idx += 1
-
-        direction = np.identity(4)[agent.direction]
-        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
-
-    def get_many(self, handles: Optional[List[int]] = None) -> Dict[
-        int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
-        """
-        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
-        in the `handles` list.
-        """
-
-        return super().get_many(handles)
-
-    def field_of_view(self, position, direction, state=None):
-        # Compute the local field of view for an agent in the environment
-        data_collection = False
-        if state is not None:
-            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
-            data_collection = True
-        if direction == 0:
-            origin = (position[0] + self.center, position[1] - self.view_width)
-        elif direction == 1:
-            origin = (position[0] - self.view_width, position[1] - self.center)
-        elif direction == 2:
-            origin = (position[0] - self.center, position[1] + self.view_width)
-        else:
-            origin = (position[0] + self.view_width, position[1] + self.center)
-        visible = list()
-        rel_coords = list()
-        for h in range(self.view_height):
-            for w in range(2 * self.view_width + 1):
-                if direction == 0:
-                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
-                        visible.append((origin[0] - h, origin[1] + w))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
-                elif direction == 1:
-                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
-                        visible.append((origin[0] + w, origin[1] + h))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
-                elif direction == 2:
-                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
-                        visible.append((origin[0] + h, origin[1] - w))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
-                else:
-                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
-                        visible.append((origin[0] - w, origin[1] - h))
-                        rel_coords.append((h, w))
-                    # if data_collection:
-                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
-        if data_collection:
-            return temp_visible_data
-        else:
-            return visible, rel_coords
-
-
-def _split_node_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int) -> (np.ndarray, np.ndarray,
-                                                                                                 np.ndarray):
-    data = np.zeros(2)
-
-    data[0] = 2.0 * int(node.num_agents_opposite_direction > 0) - 1.0
-    # data[1] = 2.0 * int(node.num_agents_same_direction > 0) - 1.0
-    data[1] = 2.0 * int(node.target_encountered > 0) - 1.0
-
-    return data
-
-
-def _split_subtree_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int,
-                                       current_tree_depth: int,
-                                       max_tree_depth: int) -> (
-        np.ndarray, np.ndarray, np.ndarray):
-    if node == -np.inf:
-        remaining_depth = max_tree_depth - current_tree_depth
-        # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
-        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
-        return [0] * num_remaining_nodes * 2
-
-    data = _split_node_into_feature_groups(node, dist_min_to_target)
-
-    if not node.childs:
-        return data
-
-    for direction in MyTreeObsForRailEnv.tree_explored_actions_char:
-        sub_data = _split_subtree_into_feature_groups(node.childs[direction],
-                                                      node.dist_min_to_target,
-                                                      current_tree_depth + 1,
-                                                      max_tree_depth)
-        data = np.concatenate((data, sub_data))
-    return data
-
-
-def split_tree_into_feature_groups(tree: MyTreeObsForRailEnv.Node, max_tree_depth: int) -> (
-        np.ndarray, np.ndarray, np.ndarray):
-    """
-    This function splits the tree into three difference arrays of values
-    """
-    data = _split_node_into_feature_groups(tree, 1000000.0)
-
-    for direction in MyTreeObsForRailEnv.tree_explored_actions_char:
-        sub_data = _split_subtree_into_feature_groups(tree.childs[direction],
-                                                      1000000.0,
-                                                      1,
-                                                      max_tree_depth)
-        data = np.concatenate((data, sub_data))
-
-    return data
-
-
-def normalize_observation(observation: MyTreeObsForRailEnv.Node, tree_depth: int):
-    """
-    This function normalizes the observation used by the RL algorithm
-    """
-    data = split_tree_into_feature_groups(observation, tree_depth)
-    normalized_obs = data
-
-    # navigate_info
-    navigate_info = np.zeros(4)
-    action_info = np.zeros(4)
-    np.seterr(all='raise')
-    try:
-        dm = observation.dist_min_to_target
-        if observation.childs['L'] != -np.inf:
-            navigate_info[0] = dm - observation.childs['L'].dist_min_to_target
-            action_info[0] = 1
-        if observation.childs['F'] != -np.inf:
-            navigate_info[1] = dm - observation.childs['F'].dist_min_to_target
-            action_info[1] = 1
-        if observation.childs['R'] != -np.inf:
-            navigate_info[2] = dm - observation.childs['R'].dist_min_to_target
-            action_info[2] = 1
-        if observation.childs['B'] != -np.inf:
-            navigate_info[3] = dm - observation.childs['B'].dist_min_to_target
-            action_info[3] = 1
-    except:
-        navigate_info = np.ones(4)
-        normalized_obs = np.zeros(len(normalized_obs))
-
-    # navigate_info_2 = np.copy(navigate_info)
-    # max_v = np.max(navigate_info_2)
-    # navigate_info_2 = navigate_info_2 / max_v
-    # navigate_info_2[navigate_info_2 < 1] = -1
-
-    max_v = np.max(navigate_info)
-    navigate_info = navigate_info / max_v
-    navigate_info[navigate_info < 0] = -1
-    # navigate_info[abs(navigate_info) < 1] = 0
-    # normalized_obs = navigate_info
-
-    # navigate_info = np.concatenate((navigate_info, action_info))
-    normalized_obs = np.concatenate((navigate_info, normalized_obs))
-    # normalized_obs = np.concatenate((navigate_info, navigate_info_2))
-    # print(normalized_obs)
-    return normalized_obs
+"""
+Collection of environment-specific ObservationBuilder.
+"""
+import collections
+from typing import Optional, List, Dict, Tuple
+
+import numpy as np
+from flatland.core.env import Environment
+from flatland.core.env_observation_builder import ObservationBuilder
+from flatland.core.env_prediction_builder import PredictionBuilder
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.core.grid.grid_utils import coordinate_to_position
+from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
+from flatland.utils.ordered_set import OrderedSet
+
+
+class MyTreeObsForRailEnv(ObservationBuilder):
+    """
+    TreeObsForRailEnv object.
+
+    This object returns observation vectors for agents in the RailEnv environment.
+    The information is local to each agent and exploits the graph structure of the rail
+    network to simplify the representation of the state of the environment for each agent.
+
+    For details about the features in the tree observation see the get() function.
+    """
+    Node = collections.namedtuple('Node', 'dist_min_to_target '
+                                          'target_encountered '
+                                          'num_agents_same_direction '
+                                          'num_agents_opposite_direction '
+                                          'childs')
+
+    tree_explored_actions_char = ['L', 'F', 'R', 'B']
+
+    def __init__(self, max_depth: int, predictor: PredictionBuilder = None):
+        super().__init__()
+        self.max_depth = max_depth
+        self.observation_dim = 2
+        self.location_has_agent = {}
+        self.predictor = predictor
+        self.location_has_target = None
+
+        self.switches_list = {}
+        self.switches_neighbours_list = []
+        self.check_agent_descision = None
+
+    def reset(self):
+        self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents}
+
+    def set_switch_and_pre_switch(self, switch_list, pre_switch_list, check_agent_descision):
+        self.switches_list = switch_list
+        self.switches_neighbours_list = pre_switch_list
+        self.check_agent_descision = check_agent_descision
+
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]:
+        """
+        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
+        in the `handles` list.
+        """
+
+        if handles is None:
+            handles = []
+        if self.predictor:
+            self.max_prediction_depth = 0
+            self.predicted_pos = {}
+            self.predicted_dir = {}
+            self.predictions = self.predictor.get()
+            if self.predictions:
+                for t in range(self.predictor.max_depth + 1):
+                    pos_list = []
+                    dir_list = []
+                    for a in handles:
+                        if self.predictions[a] is None:
+                            continue
+                        pos_list.append(self.predictions[a][t][1:3])
+                        dir_list.append(self.predictions[a][t][3])
+                    self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
+                    self.predicted_dir.update({t: dir_list})
+                self.max_prediction_depth = len(self.predicted_pos)
+        # Update local lookup table for all agents' positions
+        # ignore other agents not in the grid (only status active and done)
+
+        self.location_has_agent = {}
+        self.location_has_agent_direction = {}
+        self.location_has_agent_speed = {}
+        self.location_has_agent_malfunction = {}
+        self.location_has_agent_ready_to_depart = {}
+
+        for _agent in self.env.agents:
+            if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \
+                    _agent.position:
+                self.location_has_agent[tuple(_agent.position)] = 1
+                self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction
+                self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed']
+                self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[
+                    'malfunction']
+
+            if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \
+                    _agent.initial_position:
+                self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \
+                    self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1
+
+        observations = super().get_many(handles)
+
+        return observations
+
+    def get(self, handle: int = 0) -> Node:
+        """
+        Computes the current observation for agent `handle` in env
+
+        The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible
+        movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv).
+        The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for
+        the transitions. The order is::
+
+            [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back']
+
+        Each branch data is organized as::
+
+            [root node information] +
+            [recursive branch data from 'left'] +
+            [... from 'forward'] +
+            [... from 'right] +
+            [... from 'back']
+
+        Each node information is composed of 9 features:
+
+        #1:
+            if own target lies on the explored branch the current distance from the agent in number of cells is stored.
+
+        #2:
+            if another agents target is detected the distance in number of cells from the agents current location\
+            is stored
+
+        #3:
+            if another agent is detected the distance in number of cells from current agent position is stored.
+
+        #4:
+            possible conflict detected
+            tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \
+             distance in number of cells from current agent position
+
+            0 = No other agent reserve the same cell at similar time
+
+        #5:
+            if an not usable switch (for agent) is detected we store the distance.
+
+        #6:
+            This feature stores the distance in number of cells to the next branching  (current node)
+
+        #7:
+            minimum distance from node to the agent's target given the direction of the agent if this path is chosen
+
+        #8:
+            agent in the same direction
+            n = number of agents present same direction \
+                (possible future use: number of other agents in the same direction in this branch)
+            0 = no agent present same direction
+
+        #9:
+            agent in the opposite direction
+            n = number of agents present other direction than myself (so conflict) \
+                (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
+            0 = no agent present other direction than myself
+
+        #10:
+            malfunctioning/blokcing agents
+            n = number of time steps the oberved agent remains blocked
+
+        #11:
+            slowest observed speed of an agent in same direction
+            1 if no agent is observed
+
+            min_fractional speed otherwise
+        #12:
+            number of agents ready to depart but no yet active
+
+        Missing/padding nodes are filled in with -inf (truncated).
+        Missing values in present node are filled in with +inf (truncated).
+
+
+        In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed]
+        In case the target node is reached, the values are [0, 0, 0, 0, 0].
+        """
+
+        if handle > len(self.env.agents):
+            print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents))
+        agent = self.env.agents[handle]  # TODO: handle being treated as index
+
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            agent_virtual_position = agent.target
+        else:
+            return None
+
+        possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction)
+        num_transitions = np.count_nonzero(possible_transitions)
+
+        # Here information about the agent itself is stored
+        distance_map = self.env.distance_map.get()
+
+        root_node_observation = MyTreeObsForRailEnv.Node(dist_min_to_target=distance_map[
+            (handle, *agent_virtual_position,
+             agent.direction)],
+                                                         target_encountered=0,
+                                                         num_agents_same_direction=0,
+                                                         num_agents_opposite_direction=0,
+                                                         childs={})
+
+        visited = OrderedSet()
+
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # If only one transition is possible, the tree is oriented with this transition as the forward branch.
+        orientation = agent.direction
+
+        if num_transitions == 1:
+            orientation = np.argmax(possible_transitions)
+
+        for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]):
+            if possible_transitions[branch_direction]:
+                new_cell = get_new_position(agent_virtual_position, branch_direction)
+
+                branch_observation, branch_visited = \
+                    self._explore_branch(handle, new_cell, branch_direction, 1, 1)
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation
+
+                visited |= branch_visited
+            else:
+                # add cells filled with infinity if no transition is possible
+                root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf
+        self.env.dev_obs_dict[handle] = visited
+
+        return root_node_observation
+
+    def _explore_branch(self, handle, position, direction, tot_dist, depth):
+        """
+        Utility function to compute tree-based observations.
+        We walk along the branch and collect the information documented in the get() function.
+        If there is a branching point a new node is created and each possible branch is explored.
+        """
+
+        # [Recursive branch opened]
+        if depth >= self.max_depth + 1:
+            return [], []
+
+        # Continue along direction until next switch or
+        # until no transitions are possible along the current direction (i.e., dead-ends)
+        # We treat dead-ends as nodes, instead of going back, to avoid loops
+        exploring = True
+
+        visited = OrderedSet()
+        agent = self.env.agents[handle]
+
+        other_agent_opposite_direction = 0
+        other_agent_same_direction = 0
+
+        dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction]
+
+        last_is_dead_end = False
+        last_is_a_decision_cell = False
+        target_encountered = 0
+
+        cnt = 0
+        while exploring:
+
+            dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1],
+                                                                                     direction])
+
+            if agent.target == position:
+                target_encountered = 1
+
+            new_direction_me = direction
+            new_cell_me = position
+            a = self.env.agent_positions[new_cell_me]
+            if a != -1 and a != handle:
+                opp_agent = self.env.agents[a]
+                # look one step forward
+                # opp_possible_transitions = self.env.rail.get_transitions(*opp_agent.position, opp_agent.direction)
+                if opp_agent.direction != new_direction_me:  # opp_possible_transitions[new_direction_me] == 0:
+                    other_agent_opposite_direction += 1
+                else:
+                    other_agent_same_direction += 1
+
+            # #############################
+            # #############################
+            if (position[0], position[1], direction) in visited:
+                break
+            visited.add((position[0], position[1], direction))
+
+            # If the target node is encountered, pick that as node. Also, no further branching is possible.
+            if np.array_equal(position, self.env.agents[handle].target):
+                last_is_target = True
+                break
+
+            exploring = False
+
+            # Check number of possible transitions for agent and total number of transitions in cell (type)
+            possible_transitions = self.env.rail.get_transitions(*position, direction)
+            num_transitions = np.count_nonzero(possible_transitions)
+            # cell_transitions = self.env.rail.get_transitions(*position, direction)
+            transition_bit = bin(self.env.rail.get_full_transitions(*position))
+            total_transitions = transition_bit.count("1")
+
+            if num_transitions == 1:
+                # Check if dead-end, or if we can go forward along direction
+                nbits = total_transitions
+                if nbits == 1:
+                    # Dead-end!
+                    last_is_dead_end = True
+
+            if self.check_agent_descision is not None:
+                ret_agents_on_switch, ret_agents_near_to_switch, agents_near_to_switch_all = \
+                    self.check_agent_descision(position,
+                                               direction,
+                                               self.switches_list,
+                                               self.switches_neighbours_list)
+                if ret_agents_on_switch:
+                    last_is_a_decision_cell = True
+                    break
+
+            exploring = True
+            # convert one-hot encoding to 0,1,2,3
+            cell_transitions = self.env.rail.get_transitions(*position, direction)
+            direction = np.argmax(cell_transitions)
+            position = get_new_position(position, direction)
+
+            cnt += 1
+            if cnt > 1000:
+                exploring = False
+
+        # #############################
+        # #############################
+        # Modify here to append new / different features for each visited cell!
+
+        node = MyTreeObsForRailEnv.Node(dist_min_to_target=dist_min_to_target,
+                                        target_encountered=target_encountered,
+                                        num_agents_opposite_direction=other_agent_opposite_direction,
+                                        num_agents_same_direction=other_agent_same_direction,
+                                        childs={})
+
+        # #############################
+        # #############################
+        # Start from the current orientation, and see which transitions are available;
+        # organize them as [left, forward, right, back], relative to the current orientation
+        # Get the possible transitions
+        possible_transitions = self.env.rail.get_transitions(*position, direction)
+
+        for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]):
+            if last_is_dead_end and self.env.rail.get_transition((*position, direction),
+                                                                 (branch_direction + 2) % 4):
+                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
+                # it back
+                new_cell = get_new_position(position, (branch_direction + 2) % 4)
+                branch_observation, branch_visited = self._explore_branch(handle,
+                                                                          new_cell,
+                                                                          (branch_direction + 2) % 4,
+                                                                          tot_dist + 1,
+                                                                          depth + 1)
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
+                if len(branch_visited) != 0:
+                    visited |= branch_visited
+            elif last_is_a_decision_cell and possible_transitions[branch_direction]:
+                new_cell = get_new_position(position, branch_direction)
+                branch_observation, branch_visited = self._explore_branch(handle,
+                                                                          new_cell,
+                                                                          branch_direction,
+                                                                          tot_dist + 1,
+                                                                          depth + 1)
+                node.childs[self.tree_explored_actions_char[i]] = branch_observation
+                if len(branch_visited) != 0:
+                    visited |= branch_visited
+            else:
+                # no exploring possible, add just cells with infinity
+                node.childs[self.tree_explored_actions_char[i]] = -np.inf
+
+        if depth == self.max_depth:
+            node.childs.clear()
+        return node, visited
+
+    def util_print_obs_subtree(self, tree: Node):
+        """
+        Utility function to print tree observations returned by this object.
+        """
+        self.print_node_features(tree, "root", "")
+        for direction in self.tree_explored_actions_char:
+            self.print_subtree(tree.childs[direction], direction, "\t")
+
+    @staticmethod
+    def print_node_features(node: Node, label, indent):
+        print(indent, "Direction ", label, ": ", node.num_agents_same_direction,
+              ", ", node.num_agents_opposite_direction)
+
+    def print_subtree(self, node, label, indent):
+        if node == -np.inf or not node:
+            print(indent, "Direction ", label, ": -np.inf")
+            return
+
+        self.print_node_features(node, label, indent)
+
+        if not node.childs:
+            return
+
+        for direction in self.tree_explored_actions_char:
+            self.print_subtree(node.childs[direction], direction, indent + "\t")
+
+    def set_env(self, env: Environment):
+        super().set_env(env)
+        if self.predictor:
+            self.predictor.set_env(self.env)
+
+    def _reverse_dir(self, direction):
+        return int((direction + 2) % 4)
+
+
+class GlobalObsForRailEnv(ObservationBuilder):
+    """
+    Gives a global observation of the entire rail environment.
+    The observation is composed of the following elements:
+
+        - transition map array with dimensions (env.height, env.width, 16),\
+          assuming 16 bits encoding of transitions.
+
+        - obs_agents_state: A 3D array (map_height, map_width, 5) with
+            - first channel containing the agents position and direction
+            - second channel containing the other agents positions and direction
+            - third channel containing agent/other agent malfunctions
+            - fourth channel containing agent/other agent fractional speeds
+            - fifth channel containing number of other agents ready to depart
+
+        - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\
+         target and the positions of the other agents targets (flag only, no counter!).
+    """
+
+    def __init__(self):
+        super(GlobalObsForRailEnv, self).__init__()
+
+    def set_env(self, env: Environment):
+        super().set_env(env)
+
+    def reset(self):
+        self.rail_obs = np.zeros((self.env.height, self.env.width, 16))
+        for i in range(self.rail_obs.shape[0]):
+            for j in range(self.rail_obs.shape[1]):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+
+    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray):
+
+        agent = self.env.agents[handle]
+        if agent.status == RailAgentStatus.READY_TO_DEPART:
+            agent_virtual_position = agent.initial_position
+        elif agent.status == RailAgentStatus.ACTIVE:
+            agent_virtual_position = agent.position
+        elif agent.status == RailAgentStatus.DONE:
+            agent_virtual_position = agent.target
+        else:
+            return None
+
+        obs_targets = np.zeros((self.env.height, self.env.width, 2))
+        obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1
+
+        # TODO can we do this more elegantly?
+        # for r in range(self.env.height):
+        #     for c in range(self.env.width):
+        #         obs_agents_state[(r, c)][4] = 0
+        obs_agents_state[:, :, 4] = 0
+
+        obs_agents_state[agent_virtual_position][0] = agent.direction
+        obs_targets[agent.target][0] = 1
+
+        for i in range(len(self.env.agents)):
+            other_agent: EnvAgent = self.env.agents[i]
+
+            # ignore other agents not in the grid any more
+            if other_agent.status == RailAgentStatus.DONE_REMOVED:
+                continue
+
+            obs_targets[other_agent.target][1] = 1
+
+            # second to fourth channel only if in the grid
+            if other_agent.position is not None:
+                # second channel only for other agents
+                if i != handle:
+                    obs_agents_state[other_agent.position][1] = other_agent.direction
+                obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction']
+                obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed']
+            # fifth channel: all ready to depart on this position
+            if other_agent.status == RailAgentStatus.READY_TO_DEPART:
+                obs_agents_state[other_agent.initial_position][4] += 1
+        return self.rail_obs, obs_agents_state, obs_targets
+
+
+class LocalObsForRailEnv(ObservationBuilder):
+    """
+    !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!!
+    Gives a local observation of the rail environment around the agent.
+    The observation is composed of the following elements:
+
+        - transition map array of the local environment around the given agent, \
+          with dimensions (view_height,2*view_width+1, 16), \
+          assuming 16 bits encoding of transitions.
+
+        - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \
+        if they are in the agent's vision range, its target position, the positions of the other targets.
+
+        - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \
+          of the other agents at their position coordinates, if they are in the agent's vision range.
+
+        - A 4 elements array with one hot encoding of the direction.
+
+    Use the parameters view_width and view_height to define the rectangular view of the agent.
+    The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has
+    observation in front of it.
+
+    .. deprecated:: 2.0.0
+    """
+
+    def __init__(self, view_width, view_height, center):
+
+        super(LocalObsForRailEnv, self).__init__()
+        self.view_width = view_width
+        self.view_height = view_height
+        self.center = center
+        self.max_padding = max(self.view_width, self.view_height - self.center)
+
+    def reset(self):
+        # We build the transition map with a view_radius empty cells expansion on each side.
+        # This helps to collect the local transition map view when the agent is close to a border.
+        self.max_padding = max(self.view_width, self.view_height)
+        self.rail_obs = np.zeros((self.env.height,
+                                  self.env.width, 16))
+        for i in range(self.env.height):
+            for j in range(self.env.width):
+                bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]]
+                bitlist = [0] * (16 - len(bitlist)) + bitlist
+                self.rail_obs[i, j] = np.array(bitlist)
+
+    def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
+        agents = self.env.agents
+        agent = agents[handle]
+
+        # Correct agents position for padding
+        # agent_rel_pos[0] = agent.position[0] + self.max_padding
+        # agent_rel_pos[1] = agent.position[1] + self.max_padding
+
+        # Collect visible cells as set to be plotted
+        visited, rel_coords = self.field_of_view(agent.position, agent.direction, )
+        local_rail_obs = None
+
+        # Add the visible cells to the observed cells
+        self.env.dev_obs_dict[handle] = set(visited)
+
+        # Locate observed agents and their coresponding targets
+        local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16))
+        obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2))
+        obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4))
+        _idx = 0
+        for pos in visited:
+            curr_rel_coord = rel_coords[_idx]
+            local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :]
+            if pos == agent.target:
+                obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1
+            else:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.target:
+                        obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1
+            if pos != agent.position:
+                for tmp_agent in agents:
+                    if pos == tmp_agent.position:
+                        obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[
+                            tmp_agent.direction]
+
+            _idx += 1
+
+        direction = np.identity(4)[agent.direction]
+        return local_rail_obs, obs_map_state, obs_other_agents_state, direction
+
+    def get_many(self, handles: Optional[List[int]] = None) -> Dict[
+        int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]:
+        """
+        Called whenever an observation has to be computed for the `env` environment, for each agent with handle
+        in the `handles` list.
+        """
+
+        return super().get_many(handles)
+
+    def field_of_view(self, position, direction, state=None):
+        # Compute the local field of view for an agent in the environment
+        data_collection = False
+        if state is not None:
+            temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16))
+            data_collection = True
+        if direction == 0:
+            origin = (position[0] + self.center, position[1] - self.view_width)
+        elif direction == 1:
+            origin = (position[0] - self.view_width, position[1] - self.center)
+        elif direction == 2:
+            origin = (position[0] - self.center, position[1] + self.view_width)
+        else:
+            origin = (position[0] + self.view_width, position[1] + self.center)
+        visible = list()
+        rel_coords = list()
+        for h in range(self.view_height):
+            for w in range(2 * self.view_width + 1):
+                if direction == 0:
+                    if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width:
+                        visible.append((origin[0] - h, origin[1] + w))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :]
+                elif direction == 1:
+                    if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width:
+                        visible.append((origin[0] + w, origin[1] + h))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :]
+                elif direction == 2:
+                    if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width:
+                        visible.append((origin[0] + h, origin[1] - w))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :]
+                else:
+                    if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width:
+                        visible.append((origin[0] - w, origin[1] - h))
+                        rel_coords.append((h, w))
+                    # if data_collection:
+                    #    temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :]
+        if data_collection:
+            return temp_visible_data
+        else:
+            return visible, rel_coords
+
+
+def _split_node_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int) -> (np.ndarray, np.ndarray,
+                                                                                                 np.ndarray):
+    data = np.zeros(2)
+
+    data[0] = 2.0 * int(node.num_agents_opposite_direction > 0) - 1.0
+    # data[1] = 2.0 * int(node.num_agents_same_direction > 0) - 1.0
+    data[1] = 2.0 * int(node.target_encountered > 0) - 1.0
+
+    return data
+
+
+def _split_subtree_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int,
+                                       current_tree_depth: int,
+                                       max_tree_depth: int) -> (
+        np.ndarray, np.ndarray, np.ndarray):
+    if node == -np.inf:
+        remaining_depth = max_tree_depth - current_tree_depth
+        # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure
+        num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1))
+        return [0] * num_remaining_nodes * 2
+
+    data = _split_node_into_feature_groups(node, dist_min_to_target)
+
+    if not node.childs:
+        return data
+
+    for direction in MyTreeObsForRailEnv.tree_explored_actions_char:
+        sub_data = _split_subtree_into_feature_groups(node.childs[direction],
+                                                      node.dist_min_to_target,
+                                                      current_tree_depth + 1,
+                                                      max_tree_depth)
+        data = np.concatenate((data, sub_data))
+    return data
+
+
+def split_tree_into_feature_groups(tree: MyTreeObsForRailEnv.Node, max_tree_depth: int) -> (
+        np.ndarray, np.ndarray, np.ndarray):
+    """
+    This function splits the tree into three difference arrays of values
+    """
+    data = _split_node_into_feature_groups(tree, 1000000.0)
+
+    for direction in MyTreeObsForRailEnv.tree_explored_actions_char:
+        sub_data = _split_subtree_into_feature_groups(tree.childs[direction],
+                                                      1000000.0,
+                                                      1,
+                                                      max_tree_depth)
+        data = np.concatenate((data, sub_data))
+
+    return data
+
+
+def normalize_observation(observation: MyTreeObsForRailEnv.Node, tree_depth: int):
+    """
+    This function normalizes the observation used by the RL algorithm
+    """
+    data = split_tree_into_feature_groups(observation, tree_depth)
+    normalized_obs = data
+
+    # navigate_info
+    navigate_info = np.zeros(4)
+    action_info = np.zeros(4)
+    np.seterr(all='raise')
+    try:
+        dm = observation.dist_min_to_target
+        if observation.childs['L'] != -np.inf:
+            navigate_info[0] = dm - observation.childs['L'].dist_min_to_target
+            action_info[0] = 1
+        if observation.childs['F'] != -np.inf:
+            navigate_info[1] = dm - observation.childs['F'].dist_min_to_target
+            action_info[1] = 1
+        if observation.childs['R'] != -np.inf:
+            navigate_info[2] = dm - observation.childs['R'].dist_min_to_target
+            action_info[2] = 1
+        if observation.childs['B'] != -np.inf:
+            navigate_info[3] = dm - observation.childs['B'].dist_min_to_target
+            action_info[3] = 1
+    except:
+        navigate_info = np.ones(4)
+        normalized_obs = np.zeros(len(normalized_obs))
+
+    # navigate_info_2 = np.copy(navigate_info)
+    # max_v = np.max(navigate_info_2)
+    # navigate_info_2 = navigate_info_2 / max_v
+    # navigate_info_2[navigate_info_2 < 1] = -1
+
+    max_v = np.max(navigate_info)
+    navigate_info = navigate_info / max_v
+    navigate_info[navigate_info < 0] = -1
+    # navigate_info[abs(navigate_info) < 1] = 0
+    # normalized_obs = navigate_info
+
+    # navigate_info = np.concatenate((navigate_info, action_info))
+    normalized_obs = np.concatenate((navigate_info, normalized_obs))
+    # normalized_obs = np.concatenate((navigate_info, navigate_info_2))
+    # print(normalized_obs)
+    return normalized_obs
diff --git a/src/ppo/agent.py b/src/ppo/agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..86b210ad388f9d9638c5e1e037cdca27ffaf9e73
--- /dev/null
+++ b/src/ppo/agent.py
@@ -0,0 +1,106 @@
+import pickle
+
+import torch
+# from model import PolicyNetwork
+# from replay_memory import Episode, ReplayBuffer
+from torch.distributions.categorical import Categorical
+
+from src.ppo.model import PolicyNetwork
+from src.ppo.replay_memory import Episode, ReplayBuffer
+
+BUFFER_SIZE = 32_000
+BATCH_SIZE = 4096
+GAMMA = 0.98
+LR = 0.5e-4
+CLIP_FACTOR = .005
+UPDATE_EVERY = 30
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+print("device:", device)
+
+
+class Agent:
+    def __init__(self, state_size, action_size, num_agents):
+        self.policy = PolicyNetwork(state_size, action_size).to(device)
+        self.old_policy = PolicyNetwork(state_size, action_size).to(device)
+        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR)
+
+        self.episodes = [Episode() for _ in range(num_agents)]
+        self.memory = ReplayBuffer(BUFFER_SIZE)
+        self.t_step = 0
+
+    def reset(self):
+        self.finished = [False] * len(self.episodes)
+
+    # Decide on an action to take in the environment
+
+    def act(self, state, eps=None):
+        self.policy.eval()
+        with torch.no_grad():
+            output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device))
+            return Categorical(output).sample().item()
+
+    # Record the results of the agent's action and update the model
+
+    def step(self, handle, state, action, next_state, agent_done, episode_done, collision):
+        if not self.finished[handle]:
+            if agent_done:
+                reward = 1
+            elif collision:
+                reward = -.5
+            else:
+                reward = 0
+
+            # Push experience into Episode memory
+            self.episodes[handle].push(state, action, reward, next_state, agent_done or episode_done)
+
+            # When we finish the episode, discount rewards and push the experience into replay memory
+            if agent_done or episode_done:
+                self.episodes[handle].discount_rewards(GAMMA)
+                self.memory.push_episode(self.episodes[handle])
+                self.episodes[handle].reset()
+                self.finished[handle] = True
+
+        # Perform a gradient update every UPDATE_EVERY time steps
+        self.t_step = (self.t_step + 1) % UPDATE_EVERY
+        if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4:
+            self.learn(*self.memory.sample(BATCH_SIZE, device))
+
+    def learn(self, states, actions, rewards, next_state, done):
+        self.policy.train()
+
+        responsible_outputs = torch.gather(self.policy(states), 1, actions)
+        old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach()
+
+        # rewards = rewards - rewards.mean()
+        ratio = responsible_outputs / (old_responsible_outputs + 1e-5)
+        clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR)
+        loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean()
+
+        # Compute loss and perform a gradient step
+        self.old_policy.load_state_dict(self.policy.state_dict())
+        self.optimizer.zero_grad()
+        loss.backward()
+        self.optimizer.step()
+
+    # Checkpointing methods
+
+    def save(self, path, *data):
+        torch.save(self.policy.state_dict(), path / 'ppo/model_checkpoint.policy')
+        torch.save(self.optimizer.state_dict(), path / 'ppo/model_checkpoint.optimizer')
+        with open(path / 'ppo/model_checkpoint.meta', 'wb') as file:
+            pickle.dump(data, file)
+
+    def load(self, path, *defaults):
+        try:
+            print("Loading model from checkpoint...")
+            print(path + 'ppo/model_checkpoint.policy')
+            self.policy.load_state_dict(
+                torch.load(path + 'ppo/model_checkpoint.policy', map_location=torch.device('cpu')))
+            self.optimizer.load_state_dict(
+                torch.load(path + 'ppo/model_checkpoint.optimizer', map_location=torch.device('cpu')))
+            with open(path + 'ppo/model_checkpoint.meta', 'rb') as file:
+                return pickle.load(file)
+        except:
+            print("No checkpoint file was found")
+            return defaults
diff --git a/src/ppo/model.py b/src/ppo/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..51b86ff16691c03f6a754405352bb4cf48e4b914
--- /dev/null
+++ b/src/ppo/model.py
@@ -0,0 +1,20 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class PolicyNetwork(nn.Module):
+    def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32):
+        super().__init__()
+        self.fc1 = nn.Linear(state_size, hidsize1)
+        self.fc2 = nn.Linear(hidsize1, hidsize2)
+        # self.fc3 = nn.Linear(hidsize2, hidsize3)
+        self.output = nn.Linear(hidsize2, action_size)
+        self.softmax = nn.Softmax(dim=1)
+        self.bn0 = nn.BatchNorm1d(state_size, affine=False)
+
+    def forward(self, inputs):
+        x = self.bn0(inputs.float())
+        x = F.relu(self.fc1(x))
+        x = F.relu(self.fc2(x))
+        # x = F.relu(self.fc3(x))
+        return self.softmax(self.output(x))
diff --git a/src/ppo/replay_memory.py b/src/ppo/replay_memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6619b40169597d7a4b379f4ce2c9ddccd4cd9b
--- /dev/null
+++ b/src/ppo/replay_memory.py
@@ -0,0 +1,53 @@
+import torch
+import random
+import numpy as np
+from collections import namedtuple, deque, Iterable
+
+
+Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done"))
+
+
+class Episode:
+    memory = []
+
+    def reset(self):
+        self.memory = []
+
+    def push(self, *args):
+        self.memory.append(tuple(args))
+
+    def discount_rewards(self, gamma):
+        running_add = 0.
+        for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]:
+            running_add = running_add * gamma + reward
+            self.memory[i] = (state, action, running_add, *rest)
+
+
+class ReplayBuffer:
+    def __init__(self, buffer_size):
+        self.memory = deque(maxlen=buffer_size)
+
+    def push(self, state, action, reward, next_state, done):
+        self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done))
+
+    def push_episode(self, episode):
+        for step in episode.memory:
+            self.push(*step)
+
+    def sample(self, batch_size, device):
+        experiences = random.sample(self.memory, k=batch_size)
+
+        states      = torch.from_numpy(self.stack([e.state      for e in experiences])).float().to(device)
+        actions     = torch.from_numpy(self.stack([e.action     for e in experiences])).long().to(device)
+        rewards     = torch.from_numpy(self.stack([e.reward     for e in experiences])).float().to(device)
+        next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device)
+        dones       = torch.from_numpy(self.stack([e.done       for e in experiences]).astype(np.uint8)).float().to(device)
+
+        return states, actions, rewards, next_states, dones
+
+    def stack(self, states):
+        sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1]
+        return np.reshape(np.array(states), (len(states), *sub_dims))
+
+    def __len__(self):
+        return len(self.memory)