Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • hebe0663/neurips2020-flatland-starter-kit
  • flatland/neurips2020-flatland-starter-kit
  • manavsinghal157/marl-flatland
3 results
Show changes
Commits on Source (141)
Showing
with 654 additions and 392 deletions
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
scratch/test-envs/
scratch/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
scratch/test-envs/
scratch/
# Checkpoints and replay buffers
!checkpoints/.gitkeep
replay_buffers/*
!replay_buffers/.gitkeep
\ No newline at end of file
MIT License
Copyright (c) 2020 Flatland
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
![AIcrowd-Logo](https://raw.githubusercontent.com/AIcrowd/AIcrowd/master/app/assets/images/misc/aicrowd-horizontal.png)
# Flatland Challenge Starter Kit
**[Follow these instructions to submit your solutions!](http://flatland.aicrowd.com/getting-started/first-submission.html)**
![flatland](https://i.imgur.com/0rnbSLY.gif)
# Round 1 - 3rd best RL solution
## Used agent
* [PPO Agent -> Mitchell Goff](https://github.com/mitchellgoffpc/flatland-training)
## LICENCE for the Observation EXTRA.py
The observation can be used freely and reused for further submissions. Only the author needs to be referred to
/mentioned in any submissions - if the entire observation or parts, or the main idea is used.
Author: Adrian Egli (adrian.egli@gmail.com)
[Linkedin](https://www.researchgate.net/profile/Adrian_Egli2)
[Researchgate](https://www.linkedin.com/in/adrian-egli-733a9544/)
Main links
---
* [Submit in 10 minutes](https://flatland.aicrowd.com/getting-started/first-submission.html?_ga=2.175036450.1456714032.1596434204-43124944.1552486604)
* [Flatland documentation](https://flatland.aicrowd.com/)
* [NeurIPS 2020 Challenge](https://www.aicrowd.com/challenges/neurips-2020-flatland-challenge/)
Communication
---
* [Discord Channel](https://discord.com/invite/hCR3CZG)
* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge)
* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
Author
---
- **[Sharada Mohanty](https://twitter.com/MeMohanty)**
🚂 This code is based on the official starter kit - NeurIPS 2020 Flatland Challenge
---
You can use for your own experiments full or reduced action space.
```python
def map_action(action):
# if full action space is used -> no mapping required
if get_action_size() == get_flatland_full_action_size():
return action
# if reduced action space is used -> the action has to be mapped to real flatland actions
# The reduced action space removes the DO_NOTHING action from Flatland.
if action == 0:
return RailEnvActions.MOVE_LEFT
if action == 1:
return RailEnvActions.MOVE_FORWARD
if action == 2:
return RailEnvActions.MOVE_RIGHT
if action == 3:
return RailEnvActions.STOP_MOVING
```
```python
set_action_size_full()
```
or
```python
set_action_size_reduced()
```
action space. The reduced action space just removes DO_NOTHING.
---
The used policy is based on the FastTreeObs in the official starter kit - NeurIPS 2020 Flatland Challenge. But the
FastTreeObs in this repo is an extended version.
[fast_tree_obs.py](./utils/fast_tree_obs.py)
---
Have a look into the [run.py](./run.py) file. There you can select using PPO or DDDQN as RL agents.
```python
####################################################
# EVALUATION PARAMETERS
set_action_size_full()
# Print per-step logs
VERBOSE = True
USE_FAST_TREEOBS = True
if False:
# -------------------------------------------------------------------------------------------------------
# RL solution
# -------------------------------------------------------------------------------------------------------
# 116591 adrian_egli
# graded 71.305 0.633 RL Successfully Graded ! More details about this submission can be found at:
# http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/51
# Fri, 22 Jan 2021 23:37:56
set_action_size_reduced()
load_policy = "DDDQN"
checkpoint = "./checkpoints/210122120236-3000.pth" # 17.011131341978228
EPSILON = 0.0
if False:
# -------------------------------------------------------------------------------------------------------
# RL solution
# -------------------------------------------------------------------------------------------------------
# 116658 adrian_egli
# graded 73.821 0.655 RL Successfully Graded ! More details about this submission can be found at:
# http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/52
# Sat, 23 Jan 2021 07:41:35
set_action_size_reduced()
load_policy = "PPO"
checkpoint = "./checkpoints/210122235754-5000.pth" # 16.00113400887389
EPSILON = 0.0
if True:
# -------------------------------------------------------------------------------------------------------
# RL solution
# -------------------------------------------------------------------------------------------------------
# 116659 adrian_egli
# graded 80.579 0.715 RL Successfully Graded ! More details about this submission can be found at:
# http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/53
# Sat, 23 Jan 2021 07:45:49
set_action_size_reduced()
load_policy = "DDDQN"
checkpoint = "./checkpoints/210122165109-5000.pth" # 17.993750197899438
EPSILON = 0.0
if False:
# -------------------------------------------------------------------------------------------------------
# !! This is not a RL solution !!!!
# -------------------------------------------------------------------------------------------------------
# 116727 adrian_egli
# graded 106.786 0.768 RL Successfully Graded ! More details about this submission can be found at:
# http://gitlab.aicrowd.com/adrian_egli/neurips2020-flatland-starter-kit/issues/54
# Sat, 23 Jan 2021 14:31:50
set_action_size_reduced()
load_policy = "DeadLockAvoidance"
checkpoint = None
EPSILON = 0.0
```
---
A deadlock avoidance agent is implemented. The agent only lets the train take the shortest route. And it tries to avoid as many deadlocks as possible.
* [dead_lock_avoidance_agent.py](./utils/dead_lock_avoidance_agent.py)
---
The policy interface has changed, please have a look into
* [policy.py](./reinforcement_learning/policy.py)
---
See the tensorboard training output to get some insights:
```
tensorboard --logdir ./runs_bench
```
---
```
python reinforcement_learning/multi_agent_training.py --use_fast_tree_observation --checkpoint_interval 1000 -n 5000
--policy DDDQN -t 2 --action_size reduced --buffer_siz 128000
```
[multi_agent_training.py](./reinforcement_learning/multi_agent_training.py)
has new or changed parameters. Most important new or changed parameters for training.
* policy : [DDDQN, PPO, DeadLockAvoidance, DeadLockAvoidanceWithDecision, MultiDecision] : Default value
DeadLockAvoidance
* use_fast_tree_observation : [false,true] : Default value = true
* action_size: [full, reduced] : Default value = full
```
usage: multi_agent_training.py [-h] [-n N_EPISODES] [--n_agent_fixed]
[-t TRAINING_ENV_CONFIG]
[-e EVALUATION_ENV_CONFIG]
[--n_evaluation_episodes N_EVALUATION_EPISODES]
[--checkpoint_interval CHECKPOINT_INTERVAL]
[--eps_start EPS_START] [--eps_end EPS_END]
[--eps_decay EPS_DECAY]
[--buffer_size BUFFER_SIZE]
[--buffer_min_size BUFFER_MIN_SIZE]
[--restore_replay_buffer RESTORE_REPLAY_BUFFER]
[--save_replay_buffer SAVE_REPLAY_BUFFER]
[--batch_size BATCH_SIZE] [--gamma GAMMA]
[--tau TAU] [--learning_rate LEARNING_RATE]
[--hidden_size HIDDEN_SIZE]
[--update_every UPDATE_EVERY]
[--use_gpu USE_GPU] [--num_threads NUM_THREADS]
[--render] [--load_policy LOAD_POLICY]
[--use_fast_tree_observation]
[--max_depth MAX_DEPTH] [--policy POLICY]
[--action_size ACTION_SIZE]
optional arguments:
-h, --help show this help message and exit
-n N_EPISODES, --n_episodes N_EPISODES
number of episodes to run
--n_agent_fixed hold the number of agent fixed
-t TRAINING_ENV_CONFIG, --training_env_config TRAINING_ENV_CONFIG
training config id (eg 0 for Test_0)
-e EVALUATION_ENV_CONFIG, --evaluation_env_config EVALUATION_ENV_CONFIG
evaluation config id (eg 0 for Test_0)
--n_evaluation_episodes N_EVALUATION_EPISODES
number of evaluation episodes
--checkpoint_interval CHECKPOINT_INTERVAL
checkpoint interval
--eps_start EPS_START
max exploration
--eps_end EPS_END min exploration
--eps_decay EPS_DECAY
exploration decay
--buffer_size BUFFER_SIZE
replay buffer size
--buffer_min_size BUFFER_MIN_SIZE
min buffer size to start training
--restore_replay_buffer RESTORE_REPLAY_BUFFER
replay buffer to restore
--save_replay_buffer SAVE_REPLAY_BUFFER
save replay buffer at each evaluation interval
--batch_size BATCH_SIZE
minibatch size
--gamma GAMMA discount factor
--tau TAU soft update of target parameters
--learning_rate LEARNING_RATE
learning rate
--hidden_size HIDDEN_SIZE
hidden size (2 fc layers)
--update_every UPDATE_EVERY
how often to update the network
--use_gpu USE_GPU use GPU if available
--num_threads NUM_THREADS
number of threads PyTorch can use
--render render 1 episode in 100
--load_policy LOAD_POLICY
policy filename (reference) to load
--use_fast_tree_observation
use FastTreeObs instead of stock TreeObs
--max_depth MAX_DEPTH
max depth
--policy POLICY policy name [DDDQN, PPO, DeadLockAvoidance,
DeadLockAvoidanceWithDecision, MultiDecision]
--action_size ACTION_SIZE
define the action size [reduced,full]
```
---
If you have any questions write me on the official discord channel **aiAdrian**
(Adrian Egli - adrian.egli@gmail.com)
Credits
---
* Florian Laurent <florian@aicrowd.com>
* Erik Nygren <erik.nygren@sbb.ch>
* Adrian Egli <adrian.egli@sbb.ch>
* Sharada Mohanty <mohanty@aicrowd.com>
* Christian Baumberger <christian.baumberger@sbb.ch>
* Guillaume Mollard <guillaume.mollard2@gmail.com>
Main links
---
* [Flatland documentation](https://flatland.aicrowd.com/)
* [Flatland Challenge](https://www.aicrowd.com/challenges/flatland)
Communication
---
* [Discord Channel](https://discord.com/invite/hCR3CZG)
* [Discussion Forum](https://discourse.aicrowd.com/c/neurips-2020-flatland-challenge)
* [Issue Tracker](https://gitlab.aicrowd.com/flatland/flatland/issues/)
\ No newline at end of file
{
"challenge_id": "neurips-2020-flatland-challenge",
"grader_id": "neurips-2020-flatland-challenge",
"debug": false,
"tags": ["RL"]
}
{
"challenge_id": "neurips-2020-flatland-challenge",
"grader_id": "neurips-2020-flatland-challenge",
"debug": false,
"tags": ["RL"]
}
......@@ -3,3 +3,4 @@ git
vim
ssh
gcc
build-essential
\ No newline at end of file
File added
File added
File added
File added
File added
File added
File added
File added
File deleted
name: flatland-rl
channels:
- anaconda
- conda-forge
- defaults
dependencies:
- tk=8.6.8
- cairo=1.16.0
- cairocffi=1.1.0
- cairosvg=2.4.2
- cffi=1.12.3
- cssselect2=0.2.1
- defusedxml=0.6.0
- fontconfig=2.13.1
- freetype=2.10.0
- gettext=0.19.8.1
- glib=2.58.3
- icu=64.2
- jpeg=9c
- libiconv=1.15
- libpng=1.6.37
- libtiff=4.0.10
- libuuid=2.32.1
- libxcb=1.13
- libxml2=2.9.9
- lz4-c=1.8.3
- olefile=0.46
- pcre=8.41
- pillow=5.3.0
- pixman=0.38.0
- pthread-stubs=0.4
- pycairo=1.18.1
- pycparser=2.19
- tinycss2=1.0.2
- webencodings=0.5.1
- xorg-kbproto=1.0.7
- xorg-libice=1.0.10
- xorg-libsm=1.2.3
- xorg-libx11=1.6.8
- xorg-libxau=1.0.9
- xorg-libxdmcp=1.1.3
- xorg-libxext=1.3.4
- xorg-libxrender=0.9.10
- xorg-renderproto=0.11.1
- xorg-xextproto=7.3.0
- xorg-xproto=7.0.31
- zstd=1.4.0
- _libgcc_mutex=0.1
- ca-certificates=2019.5.15
- certifi=2019.6.16
- libedit=3.1.20181209
- libffi=3.2.1
- ncurses=6.1
- openssl=1.1.1c
- pip=19.1.1
- python=3.6.8
- readline=7.0
- setuptools=41.0.1
- sqlite=3.29.0
- wheel=0.33.4
- xz=5.2.4
- zlib=1.2.11
- pip:
- atomicwrites==1.3.0
- importlib-metadata==0.19
- importlib-resources==1.0.2
- attrs==19.1.0
- chardet==3.0.4
- click==7.0
- cloudpickle==1.2.2
- crowdai-api==0.1.21
- cycler==0.10.0
- filelock==3.0.12
- flatland-rl==2.2.1
- future==0.17.1
- gym==0.14.0
- idna==2.8
- kiwisolver==1.1.0
- lxml==4.4.0
- matplotlib==3.1.1
- more-itertools==7.2.0
- msgpack==0.6.1
- msgpack-numpy==0.4.4.3
- numpy==1.17.0
- packaging==19.0
- pandas==0.25.0
- pluggy==0.12.0
- py==1.8.0
- pyarrow==0.14.1
- pyglet==1.3.2
- pyparsing==2.4.1.1
- pytest==5.0.1
- pytest-runner==5.1
- python-dateutil==2.8.0
- python-gitlab==1.10.0
- pytz==2019.1
- torch==1.5.0
- recordtype==1.3
- redis==3.3.2
- requests==2.22.0
- scipy==1.3.1
- six==1.12.0
- svgutils==0.3.1
- timeout-decorator==0.4.1
- toml==0.10.0
- tox==3.13.2
- urllib3==1.25.3
- ushlex==0.99.1
- virtualenv==16.7.2
- wcwidth==0.1.7
- xarray==0.12.3
- zipp==0.5.2
name: flatland-rl
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- psutil==5.7.2
- pytorch==1.6.0
- pip==20.2.3
- python==3.6.8
- pip:
- tensorboard==2.3.0
- tensorboardx==2.1
\ No newline at end of file
#!/usr/bin/env python
import collections
from typing import Optional, List, Dict, Tuple
import numpy as np
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.env_prediction_builder import PredictionBuilder
from flatland.envs.agent_utils import RailAgentStatus, EnvAgent
class CustomObservationBuilder(ObservationBuilder):
"""
Template for building a custom observation builder for the RailEnv class
The observation in this case composed of the following elements:
- transition map array with dimensions (env.height, env.width),\
where the value at X,Y will represent the 16 bits encoding of transition-map at that point.
- the individual agent object (with position, direction, target information available)
"""
def __init__(self):
super(CustomObservationBuilder, self).__init__()
def set_env(self, env: Environment):
super().set_env(env)
# Note :
# The instantiations which depend on parameters of the Env object should be
# done here, as it is only here that the updated self.env instance is available
self.rail_obs = np.zeros((self.env.height, self.env.width))
def reset(self):
"""
Called internally on every env.reset() call,
to reset any observation specific variables that are being used
"""
self.rail_obs[:] = 0
for _x in range(self.env.width):
for _y in range(self.env.height):
# Get the transition map value at location _x, _y
transition_value = self.env.rail.get_full_transitions(_y, _x)
self.rail_obs[_y, _x] = transition_value
def get(self, handle: int = 0):
"""
Returns the built observation for a single agent with handle : handle
In this particular case, we return
- the global transition_map of the RailEnv,
- a tuple containing, the current agent's:
- state
- position
- direction
- initial_position
- target
"""
agent = self.env.agents[handle]
"""
Available information for each agent object :
- agent.status : [RailAgentStatus.READY_TO_DEPART, RailAgentStatus.ACTIVE, RailAgentStatus.DONE]
- agent.position : Current position of the agent
- agent.direction : Current direction of the agent
- agent.initial_position : Initial Position of the agent
- agent.target : Target position of the agent
"""
status = agent.status
position = agent.position
direction = agent.direction
initial_position = agent.initial_position
target = agent.target
"""
You can also optionally access the states of the rest of the agents by
using something similar to
for i in range(len(self.env.agents)):
other_agent: EnvAgent = self.env.agents[i]
# ignore other agents not in the grid any more
if other_agent.status == RailAgentStatus.DONE_REMOVED:
continue
## Gather other agent specific params
other_agent_status = other_agent.status
other_agent_position = other_agent.position
other_agent_direction = other_agent.direction
other_agent_initial_position = other_agent.initial_position
other_agent_target = other_agent.target
## Do something nice here if you wish
"""
return self.rail_obs, (status, position, direction, initial_position, target)
import copy
import os
import pickle
import random
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from reinforcement_learning.model import DuelingQNetwork
from reinforcement_learning.policy import Policy, LearningPolicy
from reinforcement_learning.replay_buffer import ReplayBuffer
class DDDQNPolicy(LearningPolicy):
"""Dueling Double DQN policy"""
def __init__(self, state_size, action_size, in_parameters, evaluation_mode=False):
print(">> DDDQNPolicy")
super(Policy, self).__init__()
self.ddqn_parameters = in_parameters
self.evaluation_mode = evaluation_mode
self.state_size = state_size
self.action_size = action_size
self.double_dqn = True
self.hidsize = 128
if not evaluation_mode:
self.hidsize = self.ddqn_parameters.hidden_size
self.buffer_size = self.ddqn_parameters.buffer_size
self.batch_size = self.ddqn_parameters.batch_size
self.update_every = self.ddqn_parameters.update_every
self.learning_rate = self.ddqn_parameters.learning_rate
self.tau = self.ddqn_parameters.tau
self.gamma = self.ddqn_parameters.gamma
self.buffer_min_size = self.ddqn_parameters.buffer_min_size
# Device
if self.ddqn_parameters.use_gpu and torch.cuda.is_available():
self.device = torch.device("cuda:0")
# print("🐇 Using GPU")
else:
self.device = torch.device("cpu")
# print("🐢 Using CPU")
# Q-Network
self.qnetwork_local = DuelingQNetwork(state_size,
action_size,
hidsize1=self.hidsize,
hidsize2=self.hidsize).to(self.device)
if not evaluation_mode:
self.qnetwork_target = copy.deepcopy(self.qnetwork_local)
self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.learning_rate)
self.memory = ReplayBuffer(action_size, self.buffer_size, self.batch_size, self.device)
self.t_step = 0
self.loss = 0.0
else:
self.memory = ReplayBuffer(action_size, 1, 1, self.device)
self.loss = 0.0
def act(self, handle, state, eps=0.):
state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
self.qnetwork_local.eval()
with torch.no_grad():
action_values = self.qnetwork_local(state)
self.qnetwork_local.train()
# Epsilon-greedy action selection
if random.random() >= eps:
return np.argmax(action_values.cpu().data.numpy())
else:
return random.choice(np.arange(self.action_size))
def step(self, handle, state, action, reward, next_state, done):
assert not self.evaluation_mode, "Policy has been initialized for evaluation only."
# Save experience in replay memory
self.memory.add(state, action, reward, next_state, done)
# Learn every UPDATE_EVERY time steps.
self.t_step = (self.t_step + 1) % self.update_every
if self.t_step == 0:
# If enough samples are available in memory, get random subset and learn
if len(self.memory) > self.buffer_min_size and len(self.memory) > self.batch_size:
self._learn()
def _learn(self):
experiences = self.memory.sample()
states, actions, rewards, next_states, dones, _ = experiences
# Get expected Q values from local model
q_expected = self.qnetwork_local(states).gather(1, actions)
if self.double_dqn:
# Double DQN
q_best_action = self.qnetwork_local(next_states).max(1)[1]
q_targets_next = self.qnetwork_target(next_states).gather(1, q_best_action.unsqueeze(-1))
else:
# DQN
q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(-1)
# Compute Q targets for current states
q_targets = rewards + (self.gamma * q_targets_next * (1 - dones))
# Compute loss
self.loss = F.mse_loss(q_expected, q_targets)
# Minimize the loss
self.optimizer.zero_grad()
self.loss.backward()
self.optimizer.step()
# Update target network
self._soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
def _soft_update(self, local_model, target_model, tau):
# Soft update model parameters.
# θ_target = τ*θ_local + (1 - τ)*θ_target
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
def save(self, filename):
torch.save(self.qnetwork_local.state_dict(), filename + ".local")
torch.save(self.qnetwork_target.state_dict(), filename + ".target")
def load(self, filename):
try:
if os.path.exists(filename + ".local") and os.path.exists(filename + ".target"):
self.qnetwork_local.load_state_dict(torch.load(filename + ".local", map_location=self.device))
print("qnetwork_local loaded ('{}')".format(filename + ".local"))
if not self.evaluation_mode:
self.qnetwork_target.load_state_dict(torch.load(filename + ".target", map_location=self.device))
print("qnetwork_target loaded ('{}' )".format(filename + ".target"))
else:
print(">> Checkpoint not found, using untrained policy! ('{}', '{}')".format(filename + ".local",
filename + ".target"))
except Exception as exc:
print(exc)
print("Couldn't load policy from, using untrained policy! ('{}', '{}')".format(filename + ".local",
filename + ".target"))
def save_replay_buffer(self, filename):
memory = self.memory.memory
with open(filename, 'wb') as f:
pickle.dump(list(memory)[-500000:], f)
def load_replay_buffer(self, filename):
with open(filename, 'rb') as f:
self.memory.memory = pickle.load(f)
def test(self):
self.act(0, np.array([[0] * self.state_size]))
self._learn()
def clone(self):
me = DDDQNPolicy(self.state_size, self.action_size, self.ddqn_parameters, evaluation_mode=True)
me.qnetwork_target = copy.deepcopy(self.qnetwork_local)
me.qnetwork_target = copy.deepcopy(self.qnetwork_target)
return me
from flatland.envs.agent_utils import RailAgentStatus
from flatland.envs.rail_env import RailEnv, RailEnvActions
from reinforcement_learning.policy import HybridPolicy
from reinforcement_learning.ppo_agent import PPOPolicy
from utils.agent_action_config import map_rail_env_action
from utils.dead_lock_avoidance_agent import DeadLockAvoidanceAgent
class DeadLockAvoidanceWithDecisionAgent(HybridPolicy):
def __init__(self, env: RailEnv, state_size, action_size, learning_agent):
print(">> DeadLockAvoidanceWithDecisionAgent")
super(DeadLockAvoidanceWithDecisionAgent, self).__init__()
self.env = env
self.state_size = state_size
self.action_size = action_size
self.learning_agent = learning_agent
self.dead_lock_avoidance_agent = DeadLockAvoidanceAgent(self.env, action_size, False)
self.policy_selector = PPOPolicy(state_size, 2)
self.memory = self.learning_agent.memory
self.loss = self.learning_agent.loss
def step(self, handle, state, action, reward, next_state, done):
select = self.policy_selector.act(handle, state, 0.0)
self.policy_selector.step(handle, state, select, reward, next_state, done)
self.dead_lock_avoidance_agent.step(handle, state, action, reward, next_state, done)
self.learning_agent.step(handle, state, action, reward, next_state, done)
self.loss = self.learning_agent.loss
def act(self, handle, state, eps=0.):
select = self.policy_selector.act(handle, state, eps)
if select == 0:
return self.learning_agent.act(handle, state, eps)
return self.dead_lock_avoidance_agent.act(handle, state, -1.0)
def save(self, filename):
self.dead_lock_avoidance_agent.save(filename)
self.learning_agent.save(filename)
self.policy_selector.save(filename + '.selector')
def load(self, filename):
self.dead_lock_avoidance_agent.load(filename)
self.learning_agent.load(filename)
self.policy_selector.load(filename + '.selector')
def start_step(self, train):
self.dead_lock_avoidance_agent.start_step(train)
self.learning_agent.start_step(train)
self.policy_selector.start_step(train)
def end_step(self, train):
self.dead_lock_avoidance_agent.end_step(train)
self.learning_agent.end_step(train)
self.policy_selector.end_step(train)
def start_episode(self, train):
self.dead_lock_avoidance_agent.start_episode(train)
self.learning_agent.start_episode(train)
self.policy_selector.start_episode(train)
def end_episode(self, train):
self.dead_lock_avoidance_agent.end_episode(train)
self.learning_agent.end_episode(train)
self.policy_selector.end_episode(train)
def load_replay_buffer(self, filename):
self.dead_lock_avoidance_agent.load_replay_buffer(filename)
self.learning_agent.load_replay_buffer(filename)
self.policy_selector.load_replay_buffer(filename + ".selector")
def test(self):
self.dead_lock_avoidance_agent.test()
self.learning_agent.test()
self.policy_selector.test()
def reset(self, env: RailEnv):
self.env = env
self.dead_lock_avoidance_agent.reset(env)
self.learning_agent.reset(env)
self.policy_selector.reset(env)
def clone(self):
return self