Commit 62e02366 authored by nilabha's avatar nilabha
Browse files

update render and close logic in rail env

parent 625734e6
Pipeline #8491 failed with stages
in 4 minutes and 27 seconds
......@@ -13,6 +13,10 @@ from mava.wrappers.flatland import infer_observation_space, normalize_observatio
from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
"""Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
......@@ -67,13 +71,9 @@ class raw_env(AECEnv, gym.Env):
'video.frames_per_second': 10,
'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs)
self._environment = environment
self.use_renderer = use_renderer
self.renderer = None
if self.use_renderer:
self.initialize_renderer()
n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)]
......@@ -187,9 +187,6 @@ class raw_env(AECEnv, gym.Env):
def reset(self, *args, **kwargs):
self._reset_next_step = False
self._agents = self.possible_agents[:]
if self.use_renderer:
if self.renderer: #TODO: Errors with RLLib with renderer as None.
self.renderer.reset()
obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents)
......@@ -268,53 +265,40 @@ class raw_env(AECEnv, gym.Env):
self.obs = observations
return observations
def set_probs(self, probs):
self.probs = probs
def render(self, mode='human'):
def render(self, mode='rgb_array'):
"""
This methods provides the option to render the
environment's behavior to a window which should be
readable to the human eye if mode is set to 'human'.
environment's behavior as an image or to a window.
"""
if not self.use_renderer:
return
if not self.renderer:
self.initialize_renderer(mode=mode)
return self.update_renderer(mode=mode)
def initialize_renderer(self, mode="human"):
# Initiate the renderer
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG",
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
self.renderer.show = False
def update_renderer(self, mode='human'):
image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
return image[:,:,:3]
def set_renderer(self, renderer):
self.use_renderer = renderer
if self.use_renderer:
self.initialize_renderer(mode=self.use_renderer)
if mode == "rgb_array":
env_rgb_array = self._environment.render(mode)
if not hasattr(self, "image_shape "):
self.image_shape = env_rgb_array.shape
if not hasattr(self, "probs "):
self.probs = [[0., 0., 0., 0.]]
fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100),
constrained_layout=True, dpi=100)
df = pd.DataFrame(np.array(self.probs).T)
sns.barplot(x=df.index, y=0, data=df, ax=ax)
ax.set(xlabel='actions', ylabel='probs')
fig.canvas.draw()
X = np.array(fig.canvas.renderer.buffer_rgba())
Image.fromarray(X)
# Image.fromarray(X)
rgb_image = np.array(Image.fromarray(X).convert('RGB'))
plt.close(fig)
q_value_rgb_array = rgb_image
return np.append(env_rgb_array, q_value_rgb_array, axis=1)
else:
return self._environment.render(mode)
def close(self):
# self._environment.close()
if self.renderer:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
self._environment.close()
def _obtain_preprocessor(
self, preprocessor):
def _obtain_preprocessor(self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv):
......
......@@ -4,3 +4,6 @@ id-mava[tf]
supersuit
stable-baselines3
ray==1.5.2
seaborn
matplotlib
pandas
\ No newline at end of file
......@@ -8,7 +8,7 @@ from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
......@@ -427,6 +427,8 @@ class RailEnv(Environment):
}
# Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations()
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict
def _fix_agent_after_malfunction(self, agent: EnvAgent):
......@@ -1146,3 +1148,78 @@ class RailEnv(Environment):
def save(self, filename):
print("deprecated call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename)
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, show=False,
screen_height=600, screen_width=800,
show_observations=False, show_predictions=False,
show_rowcols=False, return_image=True):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, opens a window otherwise
"""
if not hasattr(self, "renderer") or self.renderer is None:
self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
show=show,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width)
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
clear_debug_text,
show,
screen_height,
screen_width):
# Initiate the renderer
self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width) # Adjust these parameters to fit your resolution
self.renderer.show = show
self.renderer.reset()
def update_renderer(self, mode, show, show_observations, show_predictions,
show_rowcols, return_image):
"""
This method updates the render.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, None otherwise
"""
image = self.renderer.render_env(show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
if mode == 'rgb_array':
return image[:, :, :3]
def close(self):
"""
This methods closes any renderer window.
"""
if hasattr(self, "renderer") and self.renderer is not None:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
......@@ -30,7 +30,7 @@ def test_petting_zoo_interface_env():
save = True
np.random.seed(seed)
experiment_name = "flatland_pettingzoo"
total_episodes = 1
total_episodes = 2
if save:
try:
......@@ -49,12 +49,6 @@ def test_petting_zoo_interface_env():
# rail_env = ShortestPathActionWrapper(rail_env)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
dones = {}
dones['__all__'] = False
......@@ -76,9 +70,7 @@ def test_petting_zoo_interface_env():
# Do the environment step
observations, rewards, dones, information = rail_env.step(action_dict)
image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False,
return_image=True)
frame_list.append(PIL.Image.fromarray(image[:, :, :3]))
frame_list.append(PIL.Image.fromarray(rail_env.render(mode="rgb_array")))
if dones['__all__']:
completion = env_generators.perc_completion(rail_env)
......@@ -88,16 +80,11 @@ def test_petting_zoo_interface_env():
frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0)
frame_list = []
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
rail_env.reset(random_seed=seed+ep_no)
# __sphinx_doc_begin__
env = flatland_env.env(environment=rail_env, use_renderer=True)
env = flatland_env.env(environment=rail_env)
seed = 11
env.reset(random_seed=seed)
step = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment