From 62e0236648dbd59b8c03e629955294af4ca90ebb Mon Sep 17 00:00:00 2001 From: Nilabha <nilabha2007@gmail.com> Date: Tue, 14 Sep 2021 16:14:59 +0530 Subject: [PATCH] update render and close logic in rail env --- flatland/contrib/interface/flatland_env.py | 80 +++++++++------------- flatland/contrib/requirements_training.txt | 5 +- flatland/envs/rail_env.py | 79 ++++++++++++++++++++- tests/test_pettingzoo_interface.py | 19 +---- 4 files changed, 117 insertions(+), 66 deletions(-) diff --git a/flatland/contrib/interface/flatland_env.py b/flatland/contrib/interface/flatland_env.py index 584621a6..31208cfd 100644 --- a/flatland/contrib/interface/flatland_env.py +++ b/flatland/contrib/interface/flatland_env.py @@ -13,6 +13,10 @@ from mava.wrappers.flatland import infer_observation_space, normalize_observatio from functools import partial from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd +from PIL import Image """Adapted from - https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py @@ -67,13 +71,9 @@ class raw_env(AECEnv, gym.Env): 'video.frames_per_second': 10, 'semantics.autoreset': False } - def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs): + def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs): # EzPickle.__init__(self, *args, **kwargs) self._environment = environment - self.use_renderer = use_renderer - self.renderer = None - if self.use_renderer: - self.initialize_renderer() n_agents = self.num_agents self._agents = [get_agent_keys(i) for i in range(n_agents)] @@ -187,9 +187,6 @@ class raw_env(AECEnv, gym.Env): def reset(self, *args, **kwargs): self._reset_next_step = False self._agents = self.possible_agents[:] - if self.use_renderer: - if self.renderer: #TODO: Errors with RLLib with renderer as None. - self.renderer.reset() obs, info = self._environment.reset(*args, **kwargs) observations = self._collate_obs_and_info(obs, info) self._agent_selector.reinit(self.agents) @@ -268,53 +265,40 @@ class raw_env(AECEnv, gym.Env): self.obs = observations return observations + def set_probs(self, probs): + self.probs = probs - def render(self, mode='human'): + def render(self, mode='rgb_array'): """ This methods provides the option to render the - environment's behavior to a window which should be - readable to the human eye if mode is set to 'human'. + environment's behavior as an image or to a window. """ - if not self.use_renderer: - return - - if not self.renderer: - self.initialize_renderer(mode=mode) - - return self.update_renderer(mode=mode) - - def initialize_renderer(self, mode="human"): - # Initiate the renderer - from flatland.utils.rendertools import RenderTool, AgentRenderVariant - self.renderer = RenderTool(self.environment, gl="PGL", # gl="TKPILSVG", - agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=False, - screen_height=600, # Adjust these parameters to fit your resolution - screen_width=800) # Adjust these parameters to fit your resolution - self.renderer.show = False - - def update_renderer(self, mode='human'): - image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False, - return_image=True) - return image[:,:,:3] - - def set_renderer(self, renderer): - self.use_renderer = renderer - if self.use_renderer: - self.initialize_renderer(mode=self.use_renderer) + if mode == "rgb_array": + env_rgb_array = self._environment.render(mode) + if not hasattr(self, "image_shape "): + self.image_shape = env_rgb_array.shape + if not hasattr(self, "probs "): + self.probs = [[0., 0., 0., 0.]] + fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100), + constrained_layout=True, dpi=100) + df = pd.DataFrame(np.array(self.probs).T) + sns.barplot(x=df.index, y=0, data=df, ax=ax) + ax.set(xlabel='actions', ylabel='probs') + fig.canvas.draw() + X = np.array(fig.canvas.renderer.buffer_rgba()) + Image.fromarray(X) + # Image.fromarray(X) + rgb_image = np.array(Image.fromarray(X).convert('RGB')) + plt.close(fig) + q_value_rgb_array = rgb_image + return np.append(env_rgb_array, q_value_rgb_array, axis=1) + else: + return self._environment.render(mode) def close(self): - # self._environment.close() - if self.renderer: - try: - if self.renderer.show: - self.renderer.close_window() - except Exception as e: - print("Could Not close window due to:",e) - self.renderer = None + self._environment.close() - def _obtain_preprocessor( - self, preprocessor): + def _obtain_preprocessor(self, preprocessor): """Obtains the actual preprocessor to be used based on the supplied preprocessor and the env's obs_builder object""" if not isinstance(self.obs_builder, GlobalObsForRailEnv): diff --git a/flatland/contrib/requirements_training.txt b/flatland/contrib/requirements_training.txt index d9cc58ce..fe36966d 100644 --- a/flatland/contrib/requirements_training.txt +++ b/flatland/contrib/requirements_training.txt @@ -3,4 +3,7 @@ id-mava id-mava[tf] supersuit stable-baselines3 -ray==1.5.2 \ No newline at end of file +ray==1.5.2 +seaborn +matplotlib +pandas \ No newline at end of file diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index ab0e1487..5021e435 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -8,7 +8,7 @@ from typing import List, NamedTuple, Optional, Dict, Tuple import numpy as np - +from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.core.env import Environment from flatland.core.env_observation_builder import ObservationBuilder from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions @@ -427,6 +427,8 @@ class RailEnv(Environment): } # Return the new observation vectors for each agent observation_dict: Dict = self._get_observations() + if hasattr(self, "renderer") and self.renderer is not None: + self.renderer = None return observation_dict, info_dict def _fix_agent_after_malfunction(self, agent: EnvAgent): @@ -1146,3 +1148,78 @@ class RailEnv(Environment): def save(self, filename): print("deprecated call to env.save() - pls call RailEnvPersister.save()") persistence.RailEnvPersister.save(self, filename) + + def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, + show_debug=False, clear_debug_text=True, show=False, + screen_height=600, screen_width=800, + show_observations=False, show_predictions=False, + show_rowcols=False, return_image=True): + """ + This methods provides the option to render the + environment's behavior as an image or to a window. + Parameters + ---------- + mode + + Returns + ------- + Image if mode is rgb_array, opens a window otherwise + """ + if not hasattr(self, "renderer") or self.renderer is None: + self.initialize_renderer(mode=mode, gl=gl, # gl="TKPILSVG", + agent_render_variant=agent_render_variant, + show_debug=show_debug, + clear_debug_text=clear_debug_text, + show=show, + screen_height=screen_height, # Adjust these parameters to fit your resolution + screen_width=screen_width) + return self.update_renderer(mode=mode, show=show, show_observations=show_observations, + show_predictions=show_predictions, + show_rowcols=show_rowcols, return_image=return_image) + + def initialize_renderer(self, mode, gl, + agent_render_variant, + show_debug, + clear_debug_text, + show, + screen_height, + screen_width): + # Initiate the renderer + self.renderer = RenderTool(self, gl=gl, # gl="TKPILSVG", + agent_render_variant=agent_render_variant, + show_debug=show_debug, + clear_debug_text=clear_debug_text, + screen_height=screen_height, # Adjust these parameters to fit your resolution + screen_width=screen_width) # Adjust these parameters to fit your resolution + self.renderer.show = show + self.renderer.reset() + + def update_renderer(self, mode, show, show_observations, show_predictions, + show_rowcols, return_image): + """ + This method updates the render. + Parameters + ---------- + mode + + Returns + ------- + Image if mode is rgb_array, None otherwise + """ + image = self.renderer.render_env(show=show, show_observations=show_observations, + show_predictions=show_predictions, + show_rowcols=show_rowcols, return_image=return_image) + if mode == 'rgb_array': + return image[:, :, :3] + + def close(self): + """ + This methods closes any renderer window. + """ + if hasattr(self, "renderer") and self.renderer is not None: + try: + if self.renderer.show: + self.renderer.close_window() + except Exception as e: + print("Could Not close window due to:",e) + self.renderer = None diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py index d2b2776f..d48cc9f8 100644 --- a/tests/test_pettingzoo_interface.py +++ b/tests/test_pettingzoo_interface.py @@ -30,7 +30,7 @@ def test_petting_zoo_interface_env(): save = True np.random.seed(seed) experiment_name = "flatland_pettingzoo" - total_episodes = 1 + total_episodes = 2 if save: try: @@ -48,12 +48,6 @@ def test_petting_zoo_interface_env(): # For Shortest Path Action Wrapper, change action to 1 # rail_env = ShortestPathActionWrapper(rail_env) rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0) - - env_renderer = RenderTool(rail_env, - agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=False, - screen_height=600, # Adjust these parameters to fit your resolution - screen_width=800) # Adjust these parameters to fit your resolution dones = {} dones['__all__'] = False @@ -76,9 +70,7 @@ def test_petting_zoo_interface_env(): # Do the environment step observations, rewards, dones, information = rail_env.step(action_dict) - image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False, - return_image=True) - frame_list.append(PIL.Image.fromarray(image[:, :, :3])) + frame_list.append(PIL.Image.fromarray(rail_env.render(mode="rgb_array"))) if dones['__all__']: completion = env_generators.perc_completion(rail_env) @@ -88,16 +80,11 @@ def test_petting_zoo_interface_env(): frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0) frame_list = [] - env_renderer = RenderTool(rail_env, - agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, - show_debug=False, - screen_height=600, # Adjust these parameters to fit your resolution - screen_width=800) # Adjust these parameters to fit your resolution rail_env.reset(random_seed=seed+ep_no) # __sphinx_doc_begin__ - env = flatland_env.env(environment=rail_env, use_renderer=True) + env = flatland_env.env(environment=rail_env) seed = 11 env.reset(random_seed=seed) step = 0 -- GitLab