diff --git a/.gitignore b/.gitignore index da214e57a3cb40bddeb8e0e2d0b518b3de06f20e..54ef0cba10a309cb13c6077b720208c5e142257b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,127 +1,127 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -scratch/test-envs/ -scratch/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +scratch/test-envs/ +scratch/ diff --git a/README.md b/README.md index c4e93dc4ae84e6debde8a2483c7b3f94edbc0d5e..22846ae8c532fa537e79ca807fe71860100f1f23 100644 --- a/README.md +++ b/README.md @@ -1,26 +1,26 @@ - - -# Flatland Challenge Starter Kit - -**[Follow these instructions to submit your solutions!](http://flatland.aicrowd.com/getting-started/first-submission.html)** - - - - -Main links ---- - -* [Flatland documentation](https://flatland.aicrowd.com/) -* [NeurIPS 2020 Challenge](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/) - -Communication ---- - -* [Discord Channel](https://discord.com/invite/hCR3CZG) -* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge) -* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/) - -Author ---- - -- **[Sharada Mohanty](https://twitter.com/MeMohanty)** + + +# Flatland Challenge Starter Kit + +**[Follow these instructions to submit your solutions!](http://flatland.aicrowd.com/getting-started/first-submission.html)** + + + + +Main links +--- + +* [Flatland documentation](https://flatland.aicrowd.com/) +* [NeurIPS 2020 Challenge](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/) + +Communication +--- + +* [Discord Channel](https://discord.com/invite/hCR3CZG) +* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge) +* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/) + +Author +--- + +- **[Sharada Mohanty](https://twitter.com/MeMohanty)** diff --git a/aicrowd.json b/aicrowd.json index 68c76af4fd222127604c5e5e3252429f9795fa4c..de611d36b277f16014795935206b0c9250f2bc37 100644 --- a/aicrowd.json +++ b/aicrowd.json @@ -1,7 +1,7 @@ -{ - "challenge_id": "neurips-2020-flatland-challenge", - "grader_id": "neurips-2020-flatland-challenge", - "debug": false, - "tags": ["RL"] -} - +{ + "challenge_id": "neurips-2020-flatland-challenge", + "grader_id": "neurips-2020-flatland-challenge", + "debug": false, + "tags": ["RL"] +} + diff --git a/apt.txt b/apt.txt index d02172562addd7910a88661afa1bf72348d21d02..9c46ae885a5f3bcf73cf5689c501d39f76550bdc 100644 --- a/apt.txt +++ b/apt.txt @@ -1,6 +1,6 @@ -curl -git -vim -ssh -gcc -python-cairo-dev +curl +git +vim +ssh +gcc +python-cairo-dev diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..4f703391c1f977f6a63e4d8320ad8cdfb10e9a97 Binary files /dev/null and b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.meta differ diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..d92c6b5ad82ef53eca4332ccc23056a80a3699a2 Binary files /dev/null and b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..1db257665e9b57e80c24c1e1bcc165fbcdb80d71 Binary files /dev/null and b/checkpoints/No_col_20/best_0.6672/ppo/model_checkpoint.policy differ diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..84ffd934079e5114f881aad914112c35e5b0f777 Binary files /dev/null and b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.meta differ diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..63b4fc065ee61195c655c5aef7b41b6e354bf75d Binary files /dev/null and b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..7430a2d9b643bf30f9522c8ef3390bac8009f4c1 Binary files /dev/null and b/checkpoints/No_col_20/best_0.6719/ppo/model_checkpoint.policy differ diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..8f5843e3dc95213ff24db6375e9a4ba65dfb29ef Binary files /dev/null and b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.meta differ diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..7bc95369b9607e0cd3528f8be5db73eac34c0e17 Binary files /dev/null and b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..339210ff779ad9a89efbf3620151939facd12c77 Binary files /dev/null and b/checkpoints/No_col_20/best_0.7089/ppo/model_checkpoint.policy differ diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..7617876cf3d7031f066a779fde687404b0a1cc6f Binary files /dev/null and b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.meta differ diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..b93d28155360433cbb1574b2e797bf1e293c2f6c Binary files /dev/null and b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..bc21bc40897b530a65966b1cbbbaeb41835f7b69 Binary files /dev/null and b/checkpoints/No_col_20/best_0.7526/ppo/model_checkpoint.policy differ diff --git a/checkpoints/No_col_20/model_checkpoint.meta b/checkpoints/No_col_20/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..fb226078f5ccd715028992499f09f6edfcc4857e Binary files /dev/null and b/checkpoints/No_col_20/model_checkpoint.meta differ diff --git a/checkpoints/No_col_20/model_checkpoint.optimizer b/checkpoints/No_col_20/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..c725a9ccd7791016f89cf8426746dd549dc23410 Binary files /dev/null and b/checkpoints/No_col_20/model_checkpoint.optimizer differ diff --git a/checkpoints/No_col_20/model_checkpoint.policy b/checkpoints/No_col_20/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..91e0a1abe40721649a98ccc76271f34295f7932c Binary files /dev/null and b/checkpoints/No_col_20/model_checkpoint.policy differ diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.meta b/checkpoints/best_0.4757/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..8c645aa2e8794ff4edf9974b297ccca3eed5b013 Binary files /dev/null and b/checkpoints/best_0.4757/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer b/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..bd1ca4d0a9df1ccd8eff82001b542b0edadc3796 Binary files /dev/null and b/checkpoints/best_0.4757/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.4757/ppo/model_checkpoint.policy b/checkpoints/best_0.4757/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..d0bcce75eb2b058ef3a63abdfb48d1d78e5f1ef9 Binary files /dev/null and b/checkpoints/best_0.4757/ppo/model_checkpoint.policy differ diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.meta b/checkpoints/best_0.4893/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..17b21c3612387374893f2a5ce283dc671e0bb66c Binary files /dev/null and b/checkpoints/best_0.4893/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer b/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..bf6ef14481c4832e7fd8289f775cd1da08d27d25 Binary files /dev/null and b/checkpoints/best_0.4893/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.4893/ppo/model_checkpoint.policy b/checkpoints/best_0.4893/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..91722141aa59d4376ad3a7695a08d016908af6ab Binary files /dev/null and b/checkpoints/best_0.4893/ppo/model_checkpoint.policy differ diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.meta b/checkpoints/best_0.5003/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..0cbaf636a4b6be40fe1bca6e522239cc733171b4 Binary files /dev/null and b/checkpoints/best_0.5003/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..87a0386cbdbedd9cdaa4d8de81e7131caa71b815 Binary files /dev/null and b/checkpoints/best_0.5003/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.5003/ppo/model_checkpoint.policy b/checkpoints/best_0.5003/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..dfde8badb9c05dc8599bcccf81f27580b4e0ff08 Binary files /dev/null and b/checkpoints/best_0.5003/ppo/model_checkpoint.policy differ diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.meta b/checkpoints/best_0.5109/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..4079d5979b350916b55315bc305fdaf842b7be27 Binary files /dev/null and b/checkpoints/best_0.5109/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..a7f7fb013513f9a3dfe716acd7ef49c2fa0d57d9 Binary files /dev/null and b/checkpoints/best_0.5109/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.5109/ppo/model_checkpoint.policy b/checkpoints/best_0.5109/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..a613c4b42171de1169cbfa42d68a932457399cba Binary files /dev/null and b/checkpoints/best_0.5109/ppo/model_checkpoint.policy differ diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.meta b/checkpoints/best_0.5172/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..40125f2717372d6f7b2d9d3c86cb11c592ba99b5 Binary files /dev/null and b/checkpoints/best_0.5172/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..7feedff1803645d8b131523fc21f990f96bc681a Binary files /dev/null and b/checkpoints/best_0.5172/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.5172/ppo/model_checkpoint.policy b/checkpoints/best_0.5172/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..fe9741cdd7a7abd2d171c84aa95b1f1d0a17841c Binary files /dev/null and b/checkpoints/best_0.5172/ppo/model_checkpoint.policy differ diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.meta b/checkpoints/best_0.5355/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..b45d9b8937c572df6febfa2f0ac5a9d4cda4eb0e Binary files /dev/null and b/checkpoints/best_0.5355/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..915d64ab8e782a47c0de68a8ccdb81842a074c85 Binary files /dev/null and b/checkpoints/best_0.5355/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.5355/ppo/model_checkpoint.policy b/checkpoints/best_0.5355/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..3300d40d010b85f3e4395c9aed630f6beea486af Binary files /dev/null and b/checkpoints/best_0.5355/ppo/model_checkpoint.policy differ diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.meta b/checkpoints/best_0.5435/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..56a27a0763598ba9748c4b337fcb59e95ccdf612 Binary files /dev/null and b/checkpoints/best_0.5435/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer b/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..1cec17631653a3834677441f541024d582d406b4 Binary files /dev/null and b/checkpoints/best_0.5435/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.5435/ppo/model_checkpoint.policy b/checkpoints/best_0.5435/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..8a3f4e113a4fd5ca7fe4ef27b5dd81d682939df7 Binary files /dev/null and b/checkpoints/best_0.5435/ppo/model_checkpoint.policy differ diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.meta b/checkpoints/best_0.8620/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..8f5843e3dc95213ff24db6375e9a4ba65dfb29ef Binary files /dev/null and b/checkpoints/best_0.8620/ppo/model_checkpoint.meta differ diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer b/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..13b15ba48dde84af9e70f18cb0e4395737351f00 Binary files /dev/null and b/checkpoints/best_0.8620/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/best_0.8620/ppo/model_checkpoint.policy b/checkpoints/best_0.8620/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..7049ae616c44eb2c20edc1eb6aaf26f16e839851 Binary files /dev/null and b/checkpoints/best_0.8620/ppo/model_checkpoint.policy differ diff --git a/checkpoints/dqn/README.md b/checkpoints/dqn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7877436bc8fc766ce1409c3810a48b13da31ac39 --- /dev/null +++ b/checkpoints/dqn/README.md @@ -0,0 +1 @@ +DQN checkpoints will be saved here diff --git a/checkpoints/dqn/model_checkpoint.local b/checkpoints/dqn/model_checkpoint.local new file mode 100644 index 0000000000000000000000000000000000000000..cca8687b9d333d2a050d8f6910960ef80f0680b7 Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.local differ diff --git a/checkpoints/dqn/model_checkpoint.meta b/checkpoints/dqn/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..502812ecee5de1aa38caa05cc9766f7d2f04ba7e Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.meta differ diff --git a/checkpoints/dqn/model_checkpoint.optimizer b/checkpoints/dqn/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..5badbab158a8d003b069c5a2bb0b8628f25dd89b Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.optimizer differ diff --git a/checkpoints/dqn/model_checkpoint.target b/checkpoints/dqn/model_checkpoint.target new file mode 100644 index 0000000000000000000000000000000000000000..6a853ac88d2554eef6d0b6a414d87d9993c22998 Binary files /dev/null and b/checkpoints/dqn/model_checkpoint.target differ diff --git a/checkpoints/ppo/README.md b/checkpoints/ppo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fce8cd3428ab5c13ea082dd922054080beae4822 --- /dev/null +++ b/checkpoints/ppo/README.md @@ -0,0 +1 @@ +PPO checkpoints will be saved here diff --git a/checkpoints/ppo/model_checkpoint.meta b/checkpoints/ppo/model_checkpoint.meta new file mode 100644 index 0000000000000000000000000000000000000000..7617876cf3d7031f066a779fde687404b0a1cc6f Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.meta differ diff --git a/checkpoints/ppo/model_checkpoint.optimizer b/checkpoints/ppo/model_checkpoint.optimizer new file mode 100644 index 0000000000000000000000000000000000000000..b93d28155360433cbb1574b2e797bf1e293c2f6c Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.optimizer differ diff --git a/checkpoints/ppo/model_checkpoint.policy b/checkpoints/ppo/model_checkpoint.policy new file mode 100644 index 0000000000000000000000000000000000000000..bc21bc40897b530a65966b1cbbbaeb41835f7b69 Binary files /dev/null and b/checkpoints/ppo/model_checkpoint.policy differ diff --git a/docker_run.sh b/docker_run.sh index eeec29823b7603c361b65c0752f62a3b328d31c9..f14996e5e254c6d266e2f4d0bb47033b8547aaec 100755 --- a/docker_run.sh +++ b/docker_run.sh @@ -1,18 +1,18 @@ -#!/bin/bash - - -if [ -e environ_secret.sh ] -then - echo "Note: Gathering environment variables from environ_secret.sh" - source environ_secret.sh -else - echo "Note: Gathering environment variables from environ.sh" - source environ.sh -fi - -# Expected Env variables : in environ.sh -sudo docker run \ - --net=host \ - -v ./scratch/test-envs:/flatland_envs:z \ - -it ${IMAGE_NAME}:${IMAGE_TAG} \ - /home/aicrowd/run.sh +#!/bin/bash + + +if [ -e environ_secret.sh ] +then + echo "Note: Gathering environment variables from environ_secret.sh" + source environ_secret.sh +else + echo "Note: Gathering environment variables from environ.sh" + source environ.sh +fi + +# Expected Env variables : in environ.sh +sudo docker run \ + --net=host \ + -v ./scratch/test-envs:/flatland_envs:z \ + -it ${IMAGE_NAME}:${IMAGE_TAG} \ + /home/aicrowd/run.sh diff --git a/environment.yml b/environment.yml index a4109dcac710713829a6589eb06685d464984b7b..e79148acd6e4f97031ea83a2fc01f5f941a8f1e0 100644 --- a/environment.yml +++ b/environment.yml @@ -1,112 +1,112 @@ -name: flatland-rl -channels: - - anaconda - - conda-forge - - defaults -dependencies: - - tk=8.6.8 - - cairo=1.16.0 - - cairocffi=1.1.0 - - cairosvg=2.4.2 - - cffi=1.12.3 - - cssselect2=0.2.1 - - defusedxml=0.6.0 - - fontconfig=2.13.1 - - freetype=2.10.0 - - gettext=0.19.8.1 - - glib=2.58.3 - - icu=64.2 - - jpeg=9c - - libiconv=1.15 - - libpng=1.6.37 - - libtiff=4.0.10 - - libuuid=2.32.1 - - libxcb=1.13 - - libxml2=2.9.9 - - lz4-c=1.8.3 - - olefile=0.46 - - pcre=8.41 - - pillow=5.3.0 - - pixman=0.38.0 - - pthread-stubs=0.4 - - pycairo=1.18.1 - - pycparser=2.19 - - tinycss2=1.0.2 - - webencodings=0.5.1 - - xorg-kbproto=1.0.7 - - xorg-libice=1.0.10 - - xorg-libsm=1.2.3 - - xorg-libx11=1.6.8 - - xorg-libxau=1.0.9 - - xorg-libxdmcp=1.1.3 - - xorg-libxext=1.3.4 - - xorg-libxrender=0.9.10 - - xorg-renderproto=0.11.1 - - xorg-xextproto=7.3.0 - - xorg-xproto=7.0.31 - - zstd=1.4.0 - - _libgcc_mutex=0.1 - - ca-certificates=2019.5.15 - - certifi=2019.6.16 - - libedit=3.1.20181209 - - libffi=3.2.1 - - ncurses=6.1 - - openssl=1.1.1c - - pip=19.1.1 - - python=3.6.8 - - readline=7.0 - - setuptools=41.0.1 - - sqlite=3.29.0 - - wheel=0.33.4 - - xz=5.2.4 - - zlib=1.2.11 - - pip: - - atomicwrites==1.3.0 - - importlib-metadata==0.19 - - importlib-resources==1.0.2 - - attrs==19.1.0 - - chardet==3.0.4 - - click==7.0 - - cloudpickle==1.2.2 - - crowdai-api==0.1.21 - - cycler==0.10.0 - - filelock==3.0.12 - - flatland-rl==2.2.1 - - future==0.17.1 - - gym==0.14.0 - - idna==2.8 - - kiwisolver==1.1.0 - - lxml==4.4.0 - - matplotlib==3.1.1 - - more-itertools==7.2.0 - - msgpack==0.6.1 - - msgpack-numpy==0.4.4.3 - - numpy==1.17.0 - - packaging==19.0 - - pandas==0.25.0 - - pluggy==0.12.0 - - py==1.8.0 - - pyarrow==0.14.1 - - pyglet==1.3.2 - - pyparsing==2.4.1.1 - - pytest==5.0.1 - - pytest-runner==5.1 - - python-dateutil==2.8.0 - - python-gitlab==1.10.0 - - pytz==2019.1 - - torch==1.5.0 - - recordtype==1.3 - - redis==3.3.2 - - requests==2.22.0 - - scipy==1.3.1 - - six==1.12.0 - - svgutils==0.3.1 - - timeout-decorator==0.4.1 - - toml==0.10.0 - - tox==3.13.2 - - urllib3==1.25.3 - - ushlex==0.99.1 - - virtualenv==16.7.2 - - wcwidth==0.1.7 - - xarray==0.12.3 - - zipp==0.5.2 +name: flatland-rl +channels: + - anaconda + - conda-forge + - defaults +dependencies: + - tk=8.6.8 + - cairo=1.16.0 + - cairocffi=1.1.0 + - cairosvg=2.4.2 + - cffi=1.12.3 + - cssselect2=0.2.1 + - defusedxml=0.6.0 + - fontconfig=2.13.1 + - freetype=2.10.0 + - gettext=0.19.8.1 + - glib=2.58.3 + - icu=64.2 + - jpeg=9c + - libiconv=1.15 + - libpng=1.6.37 + - libtiff=4.0.10 + - libuuid=2.32.1 + - libxcb=1.13 + - libxml2=2.9.9 + - lz4-c=1.8.3 + - olefile=0.46 + - pcre=8.41 + - pillow=5.3.0 + - pixman=0.38.0 + - pthread-stubs=0.4 + - pycairo=1.18.1 + - pycparser=2.19 + - tinycss2=1.0.2 + - webencodings=0.5.1 + - xorg-kbproto=1.0.7 + - xorg-libice=1.0.10 + - xorg-libsm=1.2.3 + - xorg-libx11=1.6.8 + - xorg-libxau=1.0.9 + - xorg-libxdmcp=1.1.3 + - xorg-libxext=1.3.4 + - xorg-libxrender=0.9.10 + - xorg-renderproto=0.11.1 + - xorg-xextproto=7.3.0 + - xorg-xproto=7.0.31 + - zstd=1.4.0 + - _libgcc_mutex=0.1 + - ca-certificates=2019.5.15 + - certifi=2019.6.16 + - libedit=3.1.20181209 + - libffi=3.2.1 + - ncurses=6.1 + - openssl=1.1.1c + - pip=19.1.1 + - python=3.6.8 + - readline=7.0 + - setuptools=41.0.1 + - sqlite=3.29.0 + - wheel=0.33.4 + - xz=5.2.4 + - zlib=1.2.11 + - pip: + - atomicwrites==1.3.0 + - importlib-metadata==0.19 + - importlib-resources==1.0.2 + - attrs==19.1.0 + - chardet==3.0.4 + - click==7.0 + - cloudpickle==1.2.2 + - crowdai-api==0.1.21 + - cycler==0.10.0 + - filelock==3.0.12 + - flatland-rl==2.2.1 + - future==0.17.1 + - gym==0.14.0 + - idna==2.8 + - kiwisolver==1.1.0 + - lxml==4.4.0 + - matplotlib==3.1.1 + - more-itertools==7.2.0 + - msgpack==0.6.1 + - msgpack-numpy==0.4.4.3 + - numpy==1.17.0 + - packaging==19.0 + - pandas==0.25.0 + - pluggy==0.12.0 + - py==1.8.0 + - pyarrow==0.14.1 + - pyglet==1.3.2 + - pyparsing==2.4.1.1 + - pytest==5.0.1 + - pytest-runner==5.1 + - python-dateutil==2.8.0 + - python-gitlab==1.10.0 + - pytz==2019.1 + - torch==1.5.0 + - recordtype==1.3 + - redis==3.3.2 + - requests==2.22.0 + - scipy==1.3.1 + - six==1.12.0 + - svgutils==0.3.1 + - timeout-decorator==0.4.1 + - toml==0.10.0 + - tox==3.13.2 + - urllib3==1.25.3 + - ushlex==0.99.1 + - virtualenv==16.7.2 + - wcwidth==0.1.7 + - xarray==0.12.3 + - zipp==0.5.2 diff --git a/my_observation_builder.py b/my_observation_builder.py index 915ff839cb6cad922fe6ca7513465a4a9edde705..482eecfc36c53e9390b1600b7d0ab9a02f032d9a 100644 --- a/my_observation_builder.py +++ b/my_observation_builder.py @@ -1,101 +1,101 @@ -#!/usr/bin/env python - -import collections -from typing import Optional, List, Dict, Tuple - -import numpy as np - -from flatland.core.env import Environment -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.envs.agent_utils import RailAgentStatus, EnvAgent - - -class CustomObservationBuilder(ObservationBuilder): - """ - Template for building a custom observation builder for the RailEnv class - - The observation in this case composed of the following elements: - - - transition map array with dimensions (env.height, env.width),\ - where the value at X,Y will represent the 16 bits encoding of transition-map at that point. - - - the individual agent object (with position, direction, target information available) - - """ - def __init__(self): - super(CustomObservationBuilder, self).__init__() - - def set_env(self, env: Environment): - super().set_env(env) - # Note : - # The instantiations which depend on parameters of the Env object should be - # done here, as it is only here that the updated self.env instance is available - self.rail_obs = np.zeros((self.env.height, self.env.width)) - - def reset(self): - """ - Called internally on every env.reset() call, - to reset any observation specific variables that are being used - """ - self.rail_obs[:] = 0 - for _x in range(self.env.width): - for _y in range(self.env.height): - # Get the transition map value at location _x, _y - transition_value = self.env.rail.get_full_transitions(_y, _x) - self.rail_obs[_y, _x] = transition_value - - def get(self, handle: int = 0): - """ - Returns the built observation for a single agent with handle : handle - - In this particular case, we return - - the global transition_map of the RailEnv, - - a tuple containing, the current agent's: - - state - - position - - direction - - initial_position - - target - """ - - agent = self.env.agents[handle] - """ - Available information for each agent object : - - - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE] - - agent.position : Current position of the agent - - agent.direction : Current direction of the agent - - agent.initial_position : Initial Position of the agent - - agent.target : Target position of the agent - """ - - status = agent.status - position = agent.position - direction = agent.direction - initial_position = agent.initial_position - target = agent.target - - - """ - You can also optionally access the states of the rest of the agents by - using something similar to - - for i in range(len(self.env.agents)): - other_agent: EnvAgent = self.env.agents[i] - - # ignore other agents not in the grid any more - if other_agent.status == RailAgentStatus.DONE_REMOVED: - continue - - ## Gather other agent specific params - other_agent_status = other_agent.status - other_agent_position = other_agent.position - other_agent_direction = other_agent.direction - other_agent_initial_position = other_agent.initial_position - other_agent_target = other_agent.target - - ## Do something nice here if you wish - """ - return self.rail_obs, (status, position, direction, initial_position, target) - +#!/usr/bin/env python + +import collections +from typing import Optional, List, Dict, Tuple + +import numpy as np + +from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.envs.agent_utils import RailAgentStatus, EnvAgent + + +class CustomObservationBuilder(ObservationBuilder): + """ + Template for building a custom observation builder for the RailEnv class + + The observation in this case composed of the following elements: + + - transition map array with dimensions (env.height, env.width),\ + where the value at X,Y will represent the 16 bits encoding of transition-map at that point. + + - the individual agent object (with position, direction, target information available) + + """ + def __init__(self): + super(CustomObservationBuilder, self).__init__() + + def set_env(self, env: Environment): + super().set_env(env) + # Note : + # The instantiations which depend on parameters of the Env object should be + # done here, as it is only here that the updated self.env instance is available + self.rail_obs = np.zeros((self.env.height, self.env.width)) + + def reset(self): + """ + Called internally on every env.reset() call, + to reset any observation specific variables that are being used + """ + self.rail_obs[:] = 0 + for _x in range(self.env.width): + for _y in range(self.env.height): + # Get the transition map value at location _x, _y + transition_value = self.env.rail.get_full_transitions(_y, _x) + self.rail_obs[_y, _x] = transition_value + + def get(self, handle: int = 0): + """ + Returns the built observation for a single agent with handle : handle + + In this particular case, we return + - the global transition_map of the RailEnv, + - a tuple containing, the current agent's: + - state + - position + - direction + - initial_position + - target + """ + + agent = self.env.agents[handle] + """ + Available information for each agent object : + + - agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE] + - agent.position : Current position of the agent + - agent.direction : Current direction of the agent + - agent.initial_position : Initial Position of the agent + - agent.target : Target position of the agent + """ + + status = agent.status + position = agent.position + direction = agent.direction + initial_position = agent.initial_position + target = agent.target + + + """ + You can also optionally access the states of the rest of the agents by + using something similar to + + for i in range(len(self.env.agents)): + other_agent: EnvAgent = self.env.agents[i] + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + ## Gather other agent specific params + other_agent_status = other_agent.status + other_agent_position = other_agent.position + other_agent_direction = other_agent.direction + other_agent_initial_position = other_agent.initial_position + other_agent_target = other_agent.target + + ## Do something nice here if you wish + """ + return self.rail_obs, (status, position, direction, initial_position, target) + diff --git a/run.py b/run.py index 855dac19dcfdd6d4a971a8e74ee36c57591b45fe..27c3107896fde363315d78eadd778ffe2ac99f7e 100644 --- a/run.py +++ b/run.py @@ -1,174 +1,175 @@ -import time - -import numpy as np -from flatland.envs.agent_utils import RailAgentStatus -from flatland.evaluators.client import FlatlandRemoteClient - -##################################################################### -# Instantiate a Remote Client -##################################################################### -from src.extra import Extra -from src.observations import MyTreeObsForRailEnv - -remote_client = FlatlandRemoteClient() - - -##################################################################### -# Define your custom controller -# -# which can take an observation, and the number of agents and -# compute the necessary action for this step for all (or even some) -# of the agents -##################################################################### -def my_controller(extra: Extra, observation, my_observation_builder): - return extra.rl_agent_act(observation, my_observation_builder.max_depth) - - -##################################################################### -# Instantiate your custom Observation Builder -# -# You can build your own Observation Builder by following -# the example here : -# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14 -##################################################################### -my_observation_builder = MyTreeObsForRailEnv(max_depth=3) - -# Or if you want to use your own approach to build the observation from the env_step, -# please feel free to pass a DummyObservationBuilder() object as mentioned below, -# and that will just return a placeholder True for all observation, and you -# can build your own Observation for all the agents as your please. -# my_observation_builder = DummyObservationBuilder() - - -##################################################################### -# Main evaluation loop -# -# This iterates over an arbitrary number of env evaluations -##################################################################### -evaluation_number = 0 -while True: - - evaluation_number += 1 - # Switch to a new evaluation environemnt - # - # a remote_client.env_create is similar to instantiating a - # RailEnv and then doing a env.reset() - # hence it returns the first observation from the - # env.reset() - # - # You can also pass your custom observation_builder object - # to allow you to have as much control as you wish - # over the observation of your choice. - time_start = time.time() - observation, info = remote_client.env_create( - obs_builder_object=my_observation_builder - ) - if not observation: - # - # If the remote_client returns False on a `env_create` call, - # then it basically means that your agent has already been - # evaluated on all the required evaluation environments, - # and hence its safe to break out of the main evaluation loop - break - - print("Evaluation Number : {}".format(evaluation_number)) - - ##################################################################### - # Access to a local copy of the environment - # - ##################################################################### - # Note: You can access a local copy of the environment - # by using : - # remote_client.env - # - # But please ensure to not make any changes (or perform any action) on - # the local copy of the env, as then it will diverge from - # the state of the remote copy of the env, and the observations and - # rewards, etc will behave unexpectedly - # - # You can however probe the local_env instance to get any information - # you need from the environment. It is a valid RailEnv instance. - local_env = remote_client.env - number_of_agents = len(local_env.agents) - - # Now we enter into another infinite loop where we - # compute the actions for all the individual steps in this episode - # until the episode is `done` - # - # An episode is considered done when either all the agents have - # reached their target destination - # or when the number of time steps has exceed max_time_steps, which - # is defined by : - # - # max_time_steps = int(4 * 2 * (env.width + env.height + 20)) - # - time_taken_by_controller = [] - time_taken_per_step = [] - steps = 0 - - extra = Extra(local_env) - env_creation_time = time.time() - time_start - print("Env Creation Time : ", env_creation_time) - print("Agents : ", extra.env.get_num_agents()) - print("w : ", extra.env.width) - print("h : ", extra.env.height) - - while True: - ##################################################################### - # Evaluation of a single episode - # - ##################################################################### - # Compute the action for this step by using the previously - # defined controller - time_start = time.time() - action = my_controller(extra, observation, my_observation_builder) - time_taken = time.time() - time_start - time_taken_by_controller.append(time_taken) - - # Perform the chosen action on the environment. - # The action gets applied to both the local and the remote copy - # of the environment instance, and the observation is what is - # returned by the local copy of the env, and the rewards, and done and info - # are returned by the remote copy of the env - time_start = time.time() - observation, all_rewards, done, info = remote_client.env_step(action) - steps += 1 - time_taken = time.time() - time_start - time_taken_per_step.append(time_taken) - - total_done = 0 - for a in range(local_env.get_num_agents()): - x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) - total_done += int(x) - print("total_done:", total_done) - - if done['__all__']: - print("Reward : ", sum(list(all_rewards.values()))) - # - # When done['__all__'] == True, then the evaluation of this - # particular Env instantiation is complete, and we can break out - # of this loop, and move onto the next Env evaluation - break - - np_time_taken_by_controller = np.array(time_taken_by_controller) - np_time_taken_per_step = np.array(time_taken_per_step) - print("=" * 100) - print("=" * 100) - print("Evaluation Number : ", evaluation_number) - print("Current Env Path : ", remote_client.current_env_path) - print("Env Creation Time : ", env_creation_time) - print("Number of Steps : ", steps) - print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), - np_time_taken_by_controller.std()) - print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std()) - print("=" * 100) - -print("Evaluation of all environments complete...") -######################################################################## -# Submit your Results -# -# Please do not forget to include this call, as this triggers the -# final computation of the score statistics, video generation, etc -# and is necesaary to have your submission marked as successfully evaluated -######################################################################## -print(remote_client.submit()) +import time + +import numpy as np +from flatland.envs.agent_utils import RailAgentStatus +from flatland.evaluators.client import FlatlandRemoteClient + +##################################################################### +# Instantiate a Remote Client +##################################################################### +from src.extra import Extra + +remote_client = FlatlandRemoteClient() + + +##################################################################### +# Define your custom controller +# +# which can take an observation, and the number of agents and +# compute the necessary action for this step for all (or even some) +# of the agents +##################################################################### +def my_controller(extra: Extra, observation, info): + return extra.rl_agent_act(observation, info) + + +##################################################################### +# Instantiate your custom Observation Builder +# +# You can build your own Observation Builder by following +# the example here : +# https://gitlab.aicrowd.com/flatland/flatland/blob/master/flatland/envs/observations.py#L14 +##################################################################### +my_observation_builder = Extra(max_depth=2) + +# Or if you want to use your own approach to build the observation from the env_step, +# please feel free to pass a DummyObservationBuilder() object as mentioned below, +# and that will just return a placeholder True for all observation, and you +# can build your own Observation for all the agents as your please. +# my_observation_builder = DummyObservationBuilder() + + +##################################################################### +# Main evaluation loop +# +# This iterates over an arbitrary number of env evaluations +##################################################################### +evaluation_number = 0 +while True: + + evaluation_number += 1 + # Switch to a new evaluation environemnt + # + # a remote_client.env_create is similar to instantiating a + # RailEnv and then doing a env.reset() + # hence it returns the first observation from the + # env.reset() + # + # You can also pass your custom observation_builder object + # to allow you to have as much control as you wish + # over the observation of your choice. + time_start = time.time() + observation, info = remote_client.env_create( + obs_builder_object=my_observation_builder + ) + if not observation: + # + # If the remote_client returns False on a `env_create` call, + # then it basically means that your agent has already been + # evaluated on all the required evaluation environments, + # and hence its safe to break out of the main evaluation loop + break + + print("Evaluation Number : {}".format(evaluation_number)) + + ##################################################################### + # Access to a local copy of the environment + # + ##################################################################### + # Note: You can access a local copy of the environment + # by using : + # remote_client.env + # + # But please ensure to not make any changes (or perform any action) on + # the local copy of the env, as then it will diverge from + # the state of the remote copy of the env, and the observations and + # rewards, etc will behave unexpectedly + # + # You can however probe the local_env instance to get any information + # you need from the environment. It is a valid RailEnv instance. + local_env = remote_client.env + number_of_agents = len(local_env.agents) + + # Now we enter into another infinite loop where we + # compute the actions for all the individual steps in this episode + # until the episode is `done` + # + # An episode is considered done when either all the agents have + # reached their target destination + # or when the number of time steps has exceed max_time_steps, which + # is defined by : + # + # max_time_steps = int(4 * 2 * (env.width + env.height + 20)) + # + time_taken_by_controller = [] + time_taken_per_step = [] + steps = 0 + + extra = my_observation_builder + env_creation_time = time.time() - time_start + print("Env Creation Time : ", env_creation_time) + print("Agents : ", extra.env.get_num_agents()) + print("w : ", extra.env.width) + print("h : ", extra.env.height) + + while True: + ##################################################################### + # Evaluation of a single episode + # + ##################################################################### + # Compute the action for this step by using the previously + # defined controller + time_start = time.time() + action = my_controller(extra, observation, info) + time_taken = time.time() - time_start + time_taken_by_controller.append(time_taken) + + # Perform the chosen action on the environment. + # The action gets applied to both the local and the remote copy + # of the environment instance, and the observation is what is + # returned by the local copy of the env, and the rewards, and done and info + # are returned by the remote copy of the env + time_start = time.time() + observation, all_rewards, done, info = remote_client.env_step(action) + steps += 1 + time_taken = time.time() - time_start + time_taken_per_step.append(time_taken) + + total_done = 0 + total_active = 0 + for a in range(local_env.get_num_agents()): + x = (local_env.agents[a].status in [RailAgentStatus.DONE, RailAgentStatus.DONE_REMOVED]) + total_done += int(x) + total_active += int(local_env.agents[a].status == RailAgentStatus.ACTIVE) + # print("total_done:", total_done, "\ttotal_active", total_active, "\t num agents", local_env.get_num_agents()) + + if done['__all__']: + print("Reward : ", sum(list(all_rewards.values()))) + # + # When done['__all__'] == True, then the evaluation of this + # particular Env instantiation is complete, and we can break out + # of this loop, and move onto the next Env evaluation + break + + np_time_taken_by_controller = np.array(time_taken_by_controller) + np_time_taken_per_step = np.array(time_taken_per_step) + print("=" * 100) + print("=" * 100) + print("Evaluation Number : ", evaluation_number) + print("Current Env Path : ", remote_client.current_env_path) + print("Env Creation Time : ", env_creation_time) + print("Number of Steps : ", steps) + print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(), + np_time_taken_by_controller.std()) + print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std()) + print("=" * 100) + +print("Evaluation of all environments complete...") +######################################################################## +# Submit your Results +# +# Please do not forget to include this call, as this triggers the +# final computation of the score statistics, video generation, etc +# and is necesaary to have your submission marked as successfully evaluated +######################################################################## +print(remote_client.submit()) diff --git a/run.sh b/run.sh index 953c1660c6abafcc0a474c526ef7ffcedef6a5d8..c6d89fe35ae407a87904f15b601cc14974e8ef6a 100755 --- a/run.sh +++ b/run.sh @@ -1,3 +1,3 @@ -#!/bin/bash - -python ./run.py +#!/bin/bash + +python ./run.py diff --git a/src/agent/dueling_double_dqn.py b/src/agent/dueling_double_dqn.py index 6d82fded3552e4b9d72c22dc4f85fa1fde574e48..f08e17602db84265079fa5f6f7d8422d5a5d4c89 100644 --- a/src/agent/dueling_double_dqn.py +++ b/src/agent/dueling_double_dqn.py @@ -1,512 +1,512 @@ -import torch -import torch.optim as optim - -BUFFER_SIZE = int(1e5) # replay buffer size -BATCH_SIZE = 512 # minibatch size -GAMMA = 0.99 # discount factor 0.99 -TAU = 0.5e-3 # for soft update of target parameters -LR = 0.5e-4 # learning rate 0.5e-4 works - -# how often to update the network -UPDATE_EVERY = 20 -UPDATE_EVERY_FINAL = 10 -UPDATE_EVERY_AGENT_CANT_CHOOSE = 200 - - -double_dqn = True # If using double dqn algorithm -input_channels = 5 # Number of Input channels - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -device = torch.device("cpu") -print(device) - -USE_OPTIMIZER = optim.Adam -# USE_OPTIMIZER = optim.RMSprop -print(USE_OPTIMIZER) - - -class Agent: - """Interacts with and learns from the environment.""" - - def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): - """Initialize an Agent object. - - Params - ====== - state_size (int): dimension of each state - action_size (int): dimension of each action - seed (int): random seed - """ - self.state_size = state_size - self.action_size = action_size - self.seed = random.seed(seed) - self.version = net_type - self.double_dqn = double_dqn - # Q-Network - if self.version == "Conv": - self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - else: - self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - - self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) - - # Replay memory - self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) - self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) - self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) - - self.final_step = {} - - # Initialize time step (for updating every UPDATE_EVERY steps) - self.t_step = 0 - self.t_step_final = 0 - self.t_step_agent_can_not_choose = 0 - - def save(self, filename): - torch.save(self.qnetwork_local.state_dict(), filename + ".local") - torch.save(self.qnetwork_target.state_dict(), filename + ".target") - - def load(self, filename): - if os.path.exists(filename + ".local"): - self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) - print(filename + ".local -> ok") - if os.path.exists(filename + ".target"): - self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) - print(filename + ".target -> ok") - self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) - - def _update_model(self, switch=0): - # Learn every UPDATE_EVERY time steps. - # If enough samples are available in memory, get random subset and learn - if switch == 0: - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0: - if len(self.memory) > BATCH_SIZE: - experiences = self.memory.sample() - self.learn(experiences, GAMMA) - elif switch == 1: - self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL - if self.t_step_final == 0: - if len(self.memory_final) > BATCH_SIZE: - experiences = self.memory_final.sample() - self.learn(experiences, GAMMA) - else: - # If enough samples are available in memory_agent_can_not_choose, get random subset and learn - self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE - if self.t_step_agent_can_not_choose == 0: - if len(self.memory_agent_can_not_choose) > BATCH_SIZE: - experiences = self.memory_agent_can_not_choose.sample() - self.learn(experiences, GAMMA) - - def step(self, state, action, reward, next_state, done): - # Save experience in replay memory - self.memory.add(state, action, reward, next_state, done) - self._update_model(0) - - def step_agent_can_not_choose(self, state, action, reward, next_state, done): - # Save experience in replay memory_agent_can_not_choose - self.memory_agent_can_not_choose.add(state, action, reward, next_state, done) - self._update_model(2) - - def add_final_step(self, agent_handle, state, action, reward, next_state, done): - if self.final_step.get(agent_handle) is None: - self.final_step.update({agent_handle: [state, action, reward, next_state, done]}) - - def make_final_step(self, additional_reward=0): - for _, item in self.final_step.items(): - state = item[0] - action = item[1] - reward = item[2] + additional_reward - next_state = item[3] - done = item[4] - self.memory_final.add(state, action, reward, next_state, done) - self._update_model(1) - self._reset_final_step() - - def _reset_final_step(self): - self.final_step = {} - - def act(self, state, eps=0.): - """Returns actions for given state as per current policy. - - Params - ====== - state (array_like): current state - eps (float): epsilon, for epsilon-greedy action selection - """ - state = torch.from_numpy(state).float().unsqueeze(0).to(device) - self.qnetwork_local.eval() - with torch.no_grad(): - action_values = self.qnetwork_local(state) - self.qnetwork_local.train() - - # Epsilon-greedy action selection - if random.random() > eps: - return np.argmax(action_values.cpu().data.numpy()) - else: - return random.choice(np.arange(self.action_size)) - - def learn(self, experiences, gamma): - - """Update value parameters using given batch of experience tuples. - - Params - ====== - experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples - gamma (float): discount factor - """ - states, actions, rewards, next_states, dones = experiences - - # Get expected Q values from local model - Q_expected = self.qnetwork_local(states).gather(1, actions) - - if self.double_dqn: - # Double DQN - q_best_action = self.qnetwork_local(next_states).max(1)[1] - Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) - else: - # DQN - Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) - - # Compute Q targets for current states - - Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) - - # Compute loss - loss = F.mse_loss(Q_expected, Q_targets) - # Minimize the loss - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - # ------------------- update target network ------------------- # - self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) - - def soft_update(self, local_model, target_model, tau): - """Soft update model parameters. - θ_target = τ*θ_local + (1 - τ)*θ_target - - Params - ====== - local_model (PyTorch model): weights will be copied from - target_model (PyTorch model): weights will be copied to - tau (float): interpolation parameter - """ - for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): - target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) - - -class ReplayBuffer: - """Fixed-size buffer to store experience tuples.""" - - def __init__(self, action_size, buffer_size, batch_size, seed): - """Initialize a ReplayBuffer object. - - Params - ====== - action_size (int): dimension of each action - buffer_size (int): maximum size of buffer - batch_size (int): size of each training batch - seed (int): random seed - """ - self.action_size = action_size - self.memory = deque(maxlen=buffer_size) - self.batch_size = batch_size - self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) - self.seed = random.seed(seed) - - def add(self, state, action, reward, next_state, done): - """Add a new experience to memory.""" - e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) - self.memory.append(e) - - def sample(self): - """Randomly sample a batch of experiences from memory.""" - experiences = random.sample(self.memory, k=self.batch_size) - - states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ - .float().to(device) - actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ - .long().to(device) - rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ - .float().to(device) - next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ - .float().to(device) - dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ - .float().to(device) - - return (states, actions, rewards, next_states, dones) - - def __len__(self): - """Return the current size of internal memory.""" - return len(self.memory) - - def __v_stack_impr(self, states): - sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 - np_states = np.reshape(np.array(states), (len(states), sub_dim)) - return np_states - - -import copy -import os -import random -from collections import namedtuple, deque, Iterable - -import numpy as np -import torch -import torch.nn.functional as F -import torch.optim as optim - -from src.agent.model import QNetwork2, QNetwork - -BUFFER_SIZE = int(1e5) # replay buffer size -BATCH_SIZE = 512 # minibatch size -GAMMA = 0.95 # discount factor 0.99 -TAU = 0.5e-4 # for soft update of target parameters -LR = 0.5e-3 # learning rate 0.5e-4 works - -# how often to update the network -UPDATE_EVERY = 40 -UPDATE_EVERY_FINAL = 1000 -UPDATE_EVERY_AGENT_CANT_CHOOSE = 200 - -double_dqn = True # If using double dqn algorithm -input_channels = 5 # Number of Input channels - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -device = torch.device("cpu") -print(device) - -USE_OPTIMIZER = optim.Adam -# USE_OPTIMIZER = optim.RMSprop -print(USE_OPTIMIZER) - - -class Agent: - """Interacts with and learns from the environment.""" - - def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): - """Initialize an Agent object. - - Params - ====== - state_size (int): dimension of each state - action_size (int): dimension of each action - seed (int): random seed - """ - self.state_size = state_size - self.action_size = action_size - self.seed = random.seed(seed) - self.version = net_type - self.double_dqn = double_dqn - # Q-Network - if self.version == "Conv": - self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - else: - self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) - self.qnetwork_target = copy.deepcopy(self.qnetwork_local) - - self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) - - # Replay memory - self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) - self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) - self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) - - self.final_step = {} - - # Initialize time step (for updating every UPDATE_EVERY steps) - self.t_step = 0 - self.t_step_final = 0 - self.t_step_agent_can_not_choose = 0 - - def save(self, filename): - torch.save(self.qnetwork_local.state_dict(), filename + ".local") - torch.save(self.qnetwork_target.state_dict(), filename + ".target") - - def load(self, filename): - print("try to load: " + filename) - if os.path.exists(filename + ".local"): - self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) - print(filename + ".local -> ok") - if os.path.exists(filename + ".target"): - self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) - print(filename + ".target -> ok") - self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) - - def _update_model(self, switch=0): - # Learn every UPDATE_EVERY time steps. - # If enough samples are available in memory, get random subset and learn - if switch == 0: - self.t_step = (self.t_step + 1) % UPDATE_EVERY - if self.t_step == 0: - if len(self.memory) > BATCH_SIZE: - experiences = self.memory.sample() - self.learn(experiences, GAMMA) - elif switch == 1: - self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL - if self.t_step_final == 0: - if len(self.memory_final) > BATCH_SIZE: - experiences = self.memory_final.sample() - self.learn(experiences, GAMMA) - else: - # If enough samples are available in memory_agent_can_not_choose, get random subset and learn - self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE - if self.t_step_agent_can_not_choose == 0: - if len(self.memory_agent_can_not_choose) > BATCH_SIZE: - experiences = self.memory_agent_can_not_choose.sample() - self.learn(experiences, GAMMA) - - def step(self, state, action, reward, next_state, done): - # Save experience in replay memory - self.memory.add(state, action, reward, next_state, done) - self._update_model(0) - - def step_agent_can_not_choose(self, state, action, reward, next_state, done): - # Save experience in replay memory_agent_can_not_choose - self.memory_agent_can_not_choose.add(state, action, reward, next_state, done) - self._update_model(2) - - def add_final_step(self, agent_handle, state, action, reward, next_state, done): - if self.final_step.get(agent_handle) is None: - self.final_step.update({agent_handle: [state, action, reward, next_state, done]}) - return True - else: - return False - - def make_final_step(self, additional_reward=0): - for _, item in self.final_step.items(): - state = item[0] - action = item[1] - reward = item[2] + additional_reward - next_state = item[3] - done = item[4] - self.memory_final.add(state, action, reward, next_state, done) - self._update_model(1) - self._reset_final_step() - - def _reset_final_step(self): - self.final_step = {} - - def act(self, state, eps=0.): - """Returns actions for given state as per current policy. - - Params - ====== - state (array_like): current state - eps (float): epsilon, for epsilon-greedy action selection - """ - state = torch.from_numpy(state).float().unsqueeze(0).to(device) - self.qnetwork_local.eval() - with torch.no_grad(): - action_values = self.qnetwork_local(state) - self.qnetwork_local.train() - - # Epsilon-greedy action selection - if random.random() > eps: - return np.argmax(action_values.cpu().data.numpy()), False - else: - return random.choice(np.arange(self.action_size)), True - - def learn(self, experiences, gamma): - - """Update value parameters using given batch of experience tuples. - - Params - ====== - experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples - gamma (float): discount factor - """ - states, actions, rewards, next_states, dones = experiences - - # Get expected Q values from local model - Q_expected = self.qnetwork_local(states).gather(1, actions) - - if self.double_dqn: - # Double DQN - q_best_action = self.qnetwork_local(next_states).max(1)[1] - Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) - else: - # DQN - Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) - - # Compute Q targets for current states - - Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) - - # Compute loss - loss = F.mse_loss(Q_expected, Q_targets) - # Minimize the loss - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - # ------------------- update target network ------------------- # - self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) - - def soft_update(self, local_model, target_model, tau): - """Soft update model parameters. - θ_target = τ*θ_local + (1 - τ)*θ_target - - Params - ====== - local_model (PyTorch model): weights will be copied from - target_model (PyTorch model): weights will be copied to - tau (float): interpolation parameter - """ - for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): - target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) - - -class ReplayBuffer: - """Fixed-size buffer to store experience tuples.""" - - def __init__(self, action_size, buffer_size, batch_size, seed): - """Initialize a ReplayBuffer object. - - Params - ====== - action_size (int): dimension of each action - buffer_size (int): maximum size of buffer - batch_size (int): size of each training batch - seed (int): random seed - """ - self.action_size = action_size - self.memory = deque(maxlen=buffer_size) - self.batch_size = batch_size - self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) - self.seed = random.seed(seed) - - def add(self, state, action, reward, next_state, done): - """Add a new experience to memory.""" - e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) - self.memory.append(e) - - def sample(self): - """Randomly sample a batch of experiences from memory.""" - experiences = random.sample(self.memory, k=self.batch_size) - - states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ - .float().to(device) - actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ - .long().to(device) - rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ - .float().to(device) - next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ - .float().to(device) - dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ - .float().to(device) - - return (states, actions, rewards, next_states, dones) - - def __len__(self): - """Return the current size of internal memory.""" - return len(self.memory) - - def __v_stack_impr(self, states): - sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 - np_states = np.reshape(np.array(states), (len(states), sub_dim)) - return np_states +import torch +import torch.optim as optim + +BUFFER_SIZE = int(1e5) # replay buffer size +BATCH_SIZE = 512 # minibatch size +GAMMA = 0.99 # discount factor 0.99 +TAU = 0.5e-3 # for soft update of target parameters +LR = 0.5e-4 # learning rate 0.5e-4 works + +# how often to update the network +UPDATE_EVERY = 20 +UPDATE_EVERY_FINAL = 10 +UPDATE_EVERY_AGENT_CANT_CHOOSE = 200 + + +double_dqn = True # If using double dqn algorithm +input_channels = 5 # Number of Input channels + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") +print(device) + +USE_OPTIMIZER = optim.Adam +# USE_OPTIMIZER = optim.RMSprop +print(USE_OPTIMIZER) + + +class Agent: + """Interacts with and learns from the environment.""" + + def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): + """Initialize an Agent object. + + Params + ====== + state_size (int): dimension of each state + action_size (int): dimension of each action + seed (int): random seed + """ + self.state_size = state_size + self.action_size = action_size + self.seed = random.seed(seed) + self.version = net_type + self.double_dqn = double_dqn + # Q-Network + if self.version == "Conv": + self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + else: + self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + # Replay memory + self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + + self.final_step = {} + + # Initialize time step (for updating every UPDATE_EVERY steps) + self.t_step = 0 + self.t_step_final = 0 + self.t_step_agent_can_not_choose = 0 + + def save(self, filename): + torch.save(self.qnetwork_local.state_dict(), filename + ".local") + torch.save(self.qnetwork_target.state_dict(), filename + ".target") + + def load(self, filename): + if os.path.exists(filename + ".local"): + self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + print(filename + ".local -> ok") + if os.path.exists(filename + ".target"): + self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + print(filename + ".target -> ok") + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + def _update_model(self, switch=0): + # Learn every UPDATE_EVERY time steps. + # If enough samples are available in memory, get random subset and learn + if switch == 0: + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0: + if len(self.memory) > BATCH_SIZE: + experiences = self.memory.sample() + self.learn(experiences, GAMMA) + elif switch == 1: + self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL + if self.t_step_final == 0: + if len(self.memory_final) > BATCH_SIZE: + experiences = self.memory_final.sample() + self.learn(experiences, GAMMA) + else: + # If enough samples are available in memory_agent_can_not_choose, get random subset and learn + self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE + if self.t_step_agent_can_not_choose == 0: + if len(self.memory_agent_can_not_choose) > BATCH_SIZE: + experiences = self.memory_agent_can_not_choose.sample() + self.learn(experiences, GAMMA) + + def step(self, state, action, reward, next_state, done): + # Save experience in replay memory + self.memory.add(state, action, reward, next_state, done) + self._update_model(0) + + def step_agent_can_not_choose(self, state, action, reward, next_state, done): + # Save experience in replay memory_agent_can_not_choose + self.memory_agent_can_not_choose.add(state, action, reward, next_state, done) + self._update_model(2) + + def add_final_step(self, agent_handle, state, action, reward, next_state, done): + if self.final_step.get(agent_handle) is None: + self.final_step.update({agent_handle: [state, action, reward, next_state, done]}) + + def make_final_step(self, additional_reward=0): + for _, item in self.final_step.items(): + state = item[0] + action = item[1] + reward = item[2] + additional_reward + next_state = item[3] + done = item[4] + self.memory_final.add(state, action, reward, next_state, done) + self._update_model(1) + self._reset_final_step() + + def _reset_final_step(self): + self.final_step = {} + + def act(self, state, eps=0.): + """Returns actions for given state as per current policy. + + Params + ====== + state (array_like): current state + eps (float): epsilon, for epsilon-greedy action selection + """ + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + self.qnetwork_local.train() + + # Epsilon-greedy action selection + if random.random() > eps: + return np.argmax(action_values.cpu().data.numpy()) + else: + return random.choice(np.arange(self.action_size)) + + def learn(self, experiences, gamma): + + """Update value parameters using given batch of experience tuples. + + Params + ====== + experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples + gamma (float): discount factor + """ + states, actions, rewards, next_states, dones = experiences + + # Get expected Q values from local model + Q_expected = self.qnetwork_local(states).gather(1, actions) + + if self.double_dqn: + # Double DQN + q_best_action = self.qnetwork_local(next_states).max(1)[1] + Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) + else: + # DQN + Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + + # Compute Q targets for current states + + Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) + + # Compute loss + loss = F.mse_loss(Q_expected, Q_targets) + # Minimize the loss + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # ------------------- update target network ------------------- # + self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) + + def soft_update(self, local_model, target_model, tau): + """Soft update model parameters. + θ_target = τ*θ_local + (1 - τ)*θ_target + + Params + ====== + local_model (PyTorch model): weights will be copied from + target_model (PyTorch model): weights will be copied to + tau (float): interpolation parameter + """ + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) + + +class ReplayBuffer: + """Fixed-size buffer to store experience tuples.""" + + def __init__(self, action_size, buffer_size, batch_size, seed): + """Initialize a ReplayBuffer object. + + Params + ====== + action_size (int): dimension of each action + buffer_size (int): maximum size of buffer + batch_size (int): size of each training batch + seed (int): random seed + """ + self.action_size = action_size + self.memory = deque(maxlen=buffer_size) + self.batch_size = batch_size + self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) + self.seed = random.seed(seed) + + def add(self, state, action, reward, next_state, done): + """Add a new experience to memory.""" + e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) + self.memory.append(e) + + def sample(self): + """Randomly sample a batch of experiences from memory.""" + experiences = random.sample(self.memory, k=self.batch_size) + + states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ + .float().to(device) + actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ + .long().to(device) + rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ + .float().to(device) + next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ + .float().to(device) + dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ + .float().to(device) + + return (states, actions, rewards, next_states, dones) + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) + + def __v_stack_impr(self, states): + sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 + np_states = np.reshape(np.array(states), (len(states), sub_dim)) + return np_states + + +import copy +import os +import random +from collections import namedtuple, deque, Iterable + +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim + +from src.agent.model import QNetwork2, QNetwork + +BUFFER_SIZE = int(1e5) # replay buffer size +BATCH_SIZE = 512 # minibatch size +GAMMA = 0.95 # discount factor 0.99 +TAU = 0.5e-4 # for soft update of target parameters +LR = 0.5e-3 # learning rate 0.5e-4 works + +# how often to update the network +UPDATE_EVERY = 40 +UPDATE_EVERY_FINAL = 1000 +UPDATE_EVERY_AGENT_CANT_CHOOSE = 200 + +double_dqn = True # If using double dqn algorithm +input_channels = 5 # Number of Input channels + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +device = torch.device("cpu") +print(device) + +USE_OPTIMIZER = optim.Adam +# USE_OPTIMIZER = optim.RMSprop +print(USE_OPTIMIZER) + + +class Agent: + """Interacts with and learns from the environment.""" + + def __init__(self, state_size, action_size, net_type, seed, double_dqn=True, input_channels=5): + """Initialize an Agent object. + + Params + ====== + state_size (int): dimension of each state + action_size (int): dimension of each action + seed (int): random seed + """ + self.state_size = state_size + self.action_size = action_size + self.seed = random.seed(seed) + self.version = net_type + self.double_dqn = double_dqn + # Q-Network + if self.version == "Conv": + self.qnetwork_local = QNetwork2(state_size, action_size, seed, input_channels).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + else: + self.qnetwork_local = QNetwork(state_size, action_size, seed).to(device) + self.qnetwork_target = copy.deepcopy(self.qnetwork_local) + + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + # Replay memory + self.memory = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_final = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + self.memory_agent_can_not_choose = ReplayBuffer(action_size, BUFFER_SIZE, BATCH_SIZE, seed) + + self.final_step = {} + + # Initialize time step (for updating every UPDATE_EVERY steps) + self.t_step = 0 + self.t_step_final = 0 + self.t_step_agent_can_not_choose = 0 + + def save(self, filename): + torch.save(self.qnetwork_local.state_dict(), filename + ".local") + torch.save(self.qnetwork_target.state_dict(), filename + ".target") + + def load(self, filename): + print("try to load: " + filename) + if os.path.exists(filename + ".local"): + self.qnetwork_local.load_state_dict(torch.load(filename + ".local")) + print(filename + ".local -> ok") + if os.path.exists(filename + ".target"): + self.qnetwork_target.load_state_dict(torch.load(filename + ".target")) + print(filename + ".target -> ok") + self.optimizer = USE_OPTIMIZER(self.qnetwork_local.parameters(), lr=LR) + + def _update_model(self, switch=0): + # Learn every UPDATE_EVERY time steps. + # If enough samples are available in memory, get random subset and learn + if switch == 0: + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0: + if len(self.memory) > BATCH_SIZE: + experiences = self.memory.sample() + self.learn(experiences, GAMMA) + elif switch == 1: + self.t_step_final = (self.t_step_final + 1) % UPDATE_EVERY_FINAL + if self.t_step_final == 0: + if len(self.memory_final) > BATCH_SIZE: + experiences = self.memory_final.sample() + self.learn(experiences, GAMMA) + else: + # If enough samples are available in memory_agent_can_not_choose, get random subset and learn + self.t_step_agent_can_not_choose = (self.t_step_agent_can_not_choose + 1) % UPDATE_EVERY_AGENT_CANT_CHOOSE + if self.t_step_agent_can_not_choose == 0: + if len(self.memory_agent_can_not_choose) > BATCH_SIZE: + experiences = self.memory_agent_can_not_choose.sample() + self.learn(experiences, GAMMA) + + def step(self, state, action, reward, next_state, done): + # Save experience in replay memory + self.memory.add(state, action, reward, next_state, done) + self._update_model(0) + + def step_agent_can_not_choose(self, state, action, reward, next_state, done): + # Save experience in replay memory_agent_can_not_choose + self.memory_agent_can_not_choose.add(state, action, reward, next_state, done) + self._update_model(2) + + def add_final_step(self, agent_handle, state, action, reward, next_state, done): + if self.final_step.get(agent_handle) is None: + self.final_step.update({agent_handle: [state, action, reward, next_state, done]}) + return True + else: + return False + + def make_final_step(self, additional_reward=0): + for _, item in self.final_step.items(): + state = item[0] + action = item[1] + reward = item[2] + additional_reward + next_state = item[3] + done = item[4] + self.memory_final.add(state, action, reward, next_state, done) + self._update_model(1) + self._reset_final_step() + + def _reset_final_step(self): + self.final_step = {} + + def act(self, state, eps=0.): + """Returns actions for given state as per current policy. + + Params + ====== + state (array_like): current state + eps (float): epsilon, for epsilon-greedy action selection + """ + state = torch.from_numpy(state).float().unsqueeze(0).to(device) + self.qnetwork_local.eval() + with torch.no_grad(): + action_values = self.qnetwork_local(state) + self.qnetwork_local.train() + + # Epsilon-greedy action selection + if random.random() > eps: + return np.argmax(action_values.cpu().data.numpy()), False + else: + return random.choice(np.arange(self.action_size)), True + + def learn(self, experiences, gamma): + + """Update value parameters using given batch of experience tuples. + + Params + ====== + experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples + gamma (float): discount factor + """ + states, actions, rewards, next_states, dones = experiences + + # Get expected Q values from local model + Q_expected = self.qnetwork_local(states).gather(1, actions) + + if self.double_dqn: + # Double DQN + q_best_action = self.qnetwork_local(next_states).max(1)[1] + Q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1)) + else: + # DQN + Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1) + + # Compute Q targets for current states + + Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) + + # Compute loss + loss = F.mse_loss(Q_expected, Q_targets) + # Minimize the loss + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # ------------------- update target network ------------------- # + self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) + + def soft_update(self, local_model, target_model, tau): + """Soft update model parameters. + θ_target = τ*θ_local + (1 - τ)*θ_target + + Params + ====== + local_model (PyTorch model): weights will be copied from + target_model (PyTorch model): weights will be copied to + tau (float): interpolation parameter + """ + for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): + target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data) + + +class ReplayBuffer: + """Fixed-size buffer to store experience tuples.""" + + def __init__(self, action_size, buffer_size, batch_size, seed): + """Initialize a ReplayBuffer object. + + Params + ====== + action_size (int): dimension of each action + buffer_size (int): maximum size of buffer + batch_size (int): size of each training batch + seed (int): random seed + """ + self.action_size = action_size + self.memory = deque(maxlen=buffer_size) + self.batch_size = batch_size + self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) + self.seed = random.seed(seed) + + def add(self, state, action, reward, next_state, done): + """Add a new experience to memory.""" + e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done) + self.memory.append(e) + + def sample(self): + """Randomly sample a batch of experiences from memory.""" + experiences = random.sample(self.memory, k=self.batch_size) + + states = torch.from_numpy(self.__v_stack_impr([e.state for e in experiences if e is not None])) \ + .float().to(device) + actions = torch.from_numpy(self.__v_stack_impr([e.action for e in experiences if e is not None])) \ + .long().to(device) + rewards = torch.from_numpy(self.__v_stack_impr([e.reward for e in experiences if e is not None])) \ + .float().to(device) + next_states = torch.from_numpy(self.__v_stack_impr([e.next_state for e in experiences if e is not None])) \ + .float().to(device) + dones = torch.from_numpy(self.__v_stack_impr([e.done for e in experiences if e is not None]).astype(np.uint8)) \ + .float().to(device) + + return (states, actions, rewards, next_states, dones) + + def __len__(self): + """Return the current size of internal memory.""" + return len(self.memory) + + def __v_stack_impr(self, states): + sub_dim = len(states[0][0]) if isinstance(states[0], Iterable) else 1 + np_states = np.reshape(np.array(states), (len(states), sub_dim)) + return np_states diff --git a/src/extra.py b/src/extra.py index 044c32f757a397ec20387ce8f0f707418848c37e..48a3249377f6ef3ac8b5f6f0d3c756c3130f1ed4 100644 --- a/src/extra.py +++ b/src/extra.py @@ -1,234 +1,409 @@ -import numpy as np -from flatland.core.grid.grid4_utils import get_new_position -from flatland.envs.agent_utils import RailAgentStatus -from flatland.envs.rail_env import RailEnv -from flatland.envs.rail_env import RailEnvActions - -from src.agent.dueling_double_dqn import Agent -from src.observations import normalize_observation - -state_size = 179 -action_size = 5 -print("state_size: ", state_size) -print("action_size: ", action_size) -# Now we load a Double dueling DQN agent -global_rl_agent = Agent(state_size, action_size, "FC", 0) -global_rl_agent.load('./nets/training_best_0.626_agents_5276.pth') - - -class Extra: - global_rl_agent = None - - def __init__(self, env: RailEnv): - self.env = env - self.rl_agent = global_rl_agent - self.switches = {} - self.switches_neighbours = {} - self.find_all_cell_where_agent_can_choose() - self.steps_counter = 0 - - self.debug_render_list = [] - self.debug_render_path_list = [] - - def rl_agent_act(self, observation, max_depth, eps=0.0): - - self.steps_counter += 1 - print(self.steps_counter, self.env.get_num_agents()) - - agent_obs = [None] * self.env.get_num_agents() - for a in range(self.env.get_num_agents()): - if observation[a]: - agent_obs[a] = self.generate_state(a, observation, max_depth) - - action_dict = {} - # estimate whether the agent(s) can freely choose an action - agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ - self.required_agent_descision() - - for a in range(self.env.get_num_agents()): - if agent_obs[a] is not None: - if agents_can_choose[a]: - act, agent_rnd = self.rl_agent.act(agent_obs[a], eps=eps) - - l = len(agent_obs[a]) - if agent_obs[a][l - 3] > 0 and agents_near_to_switch_all[a]: - act = RailEnvActions.STOP_MOVING - - action_dict.update({a: act}) - else: - act = RailEnvActions.MOVE_FORWARD - action_dict.update({a: act}) - else: - action_dict.update({a: RailEnvActions.DO_NOTHING}) - return action_dict - - def find_all_cell_where_agent_can_choose(self): - switches = {} - for h in range(self.env.height): - for w in range(self.env.width): - pos = (h, w) - for dir in range(4): - possible_transitions = self.env.rail.get_transitions(*pos, dir) - num_transitions = np.count_nonzero(possible_transitions) - if num_transitions > 1: - if pos not in switches.keys(): - switches.update({pos: [dir]}) - else: - switches[pos].append(dir) - - switches_neighbours = {} - for h in range(self.env.height): - for w in range(self.env.width): - # look one step forward - for dir in range(4): - pos = (h, w) - possible_transitions = self.env.rail.get_transitions(*pos, dir) - for d in range(4): - if possible_transitions[d] == 1: - new_cell = get_new_position(pos, d) - if new_cell in switches.keys() and pos not in switches.keys(): - if pos not in switches_neighbours.keys(): - switches_neighbours.update({pos: [dir]}) - else: - switches_neighbours[pos].append(dir) - - self.switches = switches - self.switches_neighbours = switches_neighbours - - def check_agent_descision(self, position, direction, switches, switches_neighbours): - agents_on_switch = False - agents_near_to_switch = False - agents_near_to_switch_all = False - if position in switches.keys(): - agents_on_switch = direction in switches[position] - - if position in switches_neighbours.keys(): - new_cell = get_new_position(position, direction) - if new_cell in switches.keys(): - if not direction in switches[new_cell]: - agents_near_to_switch = direction in switches_neighbours[position] - else: - agents_near_to_switch = direction in switches_neighbours[position] - - agents_near_to_switch_all = direction in switches_neighbours[position] - - return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all - - def required_agent_descision(self): - agents_can_choose = {} - agents_on_switch = {} - agents_near_to_switch = {} - agents_near_to_switch_all = {} - for a in range(self.env.get_num_agents()): - ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all = \ - self.check_agent_descision( - self.env.agents[a].position, - self.env.agents[a].direction, - self.switches, - self.switches_neighbours) - agents_on_switch.update({a: ret_agents_on_switch}) - ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART - agents_near_to_switch.update({a: (ret_agents_near_to_switch or ready_to_depart)}) - - agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) - - agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all or ready_to_depart)}) - - return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all - - def check_deadlock(self, only_next_cell_check=False, handle=None): - agents_with_deadlock = [] - agents = range(self.env.get_num_agents()) - if handle is not None: - agents = [handle] - for a in agents: - if self.env.agents[a].status < RailAgentStatus.DONE: - position = self.env.agents[a].position - first_step = True - if position is None: - position = self.env.agents[a].initial_position - first_step = True - direction = self.env.agents[a].direction - cnt = 0 - while position is not None: # and position != self.env.agents[a].target: - possible_transitions = self.env.rail.get_transitions(*position, direction) - # num_transitions = np.count_nonzero(possible_transitions) - agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision( - position, - direction, - self.switches, - self.switches_neighbours) - - if not agents_on_switch or first_step: - first_step = False - new_direction_me = np.argmax(possible_transitions) - new_cell_me = get_new_position(position, new_direction_me) - opp_agent = self.env.agent_positions[new_cell_me] - if opp_agent != -1: - opp_position = self.env.agents[opp_agent].position - opp_direction = self.env.agents[opp_agent].direction - opp_agents_on_switch, opp_agents_near_to_switch, agents_near_to_switch_all = \ - self.check_agent_descision(opp_position, - opp_direction, - self.switches, - self.switches_neighbours) - - # opp_possible_transitions = self.env.rail.get_transitions(*opp_position, opp_direction) - # opp_num_transitions = np.count_nonzero(opp_possible_transitions) - if not opp_agents_on_switch: - if opp_direction != direction: - agents_with_deadlock.append(a) - position = None - else: - if only_next_cell_check: - position = None - else: - position = new_cell_me - direction = new_direction_me - else: - if only_next_cell_check: - position = None - else: - position = new_cell_me - direction = new_direction_me - else: - if only_next_cell_check: - position = None - else: - position = new_cell_me - direction = new_direction_me - else: - position = None - - cnt += 1 - if cnt > 100: - position = None - - return agents_with_deadlock - - def generate_state(self, handle: int, root, max_depth: int): - n_obs = normalize_observation(root[handle], max_depth) - - position = self.env.agents[handle].position - direction = self.env.agents[handle].direction - cell_free_4_first_step = -1 - deadlock_agents = [] - if self.env.agents[handle].status == RailAgentStatus.READY_TO_DEPART: - if self.env.agent_positions[self.env.agents[handle].initial_position] == -1: - cell_free_4_first_step = 1 - position = self.env.agents[handle].initial_position - else: - deadlock_agents = self.check_deadlock(only_next_cell_check=False, handle=handle) - agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision(position, - direction, - self.switches, - self.switches_neighbours) - - append_obs = [self.env.agents[handle].status - RailAgentStatus.ACTIVE, - cell_free_4_first_step, - 2.0 * int(len(deadlock_agents)) - 1.0, - 2.0 * int(agents_on_switch) - 1.0, - 2.0 * int(agents_near_to_switch) - 1.0] - n_obs = np.append(n_obs, append_obs) - - return n_obs +import numpy as np +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +# Adrian Egli performance fix (the fast methods brings more than 50%) +from flatland.envs.rail_env import RailEnvActions + +from src.ppo.agent import Agent + + +def fast_isclose(a, b, rtol): + return (a < (b + rtol)) or (a < (b - rtol)) + + +def fast_clip(position: (int, int), min_value: (int, int), max_value: (int, int)) -> bool: + return ( + max(min_value[0], min(position[0], max_value[0])), + max(min_value[1], min(position[1], max_value[1])) + ) + + +def fast_argmax(possible_transitions: (int, int, int, int)) -> bool: + if possible_transitions[0] == 1: + return 0 + if possible_transitions[1] == 1: + return 1 + if possible_transitions[2] == 1: + return 2 + return 3 + + +def fast_position_equal(pos_1: (int, int), pos_2: (int, int)) -> bool: + return pos_1[0] == pos_2[0] and pos_1[1] == pos_2[1] + + +def fast_count_nonzero(possible_transitions: (int, int, int, int)): + return possible_transitions[0] + possible_transitions[1] + possible_transitions[2] + possible_transitions[3] + + +class Extra(ObservationBuilder): + + def __init__(self, max_depth): + self.max_depth = max_depth + self.observation_dim = 22 + self.agent = None + + def loadAgent(self): + if self.agent is not None: + return + self.state_size = self.env.obs_builder.observation_dim + self.action_size = 5 + print("action_size: ", self.action_size) + print("state_size: ", self.state_size) + self.agent = Agent(self.state_size, self.action_size, 0) + self.agent.load('./checkpoints/', 0, 1.0) + + def build_data(self): + if self.env is not None: + self.env.dev_obs_dict = {} + self.switches = {} + self.switches_neighbours = {} + self.debug_render_list = [] + self.debug_render_path_list = [] + if self.env is not None: + self.find_all_cell_where_agent_can_choose() + + def find_all_cell_where_agent_can_choose(self): + + switches = {} + for h in range(self.env.height): + for w in range(self.env.width): + pos = (h, w) + for dir in range(4): + possible_transitions = self.env.rail.get_transitions(*pos, dir) + num_transitions = fast_count_nonzero(possible_transitions) + if num_transitions > 1: + if pos not in switches.keys(): + switches.update({pos: [dir]}) + else: + switches[pos].append(dir) + + switches_neighbours = {} + for h in range(self.env.height): + for w in range(self.env.width): + # look one step forward + for dir in range(4): + pos = (h, w) + possible_transitions = self.env.rail.get_transitions(*pos, dir) + for d in range(4): + if possible_transitions[d] == 1: + new_cell = get_new_position(pos, d) + if new_cell in switches.keys() and pos not in switches.keys(): + if pos not in switches_neighbours.keys(): + switches_neighbours.update({pos: [dir]}) + else: + switches_neighbours[pos].append(dir) + + self.switches = switches + self.switches_neighbours = switches_neighbours + + def check_agent_descision(self, position, direction): + switches = self.switches + switches_neighbours = self.switches_neighbours + agents_on_switch = False + agents_near_to_switch = False + agents_near_to_switch_all = False + if position in switches.keys(): + agents_on_switch = direction in switches[position] + + if position in switches_neighbours.keys(): + new_cell = get_new_position(position, direction) + if new_cell in switches.keys(): + if not direction in switches[new_cell]: + agents_near_to_switch = direction in switches_neighbours[position] + else: + agents_near_to_switch = direction in switches_neighbours[position] + + agents_near_to_switch_all = direction in switches_neighbours[position] + + return agents_on_switch, agents_near_to_switch, agents_near_to_switch_all + + def required_agent_descision(self): + agents_can_choose = {} + agents_on_switch = {} + agents_near_to_switch = {} + agents_near_to_switch_all = {} + for a in range(self.env.get_num_agents()): + ret_agents_on_switch, ret_agents_near_to_switch, ret_agents_near_to_switch_all = \ + self.check_agent_descision( + self.env.agents[a].position, + self.env.agents[a].direction) + agents_on_switch.update({a: ret_agents_on_switch}) + ready_to_depart = self.env.agents[a].status == RailAgentStatus.READY_TO_DEPART + agents_near_to_switch.update({a: (ret_agents_near_to_switch and not ready_to_depart)}) + + agents_can_choose.update({a: agents_on_switch[a] or agents_near_to_switch[a]}) + + agents_near_to_switch_all.update({a: (ret_agents_near_to_switch_all and not ready_to_depart)}) + + return agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all + + def debug_render(self, env_renderer): + agents_can_choose, agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = \ + self.required_agent_descision() + self.env.dev_obs_dict = {} + for a in range(max(3, self.env.get_num_agents())): + self.env.dev_obs_dict.update({a: []}) + + selected_agent = None + if agents_can_choose[0]: + if self.env.agents[0].position is not None: + self.debug_render_list.append(self.env.agents[0].position) + else: + self.debug_render_list.append(self.env.agents[0].initial_position) + + if self.env.agents[0].position is not None: + self.debug_render_path_list.append(self.env.agents[0].position) + else: + self.debug_render_path_list.append(self.env.agents[0].initial_position) + + env_renderer.gl.agent_colors[0] = env_renderer.gl.rgb_s2i("FF0000") + env_renderer.gl.agent_colors[1] = env_renderer.gl.rgb_s2i("666600") + env_renderer.gl.agent_colors[2] = env_renderer.gl.rgb_s2i("006666") + env_renderer.gl.agent_colors[3] = env_renderer.gl.rgb_s2i("550000") + + self.env.dev_obs_dict[0] = self.debug_render_list + self.env.dev_obs_dict[1] = self.switches.keys() + self.env.dev_obs_dict[2] = self.switches_neighbours.keys() + self.env.dev_obs_dict[3] = self.debug_render_path_list + + def normalize_observation(self, obsData): + return obsData + + def check_deadlock(self, only_next_cell_check=True, handle=None): + agents_with_deadlock = [] + agents = range(self.env.get_num_agents()) + if handle is not None: + agents = [handle] + for a in agents: + if self.env.agents[a].status < RailAgentStatus.DONE: + position = self.env.agents[a].position + first_step = True + if position is None: + position = self.env.agents[a].initial_position + first_step = True + direction = self.env.agents[a].direction + while position is not None: # and position != self.env.agents[a].target: + possible_transitions = self.env.rail.get_transitions(*position, direction) + # num_transitions = np.count_nonzero(possible_transitions) + agents_on_switch, agents_near_to_switch, agents_near_to_switch_all = self.check_agent_descision( + position, + direction, + self.switches, + self.switches_neighbours) + + if not agents_on_switch or first_step: + first_step = False + new_direction_me = np.argmax(possible_transitions) + new_cell_me = get_new_position(position, new_direction_me) + opp_agent = self.env.agent_positions[new_cell_me] + if opp_agent != -1: + opp_position = self.env.agents[opp_agent].position + opp_direction = self.env.agents[opp_agent].direction + opp_agents_on_switch, opp_agents_near_to_switch, agents_near_to_switch_all = \ + self.check_agent_descision(opp_position, + opp_direction, + self.switches, + self.switches_neighbours) + + # opp_possible_transitions = self.env.rail.get_transitions(*opp_position, opp_direction) + # opp_num_transitions = np.count_nonzero(opp_possible_transitions) + if not opp_agents_on_switch: + if opp_direction != direction: + agents_with_deadlock.append(a) + position = None + else: + if only_next_cell_check: + position = None + else: + position = new_cell_me + direction = new_direction_me + else: + if only_next_cell_check: + position = None + else: + position = new_cell_me + direction = new_direction_me + else: + if only_next_cell_check: + position = None + else: + position = new_cell_me + direction = new_direction_me + else: + position = None + + return agents_with_deadlock + + def is_collision(self, obsData): + if obsData[4] == 1: + # Agent is READY_TO_DEPART + return False + if obsData[6] == 1: + # Agent is DONE / DONE_REMOVED + return False + + same_dir = obsData[18] + obsData[19] + obsData[20] + obsData[21] + if same_dir > 0: + # Agent detect an agent walking in same direction and between the agent and the other agent there are all + # cell unoccupied. (Follows the agents) + return False + freedom = obsData[10] + obsData[11] + obsData[12] + obsData[13] + blocked = obsData[14] + obsData[15] + obsData[16] + obsData[17] + # if the Agent has equal freedom or less then the agent can not avoid the agent travelling towards + # (opposite) direction -> this can cause a deadlock (locally tested) + return freedom <= blocked and freedom > 0 + + def reset(self): + self.build_data() + return + + def fast_argmax(self, array): + if array[0] == 1: + return 0 + if array[1] == 1: + return 1 + if array[2] == 1: + return 2 + return 3 + + def _explore(self, handle, new_position, new_direction, depth=0): + + has_opp_agent = 0 + has_same_agent = 0 + visited = [] + + # stop exploring (max_depth reached) + if depth >= self.max_depth: + return has_opp_agent, has_same_agent, visited + + # max_explore_steps = 100 + cnt = 0 + while cnt < 100: + cnt += 1 + + visited.append(new_position) + opp_a = self.env.agent_positions[new_position] + if opp_a != -1 and opp_a != handle: + if self.env.agents[opp_a].direction != new_direction: + # opp agent found + has_opp_agent = 1 + return has_opp_agent, has_same_agent, visited + else: + has_same_agent = 1 + return has_opp_agent, has_same_agent, visited + + # convert one-hot encoding to 0,1,2,3 + possible_transitions = self.env.rail.get_transitions(*new_position, new_direction) + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all = \ + self.check_agent_descision(new_position, new_direction) + if agents_near_to_switch: + return has_opp_agent, has_same_agent, visited + + if agents_on_switch: + for dir_loop in range(4): + if possible_transitions[dir_loop] == 1: + hoa, hsa, v = self._explore(handle, new_position, new_direction, depth + 1) + visited.append(v) + has_opp_agent = 0.5 * (has_opp_agent + hoa) + has_same_agent = 0.5 * (has_same_agent + hsa) + return has_opp_agent, has_same_agent, visited + else: + new_direction = fast_argmax(possible_transitions) + new_position = get_new_position(new_position, new_direction) + return has_opp_agent, has_same_agent, visited + + def get(self, handle): + # all values are [0,1] + # observation[0] : 1 path towards target (direction 0) / otherwise 0 -> path is longer or there is no path + # observation[1] : 1 path towards target (direction 1) / otherwise 0 -> path is longer or there is no path + # observation[2] : 1 path towards target (direction 2) / otherwise 0 -> path is longer or there is no path + # observation[3] : 1 path towards target (direction 3) / otherwise 0 -> path is longer or there is no path + # observation[4] : int(agent.status == RailAgentStatus.READY_TO_DEPART) + # observation[5] : int(agent.status == RailAgentStatus.ACTIVE) + # observation[6] : int(agent.status == RailAgentStatus.DONE or agent.status == RailAgentStatus.DONE_REMOVED) + # observation[7] : current agent is located at a switch, where it can take a routing decision + # observation[8] : current agent is located at a cell, where it has to take a stop-or-go decision + # observation[9] : current agent is located one step before/after a switch + # observation[10] : 1 if there is a path (track/branch) otherwise 0 (direction 0) + # observation[11] : 1 if there is a path (track/branch) otherwise 0 (direction 1) + # observation[12] : 1 if there is a path (track/branch) otherwise 0 (direction 2) + # observation[13] : 1 if there is a path (track/branch) otherwise 0 (direction 3) + # observation[14] : If there is a path with step (direction 0) and there is a agent with opposite direction -> 1 + # observation[15] : If there is a path with step (direction 1) and there is a agent with opposite direction -> 1 + # observation[16] : If there is a path with step (direction 2) and there is a agent with opposite direction -> 1 + # observation[17] : If there is a path with step (direction 3) and there is a agent with opposite direction -> 1 + # observation[18] : If there is a path with step (direction 0) and there is a agent with same direction -> 1 + # observation[19] : If there is a path with step (direction 1) and there is a agent with same direction -> 1 + # observation[20] : If there is a path with step (direction 2) and there is a agent with same direction -> 1 + # observation[21] : If there is a path with step (direction 3) and there is a agent with same direction -> 1 + + observation = np.zeros(self.observation_dim) + visited = [] + agent = self.env.agents[handle] + + agent_done = False + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + observation[4] = 1 + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + observation[5] = 1 + else: + observation[6] = 1 + agent_virtual_position = (-1, -1) + agent_done = True + + if not agent_done: + visited.append(agent_virtual_position) + distance_map = self.env.distance_map.get() + current_cell_dist = distance_map[handle, + agent_virtual_position[0], agent_virtual_position[1], + agent.direction] + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) + orientation = agent.direction + if fast_count_nonzero(possible_transitions) == 1: + orientation = np.argmax(possible_transitions) + + for dir_loop, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): + if possible_transitions[branch_direction]: + new_position = get_new_position(agent_virtual_position, branch_direction) + new_cell_dist = distance_map[handle, + new_position[0], new_position[1], + branch_direction] + if not (np.math.isinf(new_cell_dist) and np.math.isinf(current_cell_dist)): + observation[dir_loop] = int(new_cell_dist < current_cell_dist) + + has_opp_agent, has_same_agent, v = self._explore(handle, new_position, branch_direction) + visited.append(v) + + observation[10 + dir_loop] = 1 + observation[14 + dir_loop] = has_opp_agent + observation[18 + dir_loop] = has_same_agent + + agents_on_switch, \ + agents_near_to_switch, \ + agents_near_to_switch_all = \ + self.check_agent_descision(agent_virtual_position, agent.direction) + observation[7] = int(agents_on_switch) + observation[8] = int(agents_near_to_switch) + observation[9] = int(agents_near_to_switch_all) + + self.env.dev_obs_dict.update({handle: visited}) + + return observation + + def rl_agent_act(self, observation, info, eps=0.0): + self.loadAgent() + action_dict = {} + for a in range(self.env.get_num_agents()): + if info['action_required'][a]: + action_dict[a] = self.agent.act(observation[a], eps=eps) + # action_dict[a] = np.random.randint(5) + else: + action_dict[a] = RailEnvActions.DO_NOTHING + + return action_dict diff --git a/src/observations.py b/src/observations.py index fd9659cd41d72e8c92d2a96ac40a961524fb27ce..ce47e508b21d7f0cf9a9a88ee864d3da18184cf1 100644 --- a/src/observations.py +++ b/src/observations.py @@ -1,736 +1,736 @@ -""" -Collection of environment-specific ObservationBuilder. -""" -import collections -from typing import Optional, List, Dict, Tuple - -import numpy as np -from flatland.core.env import Environment -from flatland.core.env_observation_builder import ObservationBuilder -from flatland.core.env_prediction_builder import PredictionBuilder -from flatland.core.grid.grid4_utils import get_new_position -from flatland.core.grid.grid_utils import coordinate_to_position -from flatland.envs.agent_utils import RailAgentStatus, EnvAgent -from flatland.utils.ordered_set import OrderedSet - - -class MyTreeObsForRailEnv(ObservationBuilder): - """ - TreeObsForRailEnv object. - - This object returns observation vectors for agents in the RailEnv environment. - The information is local to each agent and exploits the graph structure of the rail - network to simplify the representation of the state of the environment for each agent. - - For details about the features in the tree observation see the get() function. - """ - Node = collections.namedtuple('Node', 'dist_min_to_target ' - 'target_encountered ' - 'num_agents_same_direction ' - 'num_agents_opposite_direction ' - 'childs') - - tree_explored_actions_char = ['L', 'F', 'R', 'B'] - - def __init__(self, max_depth: int, predictor: PredictionBuilder = None): - super().__init__() - self.max_depth = max_depth - self.observation_dim = 2 - self.location_has_agent = {} - self.predictor = predictor - self.location_has_target = None - - self.switches_list = {} - self.switches_neighbours_list = [] - self.check_agent_descision = None - - def reset(self): - self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} - - def set_switch_and_pre_switch(self, switch_list, pre_switch_list, check_agent_descision): - self.switches_list = switch_list - self.switches_neighbours_list = pre_switch_list - self.check_agent_descision = check_agent_descision - - def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]: - """ - Called whenever an observation has to be computed for the `env` environment, for each agent with handle - in the `handles` list. - """ - - if handles is None: - handles = [] - if self.predictor: - self.max_prediction_depth = 0 - self.predicted_pos = {} - self.predicted_dir = {} - self.predictions = self.predictor.get() - if self.predictions: - for t in range(self.predictor.max_depth + 1): - pos_list = [] - dir_list = [] - for a in handles: - if self.predictions[a] is None: - continue - pos_list.append(self.predictions[a][t][1:3]) - dir_list.append(self.predictions[a][t][3]) - self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) - self.predicted_dir.update({t: dir_list}) - self.max_prediction_depth = len(self.predicted_pos) - # Update local lookup table for all agents' positions - # ignore other agents not in the grid (only status active and done) - - self.location_has_agent = {} - self.location_has_agent_direction = {} - self.location_has_agent_speed = {} - self.location_has_agent_malfunction = {} - self.location_has_agent_ready_to_depart = {} - - for _agent in self.env.agents: - if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ - _agent.position: - self.location_has_agent[tuple(_agent.position)] = 1 - self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction - self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] - self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ - 'malfunction'] - - if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ - _agent.initial_position: - self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ - self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 - - observations = super().get_many(handles) - - return observations - - def get(self, handle: int = 0) -> Node: - """ - Computes the current observation for agent `handle` in env - - The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible - movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). - The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for - the transitions. The order is:: - - [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] - - Each branch data is organized as:: - - [root node information] + - [recursive branch data from 'left'] + - [... from 'forward'] + - [... from 'right] + - [... from 'back'] - - Each node information is composed of 9 features: - - #1: - if own target lies on the explored branch the current distance from the agent in number of cells is stored. - - #2: - if another agents target is detected the distance in number of cells from the agents current location\ - is stored - - #3: - if another agent is detected the distance in number of cells from current agent position is stored. - - #4: - possible conflict detected - tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \ - distance in number of cells from current agent position - - 0 = No other agent reserve the same cell at similar time - - #5: - if an not usable switch (for agent) is detected we store the distance. - - #6: - This feature stores the distance in number of cells to the next branching (current node) - - #7: - minimum distance from node to the agent's target given the direction of the agent if this path is chosen - - #8: - agent in the same direction - n = number of agents present same direction \ - (possible future use: number of other agents in the same direction in this branch) - 0 = no agent present same direction - - #9: - agent in the opposite direction - n = number of agents present other direction than myself (so conflict) \ - (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) - 0 = no agent present other direction than myself - - #10: - malfunctioning/blokcing agents - n = number of time steps the oberved agent remains blocked - - #11: - slowest observed speed of an agent in same direction - 1 if no agent is observed - - min_fractional speed otherwise - #12: - number of agents ready to depart but no yet active - - Missing/padding nodes are filled in with -inf (truncated). - Missing values in present node are filled in with +inf (truncated). - - - In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed] - In case the target node is reached, the values are [0, 0, 0, 0, 0]. - """ - - if handle > len(self.env.agents): - print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) - agent = self.env.agents[handle] # TODO: handle being treated as index - - if agent.status == RailAgentStatus.READY_TO_DEPART: - agent_virtual_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: - agent_virtual_position = agent.position - elif agent.status == RailAgentStatus.DONE: - agent_virtual_position = agent.target - else: - return None - - possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) - num_transitions = np.count_nonzero(possible_transitions) - - # Here information about the agent itself is stored - distance_map = self.env.distance_map.get() - - root_node_observation = MyTreeObsForRailEnv.Node(dist_min_to_target=distance_map[ - (handle, *agent_virtual_position, - agent.direction)], - target_encountered=0, - num_agents_same_direction=0, - num_agents_opposite_direction=0, - childs={}) - - visited = OrderedSet() - - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # If only one transition is possible, the tree is oriented with this transition as the forward branch. - orientation = agent.direction - - if num_transitions == 1: - orientation = np.argmax(possible_transitions) - - for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): - if possible_transitions[branch_direction]: - new_cell = get_new_position(agent_virtual_position, branch_direction) - - branch_observation, branch_visited = \ - self._explore_branch(handle, new_cell, branch_direction, 1, 1) - root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation - - visited |= branch_visited - else: - # add cells filled with infinity if no transition is possible - root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf - self.env.dev_obs_dict[handle] = visited - - return root_node_observation - - def _explore_branch(self, handle, position, direction, tot_dist, depth): - """ - Utility function to compute tree-based observations. - We walk along the branch and collect the information documented in the get() function. - If there is a branching point a new node is created and each possible branch is explored. - """ - - # [Recursive branch opened] - if depth >= self.max_depth + 1: - return [], [] - - # Continue along direction until next switch or - # until no transitions are possible along the current direction (i.e., dead-ends) - # We treat dead-ends as nodes, instead of going back, to avoid loops - exploring = True - - visited = OrderedSet() - agent = self.env.agents[handle] - - other_agent_opposite_direction = 0 - other_agent_same_direction = 0 - - dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] - - last_is_dead_end = False - last_is_a_decision_cell = False - target_encountered = 0 - - cnt = 0 - while exploring: - - dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1], - direction]) - - if agent.target == position: - target_encountered = 1 - - new_direction_me = direction - new_cell_me = position - a = self.env.agent_positions[new_cell_me] - if a != -1 and a != handle: - opp_agent = self.env.agents[a] - # look one step forward - # opp_possible_transitions = self.env.rail.get_transitions(*opp_agent.position, opp_agent.direction) - if opp_agent.direction != new_direction_me: # opp_possible_transitions[new_direction_me] == 0: - other_agent_opposite_direction += 1 - else: - other_agent_same_direction += 1 - - # ############################# - # ############################# - if (position[0], position[1], direction) in visited: - break - visited.add((position[0], position[1], direction)) - - # If the target node is encountered, pick that as node. Also, no further branching is possible. - if np.array_equal(position, self.env.agents[handle].target): - last_is_target = True - break - - exploring = False - - # Check number of possible transitions for agent and total number of transitions in cell (type) - possible_transitions = self.env.rail.get_transitions(*position, direction) - num_transitions = np.count_nonzero(possible_transitions) - # cell_transitions = self.env.rail.get_transitions(*position, direction) - transition_bit = bin(self.env.rail.get_full_transitions(*position)) - total_transitions = transition_bit.count("1") - - if num_transitions == 1: - # Check if dead-end, or if we can go forward along direction - nbits = total_transitions - if nbits == 1: - # Dead-end! - last_is_dead_end = True - - if self.check_agent_descision is not None: - ret_agents_on_switch, ret_agents_near_to_switch, agents_near_to_switch_all = \ - self.check_agent_descision(position, - direction, - self.switches_list, - self.switches_neighbours_list) - if ret_agents_on_switch: - last_is_a_decision_cell = True - break - - exploring = True - # convert one-hot encoding to 0,1,2,3 - cell_transitions = self.env.rail.get_transitions(*position, direction) - direction = np.argmax(cell_transitions) - position = get_new_position(position, direction) - - cnt += 1 - if cnt > 1000: - exploring = False - - # ############################# - # ############################# - # Modify here to append new / different features for each visited cell! - - node = MyTreeObsForRailEnv.Node(dist_min_to_target=dist_min_to_target, - target_encountered=target_encountered, - num_agents_opposite_direction=other_agent_opposite_direction, - num_agents_same_direction=other_agent_same_direction, - childs={}) - - # ############################# - # ############################# - # Start from the current orientation, and see which transitions are available; - # organize them as [left, forward, right, back], relative to the current orientation - # Get the possible transitions - possible_transitions = self.env.rail.get_transitions(*position, direction) - - for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): - if last_is_dead_end and self.env.rail.get_transition((*position, direction), - (branch_direction + 2) % 4): - # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes - # it back - new_cell = get_new_position(position, (branch_direction + 2) % 4) - branch_observation, branch_visited = self._explore_branch(handle, - new_cell, - (branch_direction + 2) % 4, - tot_dist + 1, - depth + 1) - node.childs[self.tree_explored_actions_char[i]] = branch_observation - if len(branch_visited) != 0: - visited |= branch_visited - elif last_is_a_decision_cell and possible_transitions[branch_direction]: - new_cell = get_new_position(position, branch_direction) - branch_observation, branch_visited = self._explore_branch(handle, - new_cell, - branch_direction, - tot_dist + 1, - depth + 1) - node.childs[self.tree_explored_actions_char[i]] = branch_observation - if len(branch_visited) != 0: - visited |= branch_visited - else: - # no exploring possible, add just cells with infinity - node.childs[self.tree_explored_actions_char[i]] = -np.inf - - if depth == self.max_depth: - node.childs.clear() - return node, visited - - def util_print_obs_subtree(self, tree: Node): - """ - Utility function to print tree observations returned by this object. - """ - self.print_node_features(tree, "root", "") - for direction in self.tree_explored_actions_char: - self.print_subtree(tree.childs[direction], direction, "\t") - - @staticmethod - def print_node_features(node: Node, label, indent): - print(indent, "Direction ", label, ": ", node.num_agents_same_direction, - ", ", node.num_agents_opposite_direction) - - def print_subtree(self, node, label, indent): - if node == -np.inf or not node: - print(indent, "Direction ", label, ": -np.inf") - return - - self.print_node_features(node, label, indent) - - if not node.childs: - return - - for direction in self.tree_explored_actions_char: - self.print_subtree(node.childs[direction], direction, indent + "\t") - - def set_env(self, env: Environment): - super().set_env(env) - if self.predictor: - self.predictor.set_env(self.env) - - def _reverse_dir(self, direction): - return int((direction + 2) % 4) - - -class GlobalObsForRailEnv(ObservationBuilder): - """ - Gives a global observation of the entire rail environment. - The observation is composed of the following elements: - - - transition map array with dimensions (env.height, env.width, 16),\ - assuming 16 bits encoding of transitions. - - - obs_agents_state: A 3D array (map_height, map_width, 5) with - - first channel containing the agents position and direction - - second channel containing the other agents positions and direction - - third channel containing agent/other agent malfunctions - - fourth channel containing agent/other agent fractional speeds - - fifth channel containing number of other agents ready to depart - - - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ - target and the positions of the other agents targets (flag only, no counter!). - """ - - def __init__(self): - super(GlobalObsForRailEnv, self).__init__() - - def set_env(self, env: Environment): - super().set_env(env) - - def reset(self): - self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) - for i in range(self.rail_obs.shape[0]): - for j in range(self.rail_obs.shape[1]): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): - - agent = self.env.agents[handle] - if agent.status == RailAgentStatus.READY_TO_DEPART: - agent_virtual_position = agent.initial_position - elif agent.status == RailAgentStatus.ACTIVE: - agent_virtual_position = agent.position - elif agent.status == RailAgentStatus.DONE: - agent_virtual_position = agent.target - else: - return None - - obs_targets = np.zeros((self.env.height, self.env.width, 2)) - obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1 - - # TODO can we do this more elegantly? - # for r in range(self.env.height): - # for c in range(self.env.width): - # obs_agents_state[(r, c)][4] = 0 - obs_agents_state[:, :, 4] = 0 - - obs_agents_state[agent_virtual_position][0] = agent.direction - obs_targets[agent.target][0] = 1 - - for i in range(len(self.env.agents)): - other_agent: EnvAgent = self.env.agents[i] - - # ignore other agents not in the grid any more - if other_agent.status == RailAgentStatus.DONE_REMOVED: - continue - - obs_targets[other_agent.target][1] = 1 - - # second to fourth channel only if in the grid - if other_agent.position is not None: - # second channel only for other agents - if i != handle: - obs_agents_state[other_agent.position][1] = other_agent.direction - obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] - obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] - # fifth channel: all ready to depart on this position - if other_agent.status == RailAgentStatus.READY_TO_DEPART: - obs_agents_state[other_agent.initial_position][4] += 1 - return self.rail_obs, obs_agents_state, obs_targets - - -class LocalObsForRailEnv(ObservationBuilder): - """ - !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!! - Gives a local observation of the rail environment around the agent. - The observation is composed of the following elements: - - - transition map array of the local environment around the given agent, \ - with dimensions (view_height,2*view_width+1, 16), \ - assuming 16 bits encoding of transitions. - - - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \ - if they are in the agent's vision range, its target position, the positions of the other targets. - - - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \ - of the other agents at their position coordinates, if they are in the agent's vision range. - - - A 4 elements array with one hot encoding of the direction. - - Use the parameters view_width and view_height to define the rectangular view of the agent. - The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has - observation in front of it. - - .. deprecated:: 2.0.0 - """ - - def __init__(self, view_width, view_height, center): - - super(LocalObsForRailEnv, self).__init__() - self.view_width = view_width - self.view_height = view_height - self.center = center - self.max_padding = max(self.view_width, self.view_height - self.center) - - def reset(self): - # We build the transition map with a view_radius empty cells expansion on each side. - # This helps to collect the local transition map view when the agent is close to a border. - self.max_padding = max(self.view_width, self.view_height) - self.rail_obs = np.zeros((self.env.height, - self.env.width, 16)) - for i in range(self.env.height): - for j in range(self.env.width): - bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] - bitlist = [0] * (16 - len(bitlist)) + bitlist - self.rail_obs[i, j] = np.array(bitlist) - - def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): - agents = self.env.agents - agent = agents[handle] - - # Correct agents position for padding - # agent_rel_pos[0] = agent.position[0] + self.max_padding - # agent_rel_pos[1] = agent.position[1] + self.max_padding - - # Collect visible cells as set to be plotted - visited, rel_coords = self.field_of_view(agent.position, agent.direction, ) - local_rail_obs = None - - # Add the visible cells to the observed cells - self.env.dev_obs_dict[handle] = set(visited) - - # Locate observed agents and their coresponding targets - local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16)) - obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2)) - obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4)) - _idx = 0 - for pos in visited: - curr_rel_coord = rel_coords[_idx] - local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :] - if pos == agent.target: - obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1 - else: - for tmp_agent in agents: - if pos == tmp_agent.target: - obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1 - if pos != agent.position: - for tmp_agent in agents: - if pos == tmp_agent.position: - obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[ - tmp_agent.direction] - - _idx += 1 - - direction = np.identity(4)[agent.direction] - return local_rail_obs, obs_map_state, obs_other_agents_state, direction - - def get_many(self, handles: Optional[List[int]] = None) -> Dict[ - int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: - """ - Called whenever an observation has to be computed for the `env` environment, for each agent with handle - in the `handles` list. - """ - - return super().get_many(handles) - - def field_of_view(self, position, direction, state=None): - # Compute the local field of view for an agent in the environment - data_collection = False - if state is not None: - temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16)) - data_collection = True - if direction == 0: - origin = (position[0] + self.center, position[1] - self.view_width) - elif direction == 1: - origin = (position[0] - self.view_width, position[1] - self.center) - elif direction == 2: - origin = (position[0] - self.center, position[1] + self.view_width) - else: - origin = (position[0] + self.view_width, position[1] + self.center) - visible = list() - rel_coords = list() - for h in range(self.view_height): - for w in range(2 * self.view_width + 1): - if direction == 0: - if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: - visible.append((origin[0] - h, origin[1] + w)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] - elif direction == 1: - if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: - visible.append((origin[0] + w, origin[1] + h)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] - elif direction == 2: - if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width: - visible.append((origin[0] + h, origin[1] - w)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] - else: - if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width: - visible.append((origin[0] - w, origin[1] - h)) - rel_coords.append((h, w)) - # if data_collection: - # temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] - if data_collection: - return temp_visible_data - else: - return visible, rel_coords - - -def _split_node_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int) -> (np.ndarray, np.ndarray, - np.ndarray): - data = np.zeros(2) - - data[0] = 2.0 * int(node.num_agents_opposite_direction > 0) - 1.0 - # data[1] = 2.0 * int(node.num_agents_same_direction > 0) - 1.0 - data[1] = 2.0 * int(node.target_encountered > 0) - 1.0 - - return data - - -def _split_subtree_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int, - current_tree_depth: int, - max_tree_depth: int) -> ( - np.ndarray, np.ndarray, np.ndarray): - if node == -np.inf: - remaining_depth = max_tree_depth - current_tree_depth - # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure - num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) - return [0] * num_remaining_nodes * 2 - - data = _split_node_into_feature_groups(node, dist_min_to_target) - - if not node.childs: - return data - - for direction in MyTreeObsForRailEnv.tree_explored_actions_char: - sub_data = _split_subtree_into_feature_groups(node.childs[direction], - node.dist_min_to_target, - current_tree_depth + 1, - max_tree_depth) - data = np.concatenate((data, sub_data)) - return data - - -def split_tree_into_feature_groups(tree: MyTreeObsForRailEnv.Node, max_tree_depth: int) -> ( - np.ndarray, np.ndarray, np.ndarray): - """ - This function splits the tree into three difference arrays of values - """ - data = _split_node_into_feature_groups(tree, 1000000.0) - - for direction in MyTreeObsForRailEnv.tree_explored_actions_char: - sub_data = _split_subtree_into_feature_groups(tree.childs[direction], - 1000000.0, - 1, - max_tree_depth) - data = np.concatenate((data, sub_data)) - - return data - - -def normalize_observation(observation: MyTreeObsForRailEnv.Node, tree_depth: int): - """ - This function normalizes the observation used by the RL algorithm - """ - data = split_tree_into_feature_groups(observation, tree_depth) - normalized_obs = data - - # navigate_info - navigate_info = np.zeros(4) - action_info = np.zeros(4) - np.seterr(all='raise') - try: - dm = observation.dist_min_to_target - if observation.childs['L'] != -np.inf: - navigate_info[0] = dm - observation.childs['L'].dist_min_to_target - action_info[0] = 1 - if observation.childs['F'] != -np.inf: - navigate_info[1] = dm - observation.childs['F'].dist_min_to_target - action_info[1] = 1 - if observation.childs['R'] != -np.inf: - navigate_info[2] = dm - observation.childs['R'].dist_min_to_target - action_info[2] = 1 - if observation.childs['B'] != -np.inf: - navigate_info[3] = dm - observation.childs['B'].dist_min_to_target - action_info[3] = 1 - except: - navigate_info = np.ones(4) - normalized_obs = np.zeros(len(normalized_obs)) - - # navigate_info_2 = np.copy(navigate_info) - # max_v = np.max(navigate_info_2) - # navigate_info_2 = navigate_info_2 / max_v - # navigate_info_2[navigate_info_2 < 1] = -1 - - max_v = np.max(navigate_info) - navigate_info = navigate_info / max_v - navigate_info[navigate_info < 0] = -1 - # navigate_info[abs(navigate_info) < 1] = 0 - # normalized_obs = navigate_info - - # navigate_info = np.concatenate((navigate_info, action_info)) - normalized_obs = np.concatenate((navigate_info, normalized_obs)) - # normalized_obs = np.concatenate((navigate_info, navigate_info_2)) - # print(normalized_obs) - return normalized_obs +""" +Collection of environment-specific ObservationBuilder. +""" +import collections +from typing import Optional, List, Dict, Tuple + +import numpy as np +from flatland.core.env import Environment +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.env_prediction_builder import PredictionBuilder +from flatland.core.grid.grid4_utils import get_new_position +from flatland.core.grid.grid_utils import coordinate_to_position +from flatland.envs.agent_utils import RailAgentStatus, EnvAgent +from flatland.utils.ordered_set import OrderedSet + + +class MyTreeObsForRailEnv(ObservationBuilder): + """ + TreeObsForRailEnv object. + + This object returns observation vectors for agents in the RailEnv environment. + The information is local to each agent and exploits the graph structure of the rail + network to simplify the representation of the state of the environment for each agent. + + For details about the features in the tree observation see the get() function. + """ + Node = collections.namedtuple('Node', 'dist_min_to_target ' + 'target_encountered ' + 'num_agents_same_direction ' + 'num_agents_opposite_direction ' + 'childs') + + tree_explored_actions_char = ['L', 'F', 'R', 'B'] + + def __init__(self, max_depth: int, predictor: PredictionBuilder = None): + super().__init__() + self.max_depth = max_depth + self.observation_dim = 2 + self.location_has_agent = {} + self.predictor = predictor + self.location_has_target = None + + self.switches_list = {} + self.switches_neighbours_list = [] + self.check_agent_descision = None + + def reset(self): + self.location_has_target = {tuple(agent.target): 1 for agent in self.env.agents} + + def set_switch_and_pre_switch(self, switch_list, pre_switch_list, check_agent_descision): + self.switches_list = switch_list + self.switches_neighbours_list = pre_switch_list + self.check_agent_descision = check_agent_descision + + def get_many(self, handles: Optional[List[int]] = None) -> Dict[int, Node]: + """ + Called whenever an observation has to be computed for the `env` environment, for each agent with handle + in the `handles` list. + """ + + if handles is None: + handles = [] + if self.predictor: + self.max_prediction_depth = 0 + self.predicted_pos = {} + self.predicted_dir = {} + self.predictions = self.predictor.get() + if self.predictions: + for t in range(self.predictor.max_depth + 1): + pos_list = [] + dir_list = [] + for a in handles: + if self.predictions[a] is None: + continue + pos_list.append(self.predictions[a][t][1:3]) + dir_list.append(self.predictions[a][t][3]) + self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) + self.predicted_dir.update({t: dir_list}) + self.max_prediction_depth = len(self.predicted_pos) + # Update local lookup table for all agents' positions + # ignore other agents not in the grid (only status active and done) + + self.location_has_agent = {} + self.location_has_agent_direction = {} + self.location_has_agent_speed = {} + self.location_has_agent_malfunction = {} + self.location_has_agent_ready_to_depart = {} + + for _agent in self.env.agents: + if _agent.status in [RailAgentStatus.ACTIVE, RailAgentStatus.DONE] and \ + _agent.position: + self.location_has_agent[tuple(_agent.position)] = 1 + self.location_has_agent_direction[tuple(_agent.position)] = _agent.direction + self.location_has_agent_speed[tuple(_agent.position)] = _agent.speed_data['speed'] + self.location_has_agent_malfunction[tuple(_agent.position)] = _agent.malfunction_data[ + 'malfunction'] + + if _agent.status in [RailAgentStatus.READY_TO_DEPART] and \ + _agent.initial_position: + self.location_has_agent_ready_to_depart[tuple(_agent.initial_position)] = \ + self.location_has_agent_ready_to_depart.get(tuple(_agent.initial_position), 0) + 1 + + observations = super().get_many(handles) + + return observations + + def get(self, handle: int = 0) -> Node: + """ + Computes the current observation for agent `handle` in env + + The observation vector is composed of 4 sequential parts, corresponding to data from the up to 4 possible + movements in a RailEnv (up to because only a subset of possible transitions are allowed in RailEnv). + The possible movements are sorted relative to the current orientation of the agent, rather than NESW as for + the transitions. The order is:: + + [data from 'left'] + [data from 'forward'] + [data from 'right'] + [data from 'back'] + + Each branch data is organized as:: + + [root node information] + + [recursive branch data from 'left'] + + [... from 'forward'] + + [... from 'right] + + [... from 'back'] + + Each node information is composed of 9 features: + + #1: + if own target lies on the explored branch the current distance from the agent in number of cells is stored. + + #2: + if another agents target is detected the distance in number of cells from the agents current location\ + is stored + + #3: + if another agent is detected the distance in number of cells from current agent position is stored. + + #4: + possible conflict detected + tot_dist = Other agent predicts to pass along this cell at the same time as the agent, we store the \ + distance in number of cells from current agent position + + 0 = No other agent reserve the same cell at similar time + + #5: + if an not usable switch (for agent) is detected we store the distance. + + #6: + This feature stores the distance in number of cells to the next branching (current node) + + #7: + minimum distance from node to the agent's target given the direction of the agent if this path is chosen + + #8: + agent in the same direction + n = number of agents present same direction \ + (possible future use: number of other agents in the same direction in this branch) + 0 = no agent present same direction + + #9: + agent in the opposite direction + n = number of agents present other direction than myself (so conflict) \ + (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) + 0 = no agent present other direction than myself + + #10: + malfunctioning/blokcing agents + n = number of time steps the oberved agent remains blocked + + #11: + slowest observed speed of an agent in same direction + 1 if no agent is observed + + min_fractional speed otherwise + #12: + number of agents ready to depart but no yet active + + Missing/padding nodes are filled in with -inf (truncated). + Missing values in present node are filled in with +inf (truncated). + + + In case of the root node, the values are [0, 0, 0, 0, distance from agent to target, own malfunction, own speed] + In case the target node is reached, the values are [0, 0, 0, 0, 0]. + """ + + if handle > len(self.env.agents): + print("ERROR: obs _get - handle ", handle, " len(agents)", len(self.env.agents)) + agent = self.env.agents[handle] # TODO: handle being treated as index + + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + return None + + possible_transitions = self.env.rail.get_transitions(*agent_virtual_position, agent.direction) + num_transitions = np.count_nonzero(possible_transitions) + + # Here information about the agent itself is stored + distance_map = self.env.distance_map.get() + + root_node_observation = MyTreeObsForRailEnv.Node(dist_min_to_target=distance_map[ + (handle, *agent_virtual_position, + agent.direction)], + target_encountered=0, + num_agents_same_direction=0, + num_agents_opposite_direction=0, + childs={}) + + visited = OrderedSet() + + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # If only one transition is possible, the tree is oriented with this transition as the forward branch. + orientation = agent.direction + + if num_transitions == 1: + orientation = np.argmax(possible_transitions) + + for i, branch_direction in enumerate([(orientation + i) % 4 for i in range(-1, 3)]): + if possible_transitions[branch_direction]: + new_cell = get_new_position(agent_virtual_position, branch_direction) + + branch_observation, branch_visited = \ + self._explore_branch(handle, new_cell, branch_direction, 1, 1) + root_node_observation.childs[self.tree_explored_actions_char[i]] = branch_observation + + visited |= branch_visited + else: + # add cells filled with infinity if no transition is possible + root_node_observation.childs[self.tree_explored_actions_char[i]] = -np.inf + self.env.dev_obs_dict[handle] = visited + + return root_node_observation + + def _explore_branch(self, handle, position, direction, tot_dist, depth): + """ + Utility function to compute tree-based observations. + We walk along the branch and collect the information documented in the get() function. + If there is a branching point a new node is created and each possible branch is explored. + """ + + # [Recursive branch opened] + if depth >= self.max_depth + 1: + return [], [] + + # Continue along direction until next switch or + # until no transitions are possible along the current direction (i.e., dead-ends) + # We treat dead-ends as nodes, instead of going back, to avoid loops + exploring = True + + visited = OrderedSet() + agent = self.env.agents[handle] + + other_agent_opposite_direction = 0 + other_agent_same_direction = 0 + + dist_min_to_target = self.env.distance_map.get()[handle, position[0], position[1], direction] + + last_is_dead_end = False + last_is_a_decision_cell = False + target_encountered = 0 + + cnt = 0 + while exploring: + + dist_min_to_target = min(dist_min_to_target, self.env.distance_map.get()[handle, position[0], position[1], + direction]) + + if agent.target == position: + target_encountered = 1 + + new_direction_me = direction + new_cell_me = position + a = self.env.agent_positions[new_cell_me] + if a != -1 and a != handle: + opp_agent = self.env.agents[a] + # look one step forward + # opp_possible_transitions = self.env.rail.get_transitions(*opp_agent.position, opp_agent.direction) + if opp_agent.direction != new_direction_me: # opp_possible_transitions[new_direction_me] == 0: + other_agent_opposite_direction += 1 + else: + other_agent_same_direction += 1 + + # ############################# + # ############################# + if (position[0], position[1], direction) in visited: + break + visited.add((position[0], position[1], direction)) + + # If the target node is encountered, pick that as node. Also, no further branching is possible. + if np.array_equal(position, self.env.agents[handle].target): + last_is_target = True + break + + exploring = False + + # Check number of possible transitions for agent and total number of transitions in cell (type) + possible_transitions = self.env.rail.get_transitions(*position, direction) + num_transitions = np.count_nonzero(possible_transitions) + # cell_transitions = self.env.rail.get_transitions(*position, direction) + transition_bit = bin(self.env.rail.get_full_transitions(*position)) + total_transitions = transition_bit.count("1") + + if num_transitions == 1: + # Check if dead-end, or if we can go forward along direction + nbits = total_transitions + if nbits == 1: + # Dead-end! + last_is_dead_end = True + + if self.check_agent_descision is not None: + ret_agents_on_switch, ret_agents_near_to_switch, agents_near_to_switch_all = \ + self.check_agent_descision(position, + direction, + self.switches_list, + self.switches_neighbours_list) + if ret_agents_on_switch: + last_is_a_decision_cell = True + break + + exploring = True + # convert one-hot encoding to 0,1,2,3 + cell_transitions = self.env.rail.get_transitions(*position, direction) + direction = np.argmax(cell_transitions) + position = get_new_position(position, direction) + + cnt += 1 + if cnt > 1000: + exploring = False + + # ############################# + # ############################# + # Modify here to append new / different features for each visited cell! + + node = MyTreeObsForRailEnv.Node(dist_min_to_target=dist_min_to_target, + target_encountered=target_encountered, + num_agents_opposite_direction=other_agent_opposite_direction, + num_agents_same_direction=other_agent_same_direction, + childs={}) + + # ############################# + # ############################# + # Start from the current orientation, and see which transitions are available; + # organize them as [left, forward, right, back], relative to the current orientation + # Get the possible transitions + possible_transitions = self.env.rail.get_transitions(*position, direction) + + for i, branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 3)]): + if last_is_dead_end and self.env.rail.get_transition((*position, direction), + (branch_direction + 2) % 4): + # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes + # it back + new_cell = get_new_position(position, (branch_direction + 2) % 4) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + (branch_direction + 2) % 4, + tot_dist + 1, + depth + 1) + node.childs[self.tree_explored_actions_char[i]] = branch_observation + if len(branch_visited) != 0: + visited |= branch_visited + elif last_is_a_decision_cell and possible_transitions[branch_direction]: + new_cell = get_new_position(position, branch_direction) + branch_observation, branch_visited = self._explore_branch(handle, + new_cell, + branch_direction, + tot_dist + 1, + depth + 1) + node.childs[self.tree_explored_actions_char[i]] = branch_observation + if len(branch_visited) != 0: + visited |= branch_visited + else: + # no exploring possible, add just cells with infinity + node.childs[self.tree_explored_actions_char[i]] = -np.inf + + if depth == self.max_depth: + node.childs.clear() + return node, visited + + def util_print_obs_subtree(self, tree: Node): + """ + Utility function to print tree observations returned by this object. + """ + self.print_node_features(tree, "root", "") + for direction in self.tree_explored_actions_char: + self.print_subtree(tree.childs[direction], direction, "\t") + + @staticmethod + def print_node_features(node: Node, label, indent): + print(indent, "Direction ", label, ": ", node.num_agents_same_direction, + ", ", node.num_agents_opposite_direction) + + def print_subtree(self, node, label, indent): + if node == -np.inf or not node: + print(indent, "Direction ", label, ": -np.inf") + return + + self.print_node_features(node, label, indent) + + if not node.childs: + return + + for direction in self.tree_explored_actions_char: + self.print_subtree(node.childs[direction], direction, indent + "\t") + + def set_env(self, env: Environment): + super().set_env(env) + if self.predictor: + self.predictor.set_env(self.env) + + def _reverse_dir(self, direction): + return int((direction + 2) % 4) + + +class GlobalObsForRailEnv(ObservationBuilder): + """ + Gives a global observation of the entire rail environment. + The observation is composed of the following elements: + + - transition map array with dimensions (env.height, env.width, 16),\ + assuming 16 bits encoding of transitions. + + - obs_agents_state: A 3D array (map_height, map_width, 5) with + - first channel containing the agents position and direction + - second channel containing the other agents positions and direction + - third channel containing agent/other agent malfunctions + - fourth channel containing agent/other agent fractional speeds + - fifth channel containing number of other agents ready to depart + + - obs_targets: Two 2D arrays (map_height, map_width, 2) containing respectively the position of the given agent\ + target and the positions of the other agents targets (flag only, no counter!). + """ + + def __init__(self): + super(GlobalObsForRailEnv, self).__init__() + + def set_env(self, env: Environment): + super().set_env(env) + + def reset(self): + self.rail_obs = np.zeros((self.env.height, self.env.width, 16)) + for i in range(self.rail_obs.shape[0]): + for j in range(self.rail_obs.shape[1]): + bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i, j] = np.array(bitlist) + + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray): + + agent = self.env.agents[handle] + if agent.status == RailAgentStatus.READY_TO_DEPART: + agent_virtual_position = agent.initial_position + elif agent.status == RailAgentStatus.ACTIVE: + agent_virtual_position = agent.position + elif agent.status == RailAgentStatus.DONE: + agent_virtual_position = agent.target + else: + return None + + obs_targets = np.zeros((self.env.height, self.env.width, 2)) + obs_agents_state = np.zeros((self.env.height, self.env.width, 5)) - 1 + + # TODO can we do this more elegantly? + # for r in range(self.env.height): + # for c in range(self.env.width): + # obs_agents_state[(r, c)][4] = 0 + obs_agents_state[:, :, 4] = 0 + + obs_agents_state[agent_virtual_position][0] = agent.direction + obs_targets[agent.target][0] = 1 + + for i in range(len(self.env.agents)): + other_agent: EnvAgent = self.env.agents[i] + + # ignore other agents not in the grid any more + if other_agent.status == RailAgentStatus.DONE_REMOVED: + continue + + obs_targets[other_agent.target][1] = 1 + + # second to fourth channel only if in the grid + if other_agent.position is not None: + # second channel only for other agents + if i != handle: + obs_agents_state[other_agent.position][1] = other_agent.direction + obs_agents_state[other_agent.position][2] = other_agent.malfunction_data['malfunction'] + obs_agents_state[other_agent.position][3] = other_agent.speed_data['speed'] + # fifth channel: all ready to depart on this position + if other_agent.status == RailAgentStatus.READY_TO_DEPART: + obs_agents_state[other_agent.initial_position][4] += 1 + return self.rail_obs, obs_agents_state, obs_targets + + +class LocalObsForRailEnv(ObservationBuilder): + """ + !!!!!!WARNING!!! THIS IS DEPRACTED AND NOT UPDATED TO FLATLAND 2.0!!!!! + Gives a local observation of the rail environment around the agent. + The observation is composed of the following elements: + + - transition map array of the local environment around the given agent, \ + with dimensions (view_height,2*view_width+1, 16), \ + assuming 16 bits encoding of transitions. + + - Two 2D arrays (view_height,2*view_width+1, 2) containing respectively, \ + if they are in the agent's vision range, its target position, the positions of the other targets. + + - A 2D array (view_height,2*view_width+1, 4) containing the one hot encoding of directions \ + of the other agents at their position coordinates, if they are in the agent's vision range. + + - A 4 elements array with one hot encoding of the direction. + + Use the parameters view_width and view_height to define the rectangular view of the agent. + The center parameters moves the agent along the height axis of this rectangle. If it is 0 the agent only has + observation in front of it. + + .. deprecated:: 2.0.0 + """ + + def __init__(self, view_width, view_height, center): + + super(LocalObsForRailEnv, self).__init__() + self.view_width = view_width + self.view_height = view_height + self.center = center + self.max_padding = max(self.view_width, self.view_height - self.center) + + def reset(self): + # We build the transition map with a view_radius empty cells expansion on each side. + # This helps to collect the local transition map view when the agent is close to a border. + self.max_padding = max(self.view_width, self.view_height) + self.rail_obs = np.zeros((self.env.height, + self.env.width, 16)) + for i in range(self.env.height): + for j in range(self.env.width): + bitlist = [int(digit) for digit in bin(self.env.rail.get_full_transitions(i, j))[2:]] + bitlist = [0] * (16 - len(bitlist)) + bitlist + self.rail_obs[i, j] = np.array(bitlist) + + def get(self, handle: int = 0) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray): + agents = self.env.agents + agent = agents[handle] + + # Correct agents position for padding + # agent_rel_pos[0] = agent.position[0] + self.max_padding + # agent_rel_pos[1] = agent.position[1] + self.max_padding + + # Collect visible cells as set to be plotted + visited, rel_coords = self.field_of_view(agent.position, agent.direction, ) + local_rail_obs = None + + # Add the visible cells to the observed cells + self.env.dev_obs_dict[handle] = set(visited) + + # Locate observed agents and their coresponding targets + local_rail_obs = np.zeros((self.view_height, 2 * self.view_width + 1, 16)) + obs_map_state = np.zeros((self.view_height, 2 * self.view_width + 1, 2)) + obs_other_agents_state = np.zeros((self.view_height, 2 * self.view_width + 1, 4)) + _idx = 0 + for pos in visited: + curr_rel_coord = rel_coords[_idx] + local_rail_obs[curr_rel_coord[0], curr_rel_coord[1], :] = self.rail_obs[pos[0], pos[1], :] + if pos == agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 0] = 1 + else: + for tmp_agent in agents: + if pos == tmp_agent.target: + obs_map_state[curr_rel_coord[0], curr_rel_coord[1], 1] = 1 + if pos != agent.position: + for tmp_agent in agents: + if pos == tmp_agent.position: + obs_other_agents_state[curr_rel_coord[0], curr_rel_coord[1], :] = np.identity(4)[ + tmp_agent.direction] + + _idx += 1 + + direction = np.identity(4)[agent.direction] + return local_rail_obs, obs_map_state, obs_other_agents_state, direction + + def get_many(self, handles: Optional[List[int]] = None) -> Dict[ + int, Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]: + """ + Called whenever an observation has to be computed for the `env` environment, for each agent with handle + in the `handles` list. + """ + + return super().get_many(handles) + + def field_of_view(self, position, direction, state=None): + # Compute the local field of view for an agent in the environment + data_collection = False + if state is not None: + temp_visible_data = np.zeros(shape=(self.view_height, 2 * self.view_width + 1, 16)) + data_collection = True + if direction == 0: + origin = (position[0] + self.center, position[1] - self.view_width) + elif direction == 1: + origin = (position[0] - self.view_width, position[1] - self.center) + elif direction == 2: + origin = (position[0] - self.center, position[1] + self.view_width) + else: + origin = (position[0] + self.view_width, position[1] + self.center) + visible = list() + rel_coords = list() + for h in range(self.view_height): + for w in range(2 * self.view_width + 1): + if direction == 0: + if 0 <= origin[0] - h < self.env.height and 0 <= origin[1] + w < self.env.width: + visible.append((origin[0] - h, origin[1] + w)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - h, origin[1] + w, :] + elif direction == 1: + if 0 <= origin[0] + w < self.env.height and 0 <= origin[1] + h < self.env.width: + visible.append((origin[0] + w, origin[1] + h)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + w, origin[1] + h, :] + elif direction == 2: + if 0 <= origin[0] + h < self.env.height and 0 <= origin[1] - w < self.env.width: + visible.append((origin[0] + h, origin[1] - w)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] + h, origin[1] - w, :] + else: + if 0 <= origin[0] - w < self.env.height and 0 <= origin[1] - h < self.env.width: + visible.append((origin[0] - w, origin[1] - h)) + rel_coords.append((h, w)) + # if data_collection: + # temp_visible_data[h, w, :] = state[origin[0] - w, origin[1] - h, :] + if data_collection: + return temp_visible_data + else: + return visible, rel_coords + + +def _split_node_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int) -> (np.ndarray, np.ndarray, + np.ndarray): + data = np.zeros(2) + + data[0] = 2.0 * int(node.num_agents_opposite_direction > 0) - 1.0 + # data[1] = 2.0 * int(node.num_agents_same_direction > 0) - 1.0 + data[1] = 2.0 * int(node.target_encountered > 0) - 1.0 + + return data + + +def _split_subtree_into_feature_groups(node: MyTreeObsForRailEnv.Node, dist_min_to_target: int, + current_tree_depth: int, + max_tree_depth: int) -> ( + np.ndarray, np.ndarray, np.ndarray): + if node == -np.inf: + remaining_depth = max_tree_depth - current_tree_depth + # reference: https://stackoverflow.com/questions/515214/total-number-of-nodes-in-a-tree-data-structure + num_remaining_nodes = int((4 ** (remaining_depth + 1) - 1) / (4 - 1)) + return [0] * num_remaining_nodes * 2 + + data = _split_node_into_feature_groups(node, dist_min_to_target) + + if not node.childs: + return data + + for direction in MyTreeObsForRailEnv.tree_explored_actions_char: + sub_data = _split_subtree_into_feature_groups(node.childs[direction], + node.dist_min_to_target, + current_tree_depth + 1, + max_tree_depth) + data = np.concatenate((data, sub_data)) + return data + + +def split_tree_into_feature_groups(tree: MyTreeObsForRailEnv.Node, max_tree_depth: int) -> ( + np.ndarray, np.ndarray, np.ndarray): + """ + This function splits the tree into three difference arrays of values + """ + data = _split_node_into_feature_groups(tree, 1000000.0) + + for direction in MyTreeObsForRailEnv.tree_explored_actions_char: + sub_data = _split_subtree_into_feature_groups(tree.childs[direction], + 1000000.0, + 1, + max_tree_depth) + data = np.concatenate((data, sub_data)) + + return data + + +def normalize_observation(observation: MyTreeObsForRailEnv.Node, tree_depth: int): + """ + This function normalizes the observation used by the RL algorithm + """ + data = split_tree_into_feature_groups(observation, tree_depth) + normalized_obs = data + + # navigate_info + navigate_info = np.zeros(4) + action_info = np.zeros(4) + np.seterr(all='raise') + try: + dm = observation.dist_min_to_target + if observation.childs['L'] != -np.inf: + navigate_info[0] = dm - observation.childs['L'].dist_min_to_target + action_info[0] = 1 + if observation.childs['F'] != -np.inf: + navigate_info[1] = dm - observation.childs['F'].dist_min_to_target + action_info[1] = 1 + if observation.childs['R'] != -np.inf: + navigate_info[2] = dm - observation.childs['R'].dist_min_to_target + action_info[2] = 1 + if observation.childs['B'] != -np.inf: + navigate_info[3] = dm - observation.childs['B'].dist_min_to_target + action_info[3] = 1 + except: + navigate_info = np.ones(4) + normalized_obs = np.zeros(len(normalized_obs)) + + # navigate_info_2 = np.copy(navigate_info) + # max_v = np.max(navigate_info_2) + # navigate_info_2 = navigate_info_2 / max_v + # navigate_info_2[navigate_info_2 < 1] = -1 + + max_v = np.max(navigate_info) + navigate_info = navigate_info / max_v + navigate_info[navigate_info < 0] = -1 + # navigate_info[abs(navigate_info) < 1] = 0 + # normalized_obs = navigate_info + + # navigate_info = np.concatenate((navigate_info, action_info)) + normalized_obs = np.concatenate((navigate_info, normalized_obs)) + # normalized_obs = np.concatenate((navigate_info, navigate_info_2)) + # print(normalized_obs) + return normalized_obs diff --git a/src/ppo/agent.py b/src/ppo/agent.py new file mode 100644 index 0000000000000000000000000000000000000000..86b210ad388f9d9638c5e1e037cdca27ffaf9e73 --- /dev/null +++ b/src/ppo/agent.py @@ -0,0 +1,106 @@ +import pickle + +import torch +# from model import PolicyNetwork +# from replay_memory import Episode, ReplayBuffer +from torch.distributions.categorical import Categorical + +from src.ppo.model import PolicyNetwork +from src.ppo.replay_memory import Episode, ReplayBuffer + +BUFFER_SIZE = 32_000 +BATCH_SIZE = 4096 +GAMMA = 0.98 +LR = 0.5e-4 +CLIP_FACTOR = .005 +UPDATE_EVERY = 30 + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +print("device:", device) + + +class Agent: + def __init__(self, state_size, action_size, num_agents): + self.policy = PolicyNetwork(state_size, action_size).to(device) + self.old_policy = PolicyNetwork(state_size, action_size).to(device) + self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=LR) + + self.episodes = [Episode() for _ in range(num_agents)] + self.memory = ReplayBuffer(BUFFER_SIZE) + self.t_step = 0 + + def reset(self): + self.finished = [False] * len(self.episodes) + + # Decide on an action to take in the environment + + def act(self, state, eps=None): + self.policy.eval() + with torch.no_grad(): + output = self.policy(torch.from_numpy(state).float().unsqueeze(0).to(device)) + return Categorical(output).sample().item() + + # Record the results of the agent's action and update the model + + def step(self, handle, state, action, next_state, agent_done, episode_done, collision): + if not self.finished[handle]: + if agent_done: + reward = 1 + elif collision: + reward = -.5 + else: + reward = 0 + + # Push experience into Episode memory + self.episodes[handle].push(state, action, reward, next_state, agent_done or episode_done) + + # When we finish the episode, discount rewards and push the experience into replay memory + if agent_done or episode_done: + self.episodes[handle].discount_rewards(GAMMA) + self.memory.push_episode(self.episodes[handle]) + self.episodes[handle].reset() + self.finished[handle] = True + + # Perform a gradient update every UPDATE_EVERY time steps + self.t_step = (self.t_step + 1) % UPDATE_EVERY + if self.t_step == 0 and len(self.memory) > BATCH_SIZE * 4: + self.learn(*self.memory.sample(BATCH_SIZE, device)) + + def learn(self, states, actions, rewards, next_state, done): + self.policy.train() + + responsible_outputs = torch.gather(self.policy(states), 1, actions) + old_responsible_outputs = torch.gather(self.old_policy(states), 1, actions).detach() + + # rewards = rewards - rewards.mean() + ratio = responsible_outputs / (old_responsible_outputs + 1e-5) + clamped_ratio = torch.clamp(ratio, 1. - CLIP_FACTOR, 1. + CLIP_FACTOR) + loss = -torch.min(ratio * rewards, clamped_ratio * rewards).mean() + + # Compute loss and perform a gradient step + self.old_policy.load_state_dict(self.policy.state_dict()) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + # Checkpointing methods + + def save(self, path, *data): + torch.save(self.policy.state_dict(), path / 'ppo/model_checkpoint.policy') + torch.save(self.optimizer.state_dict(), path / 'ppo/model_checkpoint.optimizer') + with open(path / 'ppo/model_checkpoint.meta', 'wb') as file: + pickle.dump(data, file) + + def load(self, path, *defaults): + try: + print("Loading model from checkpoint...") + print(path + 'ppo/model_checkpoint.policy') + self.policy.load_state_dict( + torch.load(path + 'ppo/model_checkpoint.policy', map_location=torch.device('cpu'))) + self.optimizer.load_state_dict( + torch.load(path + 'ppo/model_checkpoint.optimizer', map_location=torch.device('cpu'))) + with open(path + 'ppo/model_checkpoint.meta', 'rb') as file: + return pickle.load(file) + except: + print("No checkpoint file was found") + return defaults diff --git a/src/ppo/model.py b/src/ppo/model.py new file mode 100644 index 0000000000000000000000000000000000000000..51b86ff16691c03f6a754405352bb4cf48e4b914 --- /dev/null +++ b/src/ppo/model.py @@ -0,0 +1,20 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class PolicyNetwork(nn.Module): + def __init__(self, state_size, action_size, hidsize1=128, hidsize2=128, hidsize3=32): + super().__init__() + self.fc1 = nn.Linear(state_size, hidsize1) + self.fc2 = nn.Linear(hidsize1, hidsize2) + # self.fc3 = nn.Linear(hidsize2, hidsize3) + self.output = nn.Linear(hidsize2, action_size) + self.softmax = nn.Softmax(dim=1) + self.bn0 = nn.BatchNorm1d(state_size, affine=False) + + def forward(self, inputs): + x = self.bn0(inputs.float()) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + # x = F.relu(self.fc3(x)) + return self.softmax(self.output(x)) diff --git a/src/ppo/replay_memory.py b/src/ppo/replay_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6619b40169597d7a4b379f4ce2c9ddccd4cd9b --- /dev/null +++ b/src/ppo/replay_memory.py @@ -0,0 +1,53 @@ +import torch +import random +import numpy as np +from collections import namedtuple, deque, Iterable + + +Transition = namedtuple("Experience", ("state", "action", "reward", "next_state", "done")) + + +class Episode: + memory = [] + + def reset(self): + self.memory = [] + + def push(self, *args): + self.memory.append(tuple(args)) + + def discount_rewards(self, gamma): + running_add = 0. + for i, (state, action, reward, *rest) in list(enumerate(self.memory))[::-1]: + running_add = running_add * gamma + reward + self.memory[i] = (state, action, running_add, *rest) + + +class ReplayBuffer: + def __init__(self, buffer_size): + self.memory = deque(maxlen=buffer_size) + + def push(self, state, action, reward, next_state, done): + self.memory.append(Transition(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done)) + + def push_episode(self, episode): + for step in episode.memory: + self.push(*step) + + def sample(self, batch_size, device): + experiences = random.sample(self.memory, k=batch_size) + + states = torch.from_numpy(self.stack([e.state for e in experiences])).float().to(device) + actions = torch.from_numpy(self.stack([e.action for e in experiences])).long().to(device) + rewards = torch.from_numpy(self.stack([e.reward for e in experiences])).float().to(device) + next_states = torch.from_numpy(self.stack([e.next_state for e in experiences])).float().to(device) + dones = torch.from_numpy(self.stack([e.done for e in experiences]).astype(np.uint8)).float().to(device) + + return states, actions, rewards, next_states, dones + + def stack(self, states): + sub_dims = states[0].shape[1:] if isinstance(states[0], Iterable) else [1] + return np.reshape(np.array(states), (len(states), *sub_dims)) + + def __len__(self): + return len(self.memory)