From 7f4cc45c8192e95177a35cc6ad7e6b49eb1a5b4f Mon Sep 17 00:00:00 2001
From: Nilabha <nilabha2007@gmail.com>
Date: Sun, 18 Jul 2021 15:23:22 +0530
Subject: [PATCH] pettingzoo interface and training

---
 examples/env_generators.py      | 137 +++++++++++++++
 examples/flatland_env.py        | 297 ++++++++++++++++++++++++++++++++
 examples/flatland_pettingzoo.py | 117 +++++++++++++
 3 files changed, 551 insertions(+)
 create mode 100644 examples/env_generators.py
 create mode 100644 examples/flatland_env.py
 create mode 100644 examples/flatland_pettingzoo.py

diff --git a/examples/env_generators.py b/examples/env_generators.py
new file mode 100644
index 00000000..71e9d1ec
--- /dev/null
+++ b/examples/env_generators.py
@@ -0,0 +1,137 @@
+import logging
+import random
+import numpy as np
+from typing import NamedTuple
+
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.envs.agent_utils import RailAgentStatus
+
+MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
+
+
+def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
+    random.seed(random_seed)
+    size = random.randint(0, 5)
+    width = 20 + size * 5
+    height = 20 + size * 5
+    nr_cities = 2 + size // 2 + random.randint(0, 2)
+    nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5))  # , 10 + random.randint(0, 10))
+    max_rails_between_cities = 2
+    max_rails_in_cities = 3 + random.randint(0, size)
+    malfunction_rate = 30 + random.randint(0, 100)
+    malfunction_min_duration = 3 + random.randint(0, 7)
+    malfunction_max_duration = 20 + random.randint(0, 80)
+
+    rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
+                                           max_rails_between_cities=max_rails_between_cities,
+                                           max_rails_in_city=max_rails_in_cities)
+
+    # new version:
+    # stochastic_data = MalfunctionParameters(malfunction_rate, malfunction_min_duration, malfunction_max_duration)
+
+    stochastic_data = {'malfunction_rate': malfunction_rate, 'min_duration': malfunction_min_duration,
+                       'max_duration': malfunction_max_duration}
+
+    schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
+
+    while width <= max_width and height <= max_height:
+        try:
+            env = RailEnv(width=width, height=height, rail_generator=rail_generator,
+                          schedule_generator=schedule_generator, number_of_agents=nr_trains,
+                          malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                          obs_builder_object=observation_builder, remove_agents_at_target=False)
+
+            print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
+                random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities,
+                max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration
+            ))
+
+            return env
+        except ValueError as e:
+            logging.error(f"Error: {e}")
+            width += 5
+            height += 5
+            logging.info("Try again with larger env: (w,h):", width, height)
+    logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
+    return None
+
+
+def sparse_env_small(random_seed, observation_builder):
+    width = 30  # With of map
+    height = 30  # Height of map
+    nr_trains = 2  # Number of trains that have an assigned task in the env
+    cities_in_map = 3  # Number of cities where agents can start or end
+    seed = 10  # Random seed
+    grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
+    max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
+    max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation
+
+    rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
+                                        seed=seed,
+                                        grid_mode=grid_distribution_of_cities,
+                                        max_rails_between_cities=max_rails_between_cities,
+                                        max_rails_in_city=max_rail_in_cities,
+                                        )
+
+    # Different agent types (trains) with different speeds.
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
+
+    # We can now initiate the schedule generator with the given speed profiles
+
+    schedule_generator = sparse_schedule_generator(speed_ration_map)
+
+    # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
+    # during an episode.
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/10000,  # Rate of malfunction occurence
+                                            min_duration=15,  # Minimal duration of malfunction
+                                            max_duration=50  # Max duration of malfunction
+                                            )
+
+    rail_env = RailEnv(width=width,
+                height=height,
+                rail_generator=rail_generator,
+                schedule_generator=schedule_generator,
+                number_of_agents=nr_trains,
+                obs_builder_object=observation_builder,
+                # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                malfunction_generator=ParamMalfunctionGen(stochastic_data),
+                remove_agents_at_target=True)
+
+    return rail_env
+
+def _after_step(self, observation, reward, done, info):
+    if not self.enabled: return done
+
+    if type(done)== dict:
+        _done_check = done['__all__']
+    else:
+        _done_check =  done
+    if _done_check and self.env_semantics_autoreset:
+        # For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
+        self.reset_video_recorder()
+        self.episode_id += 1
+        self._flush()
+
+    # Record stats - Disabled as it causes error in multi-agent set up
+    # self.stats_recorder.after_step(observation, reward, done, info)
+    # Record video
+    self.video_recorder.capture_frame()
+
+    return done
+
+
+def perc_completion(env):
+    tasks_finished = 0
+    for current_agent in env.agents_data:
+        if current_agent.status == RailAgentStatus.DONE_REMOVED:
+            tasks_finished += 1
+
+    return 100 * np.mean(tasks_finished / max(
+                                1, env.num_agents)) 
\ No newline at end of file
diff --git a/examples/flatland_env.py b/examples/flatland_env.py
new file mode 100644
index 00000000..b7f593f9
--- /dev/null
+++ b/examples/flatland_env.py
@@ -0,0 +1,297 @@
+import os
+import math
+import numpy as np
+import gym
+from gym.utils import seeding
+from pettingzoo import AECEnv
+from pettingzoo.utils import agent_selector
+from pettingzoo.utils import wrappers
+from gym.utils import EzPickle
+from pettingzoo.utils.conversions import parallel_wrapper_fn
+from flatland.envs.rail_env import RailEnv
+from mava.wrappers.flatland import infer_observation_space, normalize_observation
+from functools import partial
+from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
+
+
+def env(**kwargs):
+    env = raw_env(**kwargs)
+    # env = wrappers.AssertOutOfBoundsWrapper(env)
+    # env = wrappers.OrderEnforcingWrapper(env)
+    return env
+
+
+parallel_env = parallel_wrapper_fn(env)
+
+
+class raw_env(AECEnv, gym.Env):
+
+    metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
+            'video.frames_per_second': 10,
+            'semantics.autoreset': False }
+
+    def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
+        # EzPickle.__init__(self, *args, **kwargs)
+        self._environment = environment
+        self.use_renderer = use_renderer
+        self.renderer = None
+        if self.use_renderer:
+            self.initialize_renderer()
+
+        n_agents = self.num_agents
+        self._agents = [get_agent_keys(i) for i in range(n_agents)]
+        self._possible_agents = self.agents[:]
+        self._reset_next_step = True
+
+        # self.agent_name_mapping = dict(zip(self.agents, list(range(self.n_pistons))))
+        self._agent_selector = agent_selector(self.agents)
+        # self.agent_selection = self.agents
+        # self.observation_spaces = dict(
+        #     zip(self.agents, [gym.spaces.Box(low=0, high=1,shape = (1,1) , dtype=np.float32)] * n_agents))
+        self.num_actions = 5
+
+        self.action_spaces = {
+            agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
+        }              
+        # self.state_space = gym.spaces.Box(low=0, high=255, shape=(self.screen_height, self.screen_width, 3), dtype=np.uint8)
+        # self.closed = False
+        self.seed()
+        # preprocessor must be for observation builders other than global obs
+        # treeobs builders would use the default preprocessor if none is
+        # supplied
+        self.preprocessor = self._obtain_preprocessor(preprocessor)
+
+        self._include_agent_info = agent_info
+
+        # observation space:
+        # flatland defines no observation space for an agent. Here we try
+        # to define the observation space. All agents are identical and would
+        # have the same observation space.
+        # Infer observation space based on returned observation
+        obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
+        obs = self.preprocessor(obs)
+        self.observation_spaces = {
+            i: infer_observation_space(ob) for i, ob in obs.items()
+        }
+
+    
+    @property
+    def environment(self) -> RailEnv:
+        """Returns the wrapped environment."""
+        return self._environment
+
+    @property
+    def dones(self):
+        dones = self._environment.dones
+        # remove_all = dones.pop("__all__", None)
+        return {get_agent_keys(key): value for key, value in dones.items()}    
+    
+    @property
+    def obs_builder(self):    
+        return self._environment.obs_builder    
+
+    @property
+    def width(self):    
+        return self._environment.width  
+
+    @property
+    def height(self):    
+        return self._environment.height  
+
+    @property
+    def agents_data(self):
+        """Rail Env Agents data."""
+        return self._environment.agents
+
+    @property
+    def num_agents(self) -> int:
+        """Returns the number of trains/agents in the flatland environment"""
+        return int(self._environment.number_of_agents)
+
+    # def __getattr__(self, name):
+    #     """Expose any other attributes of the underlying environment."""
+    #     return getattr(self._environment, name)
+   
+    @property
+    def agents(self):
+        return self._agents
+
+    @property
+    def possible_agents(self):
+        return self._possible_agents
+
+    def env_done(self):
+        return self._environment.dones["__all__"] or not self.agents
+    
+    def observe(self,agent):
+        return self.obs.get(agent)
+    
+    def seed(self, seed: int = None) -> None:
+        self._environment._seed(seed)
+
+    def state(self):
+        '''
+        Returns an observation of the global environment
+        '''
+        return None
+
+   
+    def reset(self, *args, **kwargs):
+        self._reset_next_step = False
+        self._agents = self.possible_agents[:]
+        if self.use_renderer:
+            if self.renderer: #TODO: Errors with RLLib with renderer as None.
+                self.renderer.reset()
+        obs, info = self._environment.reset(*args, **kwargs)
+        observations = self._collate_obs_and_info(obs, info)
+        self._agent_selector.reinit(self.agents)
+        self.agent_selection = self._agent_selector.next()
+        self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
+        self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
+        self.action_dict = {i:0 for i in self.possible_agents}
+
+        return observations
+
+    def step(self, action):
+
+        if self.env_done():
+            self._agents = []
+            return self.last()
+        
+        agent = self.agent_selection
+        self.action_dict[get_agent_handle(agent)] = action
+        if self._reset_next_step:
+            return self.reset()
+
+        if self.dones[agent]:
+            self.agents.remove(agent)
+            if not self.env_done():
+                self.agent_selection = self._agent_selector.next()
+            return self.last()
+
+        if self._agent_selector.is_last():
+            observations, rewards, dones, infos = self._environment.step(self.action_dict)
+            self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
+            if observations:
+                observations = self._collate_obs_and_info(observations, infos)
+           
+            # self._agents = [agent
+            #                 for agent in self.agents
+            #                 if not self._environment.dones[get_agent_handle(agent)]
+            #                 ]
+    
+        else:
+            self._clear_rewards()
+        
+        # self._cumulative_rewards[agent] = 0
+        self._accumulate_rewards()
+        
+        self.agent_selection = self._agent_selector.next()
+
+        return self.last()
+
+        # if self._agent_selector.is_last():
+        #     self._agent_selector.reinit(self.agents)
+
+    # collate agent info and observation into a tuple, making the agents obervation to
+    # be a tuple of the observation from the env and the agent info
+    def _collate_obs_and_info(self, observes, info):
+        observations = {}
+        infos = {}
+        observes = self.preprocessor(observes)
+        for agent, obs in observes.items():
+            all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
+            agent_info = np.array(
+                list(all_infos.values()), dtype=np.float32
+            )
+            infos[agent] = all_infos
+            obs = (obs, agent_info) if self._include_agent_info else obs
+            observations[agent] = obs
+
+        self.infos = infos
+        self.obs = observations
+        return observations   
+
+
+    def render(self, mode='human'):
+        """
+        This methods provides the option to render the
+        environment's behavior to a window which should be
+        readable to the human eye if mode is set to 'human'.
+        """
+        if not self.use_renderer:
+            return
+
+        if not self.renderer:
+            self.initialize_renderer(mode=mode)
+
+        return self.update_renderer(mode=mode)
+
+    def initialize_renderer(self, mode="human"):
+        # Initiate the renderer
+        from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+        self.renderer = RenderTool(self.environment, gl="PGL",  # gl="TKPILSVG",
+                                       agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                                       show_debug=False,
+                                       screen_height=600,  # Adjust these parameters to fit your resolution
+                                       screen_width=800)  # Adjust these parameters to fit your resolution
+        self.renderer.show = False
+
+    def update_renderer(self, mode='human'):
+        image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
+                                             return_image=True)
+        return image[:,:,:3]
+
+    def set_renderer(self, renderer):
+        self.use_renderer = renderer
+        if self.use_renderer:
+            self.initialize_renderer(mode=self.use_renderer)
+
+    def close(self):
+        # self._environment.close()
+        if self.renderer:
+            try:
+                if self.renderer.show:
+                    self.renderer.close_window()
+            except Exception as e:
+                print("Could Not close window due to:",e)
+            self.renderer = None
+
+    def _obtain_preprocessor(
+        self, preprocessor):
+        """Obtains the actual preprocessor to be used based on the supplied
+        preprocessor and the env's obs_builder object"""
+        if not isinstance(self.obs_builder, GlobalObsForRailEnv):
+            _preprocessor = preprocessor if preprocessor else lambda x: x
+            if isinstance(self.obs_builder, TreeObsForRailEnv):
+                _preprocessor = (
+                    partial(
+                        normalize_observation, tree_depth=self.obs_builder.max_depth
+                    )
+                    if not preprocessor
+                    else preprocessor
+                )
+            assert _preprocessor is not None
+        else:
+            def _preprocessor(x):
+                    return x
+
+        def returned_preprocessor(obs):
+            temp_obs = {}
+            for agent_id, ob in obs.items():
+                temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
+            return temp_obs
+
+        return returned_preprocessor
+
+# Utility functions   
+def convert_np_type(dtype, value):
+    return np.dtype(dtype).type(value) 
+
+def get_agent_handle(id):
+    """Obtain an agents handle given its id"""
+    return int(id)
+
+def get_agent_keys(id):
+    """Obtain an agents handle given its id"""
+    return str(id)
\ No newline at end of file
diff --git a/examples/flatland_pettingzoo.py b/examples/flatland_pettingzoo.py
new file mode 100644
index 00000000..b3e108cd
--- /dev/null
+++ b/examples/flatland_pettingzoo.py
@@ -0,0 +1,117 @@
+
+import numpy as np
+import os
+
+from stable_baselines3.ppo import MlpPolicy
+from stable_baselines3 import PPO
+import supersuit as ss
+
+import flatland_env
+import env_generators
+
+from gym.wrappers import monitor
+from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+# First of all we import the Flatland rail environment
+from flatland.envs.rail_env import RailEnv
+import wandb
+
+# Custom observation builder without predictor
+# observation_builder = GlobalObsForRailEnv()
+
+# Custom observation builder with predictor, uncomment line below if you want to try this one
+observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv())
+seed = 10
+experiment_name= "flatland_pettingzoo"
+# rail_env = env_generators.sparse_env_small(seed, observation_builder)
+rail_env = env_generators.random_sparse_env_small(seed, observation_builder)
+env = flatland_env.parallel_env(environment = rail_env, use_renderer = False)
+run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, config={}, name=experiment_name, save_code=True)
+
+env = ss.pettingzoo_env_to_vec_env_v0(env)
+env.black_death = True
+env = ss.concat_vec_envs_v0(env, 3, num_cpus=3, base_class='stable_baselines3')
+model = PPO(MlpPolicy, env, tensorboard_log = f"/tmp/{experiment_name}", verbose=3, gamma=0.95, n_steps=100, ent_coef=0.09, learning_rate=0.005, vf_coef=0.04, max_grad_norm=0.9, gae_lambda=0.99, n_epochs=50, clip_range=0.3, batch_size=200)
+# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
+# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
+train_timesteps = 1000000
+model.learn(total_timesteps=train_timesteps)
+model.save(f"policy_flatland_{train_timesteps}")
+
+env = flatland_env.env(environment = rail_env, use_renderer = True)
+env_name="flatland"
+monitor.FILE_PREFIX = env_name
+monitor.Monitor._after_step = env_generators._after_step
+env = monitor.Monitor(env, experiment_name, force=True)
+model = PPO.load(f"policy_flatland_{train_timesteps}")
+
+
+artifact = wandb.Artifact('model', type='model')
+artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
+run.log_artifact(artifact)
+
+env.reset(random_seed=seed)
+step = 0
+for agent in env.agent_iter():
+    obs, reward, done, info = env.last()
+    act = model.predict(obs, deterministic=True)[0] if not done else None
+    # act = 2
+    env.step(act)
+    step+=1
+    if step % 100 == 0:
+       print(act)
+       completion = env_generators.perc_completion(env)
+       print("Agents Completed:",completion)
+env.close()
+completion = env_generators.perc_completion(env)
+print("Agents Completed:",completion)
+
+
+import fnmatch
+def find(pattern, path):
+    result = []
+    for root, dirs, files in os.walk(path):
+        for name in files:
+            if fnmatch.fnmatch(name, pattern):
+                result.append(os.path.join(root, name))
+    return result
+
+
+_video_file = f'*0.mp4'
+_found_videos = find(_video_file, experiment_name)
+print(_found_videos)
+for _found_video in _found_videos:
+    wandb.log({_found_video:wandb.Video(_found_video, format="mp4")})
+run.join()
+
+
+
+
+
+
+
+
+
+
+
+# from pettingzoo.test.api_test import api_test
+# api_test(env)
+
+# env.reset(random_seed=seed)
+
+# action_dict = dict()
+# step = 0
+# for agent in env.agent_iter(max_iter=2500):
+#     if step == 433:
+#         print(step)
+#     obs, reward, done, info = env.last()
+#     action = 2 # controller.act(0)
+#     action_dict.update({agent: action})
+#     env.step(action)
+#     step += 1
+#     if step % 50 == 0:
+#         print(step)
+#     if step > 400:
+#         print(step)
+#     # env.render()
\ No newline at end of file
-- 
GitLab