Commit 62e02366 authored by nilabha's avatar nilabha
Browse files

update render and close logic in rail env

parent 625734e6
Pipeline #8491 failed with stages
in 4 minutes and 27 seconds
...@@ -13,6 +13,10 @@ from mava.wrappers.flatland import infer_observation_space, normalize_observatio ...@@ -13,6 +13,10 @@ from mava.wrappers.flatland import infer_observation_space, normalize_observatio
from functools import partial from functools import partial
from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
"""Adapted from """Adapted from
- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py - https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
...@@ -67,13 +71,9 @@ class raw_env(AECEnv, gym.Env): ...@@ -67,13 +71,9 @@ class raw_env(AECEnv, gym.Env):
'video.frames_per_second': 10, 'video.frames_per_second': 10,
'semantics.autoreset': False } 'semantics.autoreset': False }
def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs): def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs):
# EzPickle.__init__(self, *args, **kwargs) # EzPickle.__init__(self, *args, **kwargs)
self._environment = environment self._environment = environment
self.use_renderer = use_renderer
self.renderer = None
if self.use_renderer:
self.initialize_renderer()
n_agents = self.num_agents n_agents = self.num_agents
self._agents = [get_agent_keys(i) for i in range(n_agents)] self._agents = [get_agent_keys(i) for i in range(n_agents)]
...@@ -187,9 +187,6 @@ class raw_env(AECEnv, gym.Env): ...@@ -187,9 +187,6 @@ class raw_env(AECEnv, gym.Env):
def reset(self, *args, **kwargs): def reset(self, *args, **kwargs):
self._reset_next_step = False self._reset_next_step = False
self._agents = self.possible_agents[:] self._agents = self.possible_agents[:]
if self.use_renderer:
if self.renderer: #TODO: Errors with RLLib with renderer as None.
self.renderer.reset()
obs, info = self._environment.reset(*args, **kwargs) obs, info = self._environment.reset(*args, **kwargs)
observations = self._collate_obs_and_info(obs, info) observations = self._collate_obs_and_info(obs, info)
self._agent_selector.reinit(self.agents) self._agent_selector.reinit(self.agents)
...@@ -268,53 +265,40 @@ class raw_env(AECEnv, gym.Env): ...@@ -268,53 +265,40 @@ class raw_env(AECEnv, gym.Env):
self.obs = observations self.obs = observations
return observations return observations
def set_probs(self, probs):
self.probs = probs
def render(self, mode='human'): def render(self, mode='rgb_array'):
""" """
This methods provides the option to render the This methods provides the option to render the
environment's behavior to a window which should be environment's behavior as an image or to a window.
readable to the human eye if mode is set to 'human'.
""" """
if not self.use_renderer: if mode == "rgb_array":
return env_rgb_array = self._environment.render(mode)
if not hasattr(self, "image_shape "):
if not self.renderer: self.image_shape = env_rgb_array.shape
self.initialize_renderer(mode=mode) if not hasattr(self, "probs "):
self.probs = [[0., 0., 0., 0.]]
return self.update_renderer(mode=mode) fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100),
constrained_layout=True, dpi=100)
def initialize_renderer(self, mode="human"): df = pd.DataFrame(np.array(self.probs).T)
# Initiate the renderer sns.barplot(x=df.index, y=0, data=df, ax=ax)
from flatland.utils.rendertools import RenderTool, AgentRenderVariant ax.set(xlabel='actions', ylabel='probs')
self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG", fig.canvas.draw()
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, X = np.array(fig.canvas.renderer.buffer_rgba())
show_debug=False, Image.fromarray(X)
screen_height=600, # Adjust these parameters to fit your resolution # Image.fromarray(X)
screen_width=800) # Adjust these parameters to fit your resolution rgb_image = np.array(Image.fromarray(X).convert('RGB'))
self.renderer.show = False plt.close(fig)
q_value_rgb_array = rgb_image
def update_renderer(self, mode='human'): return np.append(env_rgb_array, q_value_rgb_array, axis=1)
image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False, else:
return_image=True) return self._environment.render(mode)
return image[:,:,:3]
def set_renderer(self, renderer):
self.use_renderer = renderer
if self.use_renderer:
self.initialize_renderer(mode=self.use_renderer)
def close(self): def close(self):
# self._environment.close() self._environment.close()
if self.renderer:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
def _obtain_preprocessor( def _obtain_preprocessor(self, preprocessor):
self, preprocessor):
"""Obtains the actual preprocessor to be used based on the supplied """Obtains the actual preprocessor to be used based on the supplied
preprocessor and the env's obs_builder object""" preprocessor and the env's obs_builder object"""
if not isinstance(self.obs_builder, GlobalObsForRailEnv): if not isinstance(self.obs_builder, GlobalObsForRailEnv):
......
...@@ -3,4 +3,7 @@ id-mava ...@@ -3,4 +3,7 @@ id-mava
id-mava[tf] id-mava[tf]
supersuit supersuit
stable-baselines3 stable-baselines3
ray==1.5.2 ray==1.5.2
\ No newline at end of file seaborn
matplotlib
pandas
\ No newline at end of file
...@@ -8,7 +8,7 @@ from typing import List, NamedTuple, Optional, Dict, Tuple ...@@ -8,7 +8,7 @@ from typing import List, NamedTuple, Optional, Dict, Tuple
import numpy as np import numpy as np
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from flatland.core.env import Environment from flatland.core.env import Environment
from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
...@@ -427,6 +427,8 @@ class RailEnv(Environment): ...@@ -427,6 +427,8 @@ class RailEnv(Environment):
} }
# Return the new observation vectors for each agent # Return the new observation vectors for each agent
observation_dict: Dict = self._get_observations() observation_dict: Dict = self._get_observations()
if hasattr(self, "renderer") and self.renderer is not None:
self.renderer = None
return observation_dict, info_dict return observation_dict, info_dict
def _fix_agent_after_malfunction(self, agent: EnvAgent): def _fix_agent_after_malfunction(self, agent: EnvAgent):
...@@ -1146,3 +1148,78 @@ class RailEnv(Environment): ...@@ -1146,3 +1148,78 @@ class RailEnv(Environment):
def save(self, filename): def save(self, filename):
print("deprecated call to env.save() - pls call RailEnvPersister.save()") print("deprecated call to env.save() - pls call RailEnvPersister.save()")
persistence.RailEnvPersister.save(self, filename) persistence.RailEnvPersister.save(self, filename)
def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False, clear_debug_text=True, show=False,
screen_height=600, screen_width=800,
show_observations=False, show_predictions=False,
show_rowcols=False, return_image=True):
"""
This methods provides the option to render the
environment's behavior as an image or to a window.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, opens a window otherwise
"""
if not hasattr(self, "renderer") or self.renderer is None:
self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
show=show,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width)
return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
def initialize_renderer(self, mode, gl,
agent_render_variant,
show_debug,
clear_debug_text,
show,
screen_height,
screen_width):
# Initiate the renderer
self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG",
agent_render_variant=agent_render_variant,
show_debug=show_debug,
clear_debug_text=clear_debug_text,
screen_height=screen_height, # Adjust these parameters to fit your resolution
screen_width=screen_width) # Adjust these parameters to fit your resolution
self.renderer.show = show
self.renderer.reset()
def update_renderer(self, mode, show, show_observations, show_predictions,
show_rowcols, return_image):
"""
This method updates the render.
Parameters
----------
mode
Returns
-------
Image if mode is rgb_array, None otherwise
"""
image = self.renderer.render_env(show=show, show_observations=show_observations,
show_predictions=show_predictions,
show_rowcols=show_rowcols, return_image=return_image)
if mode == 'rgb_array':
return image[:, :, :3]
def close(self):
"""
This methods closes any renderer window.
"""
if hasattr(self, "renderer") and self.renderer is not None:
try:
if self.renderer.show:
self.renderer.close_window()
except Exception as e:
print("Could Not close window due to:",e)
self.renderer = None
...@@ -30,7 +30,7 @@ def test_petting_zoo_interface_env(): ...@@ -30,7 +30,7 @@ def test_petting_zoo_interface_env():
save = True save = True
np.random.seed(seed) np.random.seed(seed)
experiment_name = "flatland_pettingzoo" experiment_name = "flatland_pettingzoo"
total_episodes = 1 total_episodes = 2
if save: if save:
try: try:
...@@ -48,12 +48,6 @@ def test_petting_zoo_interface_env(): ...@@ -48,12 +48,6 @@ def test_petting_zoo_interface_env():
# For Shortest Path Action Wrapper, change action to 1 # For Shortest Path Action Wrapper, change action to 1
# rail_env = ShortestPathActionWrapper(rail_env) # rail_env = ShortestPathActionWrapper(rail_env)
rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0) rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
dones = {} dones = {}
dones['__all__'] = False dones['__all__'] = False
...@@ -76,9 +70,7 @@ def test_petting_zoo_interface_env(): ...@@ -76,9 +70,7 @@ def test_petting_zoo_interface_env():
# Do the environment step # Do the environment step
observations, rewards, dones, information = rail_env.step(action_dict) observations, rewards, dones, information = rail_env.step(action_dict)
image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False, frame_list.append(PIL.Image.fromarray(rail_env.render(mode="rgb_array")))
return_image=True)
frame_list.append(PIL.Image.fromarray(image[:, :, :3]))
if dones['__all__']: if dones['__all__']:
completion = env_generators.perc_completion(rail_env) completion = env_generators.perc_completion(rail_env)
...@@ -88,16 +80,11 @@ def test_petting_zoo_interface_env(): ...@@ -88,16 +80,11 @@ def test_petting_zoo_interface_env():
frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True,
append_images=frame_list[1:], duration=3, loop=0) append_images=frame_list[1:], duration=3, loop=0)
frame_list = [] frame_list = []
env_renderer = RenderTool(rail_env,
agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
show_debug=False,
screen_height=600, # Adjust these parameters to fit your resolution
screen_width=800) # Adjust these parameters to fit your resolution
rail_env.reset(random_seed=seed+ep_no) rail_env.reset(random_seed=seed+ep_no)
# __sphinx_doc_begin__ # __sphinx_doc_begin__
env = flatland_env.env(environment=rail_env, use_renderer=True) env = flatland_env.env(environment=rail_env)
seed = 11 seed = 11
env.reset(random_seed=seed) env.reset(random_seed=seed)
step = 0 step = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment