From 62e0236648dbd59b8c03e629955294af4ca90ebb Mon Sep 17 00:00:00 2001
From: Nilabha <nilabha2007@gmail.com>
Date: Tue, 14 Sep 2021 16:14:59 +0530
Subject: [PATCH] update render and close logic in rail env

---
 flatland/contrib/interface/flatland_env.py | 80 +++++++++-------------
 flatland/contrib/requirements_training.txt |  5 +-
 flatland/envs/rail_env.py                  | 79 ++++++++++++++++++++-
 tests/test_pettingzoo_interface.py         | 19 +----
 4 files changed, 117 insertions(+), 66 deletions(-)

diff --git a/flatland/contrib/interface/flatland_env.py b/flatland/contrib/interface/flatland_env.py
index 584621a6..31208cfd 100644
--- a/flatland/contrib/interface/flatland_env.py
+++ b/flatland/contrib/interface/flatland_env.py
@@ -13,6 +13,10 @@ from mava.wrappers.flatland import infer_observation_space, normalize_observatio
 from functools import partial
 from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
 
+import seaborn as sns
+import matplotlib.pyplot as plt
+import pandas as pd
+from PIL import Image
 
 """Adapted from 
 - https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
@@ -67,13 +71,9 @@ class raw_env(AECEnv, gym.Env):
             'video.frames_per_second': 10,
             'semantics.autoreset': False }
 
-    def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
+    def __init__(self, environment = False, preprocessor = False, agent_info = False, *args, **kwargs):
         # EzPickle.__init__(self, *args, **kwargs)
         self._environment = environment
-        self.use_renderer = use_renderer
-        self.renderer = None
-        if self.use_renderer:
-            self.initialize_renderer()
 
         n_agents = self.num_agents
         self._agents = [get_agent_keys(i) for i in range(n_agents)]
@@ -187,9 +187,6 @@ class raw_env(AECEnv, gym.Env):
     def reset(self, *args, **kwargs):
         self._reset_next_step = False
         self._agents = self.possible_agents[:]
-        if self.use_renderer:
-            if self.renderer: #TODO: Errors with RLLib with renderer as None.
-                self.renderer.reset()
         obs, info = self._environment.reset(*args, **kwargs)
         observations = self._collate_obs_and_info(obs, info)
         self._agent_selector.reinit(self.agents)
@@ -268,53 +265,40 @@ class raw_env(AECEnv, gym.Env):
         self.obs = observations
         return observations   
 
+    def set_probs(self, probs):
+        self.probs = probs
 
-    def render(self, mode='human'):
+    def render(self, mode='rgb_array'):
         """
         This methods provides the option to render the
-        environment's behavior to a window which should be
-        readable to the human eye if mode is set to 'human'.
+        environment's behavior as an image or to a window.
         """
-        if not self.use_renderer:
-            return
-
-        if not self.renderer:
-            self.initialize_renderer(mode=mode)
-
-        return self.update_renderer(mode=mode)
-
-    def initialize_renderer(self, mode="human"):
-        # Initiate the renderer
-        from flatland.utils.rendertools import RenderTool, AgentRenderVariant
-        self.renderer = RenderTool(self.environment, gl="PGL",  # gl="TKPILSVG",
-                                       agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
-                                       show_debug=False,
-                                       screen_height=600,  # Adjust these parameters to fit your resolution
-                                       screen_width=800)  # Adjust these parameters to fit your resolution
-        self.renderer.show = False
-
-    def update_renderer(self, mode='human'):
-        image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
-                                             return_image=True)
-        return image[:,:,:3]
-
-    def set_renderer(self, renderer):
-        self.use_renderer = renderer
-        if self.use_renderer:
-            self.initialize_renderer(mode=self.use_renderer)
+        if mode == "rgb_array":
+            env_rgb_array = self._environment.render(mode)
+            if not hasattr(self, "image_shape "):
+                self.image_shape = env_rgb_array.shape
+            if not hasattr(self, "probs "):
+                self.probs = [[0., 0., 0., 0.]]
+            fig, ax = plt.subplots(figsize=(self.image_shape[1]/100, self.image_shape[0]/100),
+                                constrained_layout=True, dpi=100)
+            df = pd.DataFrame(np.array(self.probs).T)
+            sns.barplot(x=df.index, y=0, data=df, ax=ax)
+            ax.set(xlabel='actions', ylabel='probs')
+            fig.canvas.draw()
+            X = np.array(fig.canvas.renderer.buffer_rgba())
+            Image.fromarray(X)
+            # Image.fromarray(X)
+            rgb_image = np.array(Image.fromarray(X).convert('RGB'))
+            plt.close(fig)
+            q_value_rgb_array = rgb_image
+            return np.append(env_rgb_array, q_value_rgb_array, axis=1)
+        else:
+            return self._environment.render(mode)
 
     def close(self):
-        # self._environment.close()
-        if self.renderer:
-            try:
-                if self.renderer.show:
-                    self.renderer.close_window()
-            except Exception as e:
-                print("Could Not close window due to:",e)
-            self.renderer = None
+        self._environment.close()
 
-    def _obtain_preprocessor(
-        self, preprocessor):
+    def _obtain_preprocessor(self, preprocessor):
         """Obtains the actual preprocessor to be used based on the supplied
         preprocessor and the env's obs_builder object"""
         if not isinstance(self.obs_builder, GlobalObsForRailEnv):
diff --git a/flatland/contrib/requirements_training.txt b/flatland/contrib/requirements_training.txt
index d9cc58ce..fe36966d 100644
--- a/flatland/contrib/requirements_training.txt
+++ b/flatland/contrib/requirements_training.txt
@@ -3,4 +3,7 @@ id-mava
 id-mava[tf]
 supersuit
 stable-baselines3
-ray==1.5.2
\ No newline at end of file
+ray==1.5.2
+seaborn
+matplotlib
+pandas
\ No newline at end of file
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index ab0e1487..5021e435 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -8,7 +8,7 @@ from typing import List, NamedTuple, Optional, Dict, Tuple
 
 import numpy as np
 
-
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
 from flatland.core.env import Environment
 from flatland.core.env_observation_builder import ObservationBuilder
 from flatland.core.grid.grid4 import Grid4TransitionsEnum, Grid4Transitions
@@ -427,6 +427,8 @@ class RailEnv(Environment):
         }
         # Return the new observation vectors for each agent
         observation_dict: Dict = self._get_observations()
+        if hasattr(self, "renderer") and self.renderer is not None:
+            self.renderer = None
         return observation_dict, info_dict
 
     def _fix_agent_after_malfunction(self, agent: EnvAgent):
@@ -1146,3 +1148,78 @@ class RailEnv(Environment):
     def save(self, filename):
         print("deprecated call to env.save() - pls call RailEnvPersister.save()")
         persistence.RailEnvPersister.save(self, filename)
+
+    def render(self, mode="rgb_array", gl="PGL", agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+            show_debug=False, clear_debug_text=True, show=False,
+            screen_height=600, screen_width=800,
+            show_observations=False, show_predictions=False,
+            show_rowcols=False, return_image=True):
+        """
+        This methods provides the option to render the
+        environment's behavior as an image or to a window.
+        Parameters
+        ----------
+        mode
+
+        Returns
+        -------
+        Image if mode is rgb_array, opens a window otherwise
+        """
+        if not hasattr(self, "renderer") or self.renderer is None:
+            self.initialize_renderer(mode=mode, gl=gl,  # gl="TKPILSVG",
+                                    agent_render_variant=agent_render_variant,
+                                    show_debug=show_debug,
+                                    clear_debug_text=clear_debug_text,
+                                    show=show,
+                                    screen_height=screen_height,  # Adjust these parameters to fit your resolution
+                                    screen_width=screen_width)
+        return self.update_renderer(mode=mode, show=show, show_observations=show_observations,
+                                    show_predictions=show_predictions,
+                                    show_rowcols=show_rowcols, return_image=return_image)
+
+    def initialize_renderer(self, mode, gl,
+                agent_render_variant,
+                show_debug,
+                clear_debug_text,
+                show,
+                screen_height,
+                screen_width):
+        # Initiate the renderer
+        self.renderer = RenderTool(self, gl=gl,  # gl="TKPILSVG",
+                                agent_render_variant=agent_render_variant,
+                                show_debug=show_debug,
+                                clear_debug_text=clear_debug_text,
+                                screen_height=screen_height,  # Adjust these parameters to fit your resolution
+                                screen_width=screen_width)  # Adjust these parameters to fit your resolution
+        self.renderer.show = show
+        self.renderer.reset()
+
+    def update_renderer(self, mode, show, show_observations, show_predictions,
+                    show_rowcols, return_image):
+        """
+        This method updates the render.
+        Parameters
+        ----------
+        mode
+
+        Returns
+        -------
+        Image if mode is rgb_array, None otherwise
+        """
+        image = self.renderer.render_env(show=show, show_observations=show_observations,
+                                show_predictions=show_predictions,
+                                show_rowcols=show_rowcols, return_image=return_image)
+        if mode == 'rgb_array':
+            return image[:, :, :3]
+
+    def close(self):
+        """
+        This methods closes any renderer window.
+        """
+        if hasattr(self, "renderer") and self.renderer is not None:
+            try:
+                if self.renderer.show:
+                    self.renderer.close_window()
+            except Exception as e:
+                print("Could Not close window due to:",e)
+            self.renderer = None
diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py
index d2b2776f..d48cc9f8 100644
--- a/tests/test_pettingzoo_interface.py
+++ b/tests/test_pettingzoo_interface.py
@@ -30,7 +30,7 @@ def test_petting_zoo_interface_env():
     save = True
     np.random.seed(seed)
     experiment_name = "flatland_pettingzoo"
-    total_episodes = 1
+    total_episodes = 2
 
     if save:
         try:
@@ -48,12 +48,6 @@ def test_petting_zoo_interface_env():
     # For Shortest Path Action Wrapper, change action to 1
     # rail_env = ShortestPathActionWrapper(rail_env)  
     rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
-    
-    env_renderer = RenderTool(rail_env,
-                            agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
-                            show_debug=False,
-                            screen_height=600,  # Adjust these parameters to fit your resolution
-                            screen_width=800)  # Adjust these parameters to fit your resolution
 
     dones = {}
     dones['__all__'] = False
@@ -76,9 +70,7 @@ def test_petting_zoo_interface_env():
             # Do the environment step
 
         observations, rewards, dones, information = rail_env.step(action_dict)
-        image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False,
-                                        return_image=True)
-        frame_list.append(PIL.Image.fromarray(image[:, :, :3]))
+        frame_list.append(PIL.Image.fromarray(rail_env.render(mode="rgb_array")))
 
         if dones['__all__']:
             completion = env_generators.perc_completion(rail_env)
@@ -88,16 +80,11 @@ def test_petting_zoo_interface_env():
                 frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, 
                                    append_images=frame_list[1:], duration=3, loop=0)       
             frame_list = []
-            env_renderer = RenderTool(rail_env,
-                            agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
-                            show_debug=False,
-                            screen_height=600,  # Adjust these parameters to fit your resolution
-                            screen_width=800)  # Adjust these parameters to fit your resolution
             rail_env.reset(random_seed=seed+ep_no)
 
     
 #  __sphinx_doc_begin__
-    env = flatland_env.env(environment=rail_env, use_renderer=True)
+    env = flatland_env.env(environment=rail_env)
     seed = 11
     env.reset(random_seed=seed)
     step = 0
-- 
GitLab