Commit 636faf72 authored by nilabha's avatar nilabha
Browse files

Merge branch 'flatland3-pettingzoo' into 'flatland-3-updates'

Flatland3 pettingzoo

See merge request !321
parents a578d2a1 7ba0f38c
Pipeline #8493 failed with stages
in 7 minutes and 42 seconds
Authors
=======
.. toctree::
:maxdepth: 2
.. include:: ../AUTHORS.rst
Multi-Agent Interface
=======
.. include:: interface/pettingzoo.rst
.. include:: interface/wrappers.rst
Multi-Agent Pettingzoo Usage
=======
We can use the PettingZoo interface by proving the rail env to the petting zoo wrapper as shown below in the example.
.. literalinclude:: ../tests/test_pettingzoo_interface.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Multi-Agent Interface Stable Baseline 3 Training
=======
.. literalinclude:: ../flatland/contrib/training/flatland_pettingzoo_stable_baselines.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
Multi-Agent Interface Rllib Training
=======
.. literalinclude:: ../flatland/contrib/training/flatland_pettingzoo_rllib.py
:language: python
:start-after: __sphinx_doc_begin__
:end-before: __sphinx_doc_end__
\ No newline at end of file
Multi-Agent Interfaces
==============
.. toctree::
:maxdepth: 2
10_interface
......@@ -15,6 +15,7 @@ Welcome to flatland's documentation!
07_changes
08_authors
09_faq_toc
10_interface
Indices and tables
==================
......
# PettingZoo
> PettingZoo (https://www.pettingzoo.ml/) is a collection of multi-agent environments for reinforcement learning. We build a pettingzoo interface for flatland.
## Background
PettingZoo is a popular multi-agent environment library (https://arxiv.org/abs/2009.14471) that aims to be the gym standard for Multi-Agent Reinforcement Learning. We list the below advantages that make it suitable for use with flatland
- Works with both rllib (https://docs.ray.io/en/latest/rllib.html) and stable baselines 3 (https://stable-baselines3.readthedocs.io/) using wrappers from Super Suit.
- Clean API (https://www.pettingzoo.ml/api) with additional facilities/api for parallel, saving observation, recording using gym monitor, processing, normalising observations
- Scikit-learn inspired api
e.g.
```python
act = model.predict(obs, deterministic=True)[0]
```
- Parallel learning using literally 2 lines of code to use with stable baselines 3
```python
env = ss.pettingzoo_env_to_vec_env_v0(env)
env = ss.concat_vec_envs_v0(env, 8, num_cpus=4, base_class=stable_baselines3)
```
- Tested and supports various multi-agent environments with many agents comparable to flatland. e.g. https://www.pettingzoo.ml/magent
- Clean interface means we can custom add an experimenting tool like wandb and have full flexibility to save information we want
PettingZoo
==========
..
PettingZoo (https://www.pettingzoo.ml/) is a collection of multi-agent environments for reinforcement learning. We build a pettingzoo interface for flatland.
Background
----------
PettingZoo is a popular multi-agent environment library (https://arxiv.org/abs/2009.14471) that aims to be the gym standard for Multi-Agent Reinforcement Learning. We list the below advantages that make it suitable for use with flatland
* Works with both rllib (https://docs.ray.io/en/latest/rllib.html) and stable baselines 3 (https://stable-baselines3.readthedocs.io/) using wrappers from Super Suit.
* Clean API (https://www.pettingzoo.ml/api) with additional facilities/api for parallel, saving observation, recording using gym monitor, processing, normalising observations
* Scikit-learn inspired api
e.g.
.. code-block:: python
act = model.predict(obs, deterministic=True)[0]
* Parallel learning using literally 2 lines of code to use with stable baselines 3
.. code-block:: python
env = ss.pettingzoo_env_to_vec_env_v0(env)
env = ss.concat_vec_envs_v0(env, 8, num_cpus=4, base_class=’stable_baselines3’)
* Tested and supports various multi-agent environments with many agents comparable to flatland. e.g. https://www.pettingzoo.ml/magent
* Clean interface means we can custom add an experimenting tool like wandb and have full flexibility to save information we want
# Environment Wrappers
> We provide various environment wrappers to work with both the rail env and the petting zoo interface.
## Background
These wrappers changes certain environment behavior which can help to get better reinforcement learning training.
## Supported Inbuilt Wrappers
We provide 2 sample wrappers for ShortestPathAction wrapper and SkipNoChoice wrapper. The wrappers requires many env properties that are only created on environment reset. Hence before using the wrapper, we must reset the rail env. To use the wrappers, simply pass the resetted rail env. Code samples are shown below for each wrapper.
### ShortestPathAction Wrapper
To use the ShortestPathAction Wrapper, simply wrap the rail env as follows
```python
rail_env.reset(random_seed=1)
rail_env = ShortestPathActionWrapper(rail_env)
```
The shortest path action wrapper maps the existing action space into 3 actions - Shortest Path (`0`), Next Shortest Path (`1`) and Stop (`2`). Hence, we must ensure that the predicted action should always be one of these (0, 1 and 2) actions. To route all agents in the shortest path, pass `0` as the action.
### SkipNoChoice Wrapper
To use the SkipNoChoiceWrapper, simply wrap the rail env as follows
```python
rail_env.reset(random_seed=1)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
```
Environment Wrappers
====================
..
We provide various environment wrappers to work with both the rail env and the petting zoo interface.
Background
----------
These wrappers changes certain environment behavior which can help to get better reinforcement learning training.
Supported Inbuilt Wrappers
--------------------------
We provide 2 sample wrappers for ShortestPathAction wrapper and SkipNoChoice wrapper. The wrappers requires many env properties that are only created on environment reset. Hence before using the wrapper, we must reset the rail env. To use the wrappers, simply pass the resetted rail env. Code samples are shown below for each wrapper.
ShortestPathAction Wrapper
^^^^^^^^^^^^^^^^^^^^^^^^^^
To use the ShortestPathAction Wrapper, simply wrap the rail env as follows
.. code-block:: python
rail_env.reset(random_seed=1)
rail_env = ShortestPathActionWrapper(rail_env)
The shortest path action wrapper maps the existing action space into 3 actions - Shortest Path (\ ``0``\ ), Next Shortest Path (\ ``1``\ ) and Stop (\ ``2``\ ). Hence, we must ensure that the predicted action should always be one of these (0, 1 and 2) actions. To route all agents in the shortest path, pass ``0`` as the action.
SkipNoChoice Wrapper
^^^^^^^^^^^^^^^^^^^^
To use the SkipNoChoiceWrapper, simply wrap the rail env as follows
.. code-block:: python
rail_env.reset(random_seed=1)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
......@@ -13,6 +13,10 @@ from mava.wrappers.flatland import infer_observation_space, normalize_observatio
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
......@@ -67,13 +71,9 @@ class raw_env(AECEnv, gym.Env):
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
self.use_renderer = use_renderer
self.renderer = None
if self.use_renderer:
self.initialize_renderer()
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
......@@ -187,9 +187,6 @@ class raw_env(AECEnv, gym.Env):
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
if self.use_renderer:
if self.renderer: #TODO: Errors with RLLib with renderer as None.
self.renderer.reset()
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
......@@ -268,53 +265,40 @@ class raw_env(AECEnv, gym.Env):
self.obs = observations
return observations
def set_probs(self, probs):
self.probs = probs
def render(self, mode='human'):
def render(self, mode='rgb_array'):
"""
This methods provides the option to render the
environment's behavior to a window which should be
readable to the human eye if mode is set to 'human'.
environment's behavior as an image or to a window.
"""
if not self.use_renderer:
return
if not self.renderer:
self.initialize_renderer(mode=mode)
return self.update_renderer(mode=mode)
def initialize_renderer(self, mode="human"):
# Initiate the renderer
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG",
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
self.renderer.show = False
def update_renderer(self, mode='human'):
image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
return image[:,:,:3]
def set_renderer(self, renderer):
self.use_renderer = renderer
if self.use_renderer:
self.initialize_renderer(mode=self.use_renderer)
if mode == "rgb_array":
env_rgb_array = self._environment.render(mode)
if not hasattr(self, "image_shape "):
self.image_shape = env_rgb_array.shape
if not hasattr(self, "probs "):
self.probs = [[0., 0., 0., 0.]]
fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100),
constrained_layout=True, dpi=100)
df = pd.DataFrame(np.array(self.probs).T)
sns.barplot(x=df.index, y=0, data=df, ax=ax)
ax.set(xlabel='actions', ylabel='probs')
fig.canvas.draw()
X = np.array(fig.canvas.renderer.buffer_rgba())
Image.fromarray(X)
# Image.fromarray(X)
rgb_image = np.array(Image.fromarray(X).convert('RGB'))
plt.close(fig)
q_value_rgb_array = rgb_image
return np.append(env_rgb_array, q_value_rgb_array, axis=1)
else:
return self._environment.render(mode)
def close(self):
# self._environment.close()
if self.renderer:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
self._environment.close()
def _obtain_preprocessor(
self, preprocessor):
def _obtain_preprocessor(self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
......
......@@ -3,4 +3,7 @@ id-mava
id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
\ No newline at end of file
ray==1.5.2
seaborn
matplotlib
pandas
\ No newline at end of file
......@@ -8,7 +8,7 @@ from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
......@@ -427,6 +427,8 @@ class RailEnv(Environment):
}
# Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations()
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict
def _fix_agent_after_malfunction(self, agent: EnvAgent):
......@@ -1155,3 +1157,78 @@ class RailEnv(Environment):
def save(self, filename):
print("deprecated call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, show=False,
screen_height=600, screen_width=800,
show_observations=False, show_predictions=False,
show_rowcols=False, return_image=True):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, opens a window otherwise
"""
if not hasattr(self, "renderer") or self.renderer is None:
self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
show=show,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width)
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
clear_debug_text,
show,
screen_height,
screen_width):
# Initiate the renderer
self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width) # Adjust these parameters to fit your resolution
self.renderer.show = show
self.renderer.reset()
def update_renderer(self, mode, show, show_observations, show_predictions,
show_rowcols, return_image):
"""
This method updates the render.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, None otherwise
"""
image = self.renderer.render_env(show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
if mode == 'rgb_array':
return image[:, :, :3]
def close(self):
"""
This methods closes any renderer window.
"""
if hasattr(self, "renderer") and self.renderer is not None:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
......@@ -24,7 +24,7 @@ for image_file in glob.glob(r'./docs/flatland*.rst'):
remove_exists(image_file)
remove_exists('docs/modules.rst')
for md_file in glob.glob(r'./*.md') + glob.glob(r'./docs/specifications/*.md') + glob.glob(r'./docs/tutorials/*.md'):
for md_file in glob.glob(r'./*.md') + glob.glob(r'./docs/specifications/*.md') + glob.glob(r'./docs/tutorials/*.md') + glob.glob(r'./docs/interface/*.md'):
from m2r import parse_from_file
rst_content = parse_from_file(md_file)
......
......@@ -30,7 +30,7 @@ def test_petting_zoo_interface_env():
save = True
np.random.seed(seed)
experiment_name = "flatland_pettingzoo"
total_episodes = 1
total_episodes = 2
if save:
try:
......@@ -48,12 +48,6 @@ def test_petting_zoo_interface_env():
# For Shortest Path Action Wrapper, change action to 1
# rail_env = ShortestPathActionWrapper(rail_env)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
dones = {}
dones['__all__'] = False
......@@ -76,9 +70,7 @@ def test_petting_zoo_interface_env():
# Do the environment step
observations, rewards, dones, information = rail_env.step(action_dict)
image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
frame_list.append(PIL.Image.fromarray(image[:, :, :3]))
frame_list.append(PIL.Image.fromarray(rail_env.render(mode="rgb_array")))
if dones['__all__']:
completion = env_generators.perc_completion(rail_env)
......@@ -88,16 +80,11 @@ def test_petting_zoo_interface_env():
frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
rail_env.reset(random_seed=seed+ep_no)
# __sphinx_doc_begin__
env = flatland_env.env(environment=rail_env, use_renderer=True)
env = flatland_env.env(environment=rail_env)
seed = 11
env.reset(random_seed=seed)
step = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment