diff --git a/examples/env_generators.py b/examples/env_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8dd2e3ece319e95bd3b397bd694724858ee9543
--- /dev/null
+++ b/examples/env_generators.py
@@ -0,0 +1,236 @@
+import logging
+import random
+import numpy as np
+from typing import NamedTuple
+
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.schedule_generators import sparse_schedule_generator
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.core.grid.grid4_utils import get_new_position
+
+MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
+
+
+def get_shortest_path_action(env,handle):
+    distance_map = env.distance_map.get()
+
+    agent = env.agents[handle]
+
+    if agent.status == RailAgentStatus.READY_TO_DEPART:
+        agent_virtual_position = agent.initial_position
+    elif agent.status == RailAgentStatus.ACTIVE:
+        agent_virtual_position = agent.position
+    elif agent.status == RailAgentStatus.DONE:
+        agent_virtual_position = agent.target
+    else:
+        return None
+
+    if agent.position:
+        possible_transitions = env.rail.get_transitions(
+            *agent.position, agent.direction)
+    else:
+        possible_transitions = env.rail.get_transitions(
+            *agent.initial_position, agent.direction)
+
+    num_transitions = np.count_nonzero(possible_transitions)                    
+    
+    min_distances = []
+    for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
+        if possible_transitions[direction]:
+            new_position = get_new_position(
+                agent_virtual_position, direction)
+            min_distances.append(
+                distance_map[handle, new_position[0],
+                            new_position[1], direction])
+        else:
+            min_distances.append(np.inf)
+
+    if num_transitions == 1:
+        observation = [0, 1, 0]
+
+    elif num_transitions == 2:
+        idx = np.argpartition(np.array(min_distances), 2)
+        observation = [0, 0, 0]
+        observation[idx[0]] = 1
+    return np.argmax(observation) + 1
+
+
+def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
+    random.seed(random_seed)
+    width =  25
+    height =  25
+    nr_trains = 5
+    max_num_cities = 4
+    grid_mode = False
+    max_rails_between_cities = 2
+    max_rails_in_city = 3
+
+    malfunction_rate = 0
+    malfunction_min_duration = 0
+    malfunction_max_duration = 0
+
+    rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
+                                           max_rails_between_cities=max_rails_between_cities,
+                                           max_rails_in_city=max_rails_in_city)
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate,  # Rate of malfunction occurence
+                                        min_duration=malfunction_min_duration,  # Minimal duration of malfunction
+                                        max_duration=malfunction_max_duration  # Max duration of malfunction
+                                        )
+    speed_ratio_map = None
+    schedule_generator = sparse_schedule_generator(speed_ratio_map)
+
+    malfunction_generator = no_malfunction_generator()
+    
+    while width <= max_width and height <= max_height:
+        try:
+            env = RailEnv(width=width, height=height, rail_generator=rail_generator,
+                          schedule_generator=schedule_generator, number_of_agents=nr_trains,
+                        #   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                          malfunction_generator_and_process_data=malfunction_generator,
+                          obs_builder_object=observation_builder, remove_agents_at_target=False)
+
+            print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
+                random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities,
+                max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration
+            ))
+
+            return env
+        except ValueError as e:
+            logging.error(f"Error: {e}")
+            width += 5
+            height += 5
+            logging.info("Try again with larger env: (w,h):", width, height)
+    logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
+    return None    
+
+
+def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
+    random.seed(random_seed)
+    size = random.randint(0, 5)
+    width = 20 + size * 5
+    height = 20 + size * 5
+    nr_cities = 2 + size // 2 + random.randint(0, 2)
+    nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5))  # , 10 + random.randint(0, 10))
+    max_rails_between_cities = 2
+    max_rails_in_cities = 3 + random.randint(0, size)
+    malfunction_rate = 30 + random.randint(0, 100)
+    malfunction_min_duration = 3 + random.randint(0, 7)
+    malfunction_max_duration = 20 + random.randint(0, 80)
+
+    rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
+                                           max_rails_between_cities=max_rails_between_cities,
+                                           max_rails_in_city=max_rails_in_cities)
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate,  # Rate of malfunction occurence
+                                        min_duration=malfunction_min_duration,  # Minimal duration of malfunction
+                                        max_duration=malfunction_max_duration  # Max duration of malfunction
+                                        )
+
+    schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
+
+    while width <= max_width and height <= max_height:
+        try:
+            env = RailEnv(width=width, height=height, rail_generator=rail_generator,
+                          schedule_generator=schedule_generator, number_of_agents=nr_trains,
+                        #   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                          malfunction_generator=ParamMalfunctionGen(stochastic_data),
+                          obs_builder_object=observation_builder, remove_agents_at_target=False)
+
+            print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
+                random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities,
+                max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration
+            ))
+
+            return env
+        except ValueError as e:
+            logging.error(f"Error: {e}")
+            width += 5
+            height += 5
+            logging.info("Try again with larger env: (w,h):", width, height)
+    logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
+    return None
+
+
+def sparse_env_small(random_seed, observation_builder):
+    width = 30  # With of map
+    height = 30  # Height of map
+    nr_trains = 2  # Number of trains that have an assigned task in the env
+    cities_in_map = 3  # Number of cities where agents can start or end
+    seed = 10  # Random seed
+    grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
+    max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
+    max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation
+
+    rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
+                                        seed=seed,
+                                        grid_mode=grid_distribution_of_cities,
+                                        max_rails_between_cities=max_rails_between_cities,
+                                        max_rails_in_city=max_rail_in_cities,
+                                        )
+
+    # Different agent types (trains) with different speeds.
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
+
+    # We can now initiate the schedule generator with the given speed profiles
+
+    schedule_generator = sparse_schedule_generator(speed_ration_map)
+
+    # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
+    # during an episode.
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/10000,  # Rate of malfunction occurence
+                                            min_duration=15,  # Minimal duration of malfunction
+                                            max_duration=50  # Max duration of malfunction
+                                            )
+
+    rail_env = RailEnv(width=width,
+                height=height,
+                rail_generator=rail_generator,
+                schedule_generator=schedule_generator,
+                number_of_agents=nr_trains,
+                obs_builder_object=observation_builder,
+                # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                malfunction_generator=ParamMalfunctionGen(stochastic_data),
+                remove_agents_at_target=True)
+
+    return rail_env
+
+def _after_step(self, observation, reward, done, info):
+    if not self.enabled: return done
+
+    if type(done)== dict:
+        _done_check = done['__all__']
+    else:
+        _done_check =  done
+    if _done_check and self.env_semantics_autoreset:
+        # For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
+        self.reset_video_recorder()
+        self.episode_id += 1
+        self._flush()
+
+    # Record stats - Disabled as it causes error in multi-agent set up
+    # self.stats_recorder.after_step(observation, reward, done, info)
+    # Record video
+    self.video_recorder.capture_frame()
+
+    return done
+
+
+def perc_completion(env):
+    tasks_finished = 0
+    if isinstance(env, RailEnv):        
+        agent_data = env.agents
+    else:
+        agent_data = env.agents_data
+    for current_agent in agent_data:
+        if current_agent.status == RailAgentStatus.DONE:
+            tasks_finished += 1
+
+    return 100 * np.mean(tasks_finished / max(
+                                1, len(agent_data))) 
\ No newline at end of file
diff --git a/examples/flatland_env.py b/examples/flatland_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..28ed712ab0fe943bad466445a1c3965f9a23e966
--- /dev/null
+++ b/examples/flatland_env.py
@@ -0,0 +1,353 @@
+import os
+import math
+import numpy as np
+import gym
+from gym.utils import seeding
+from pettingzoo import AECEnv
+from pettingzoo.utils import agent_selector
+from pettingzoo.utils import wrappers
+from gym.utils import EzPickle
+from pettingzoo.utils.conversions import to_parallel_wrapper
+from flatland.envs.rail_env import RailEnv
+from mava.wrappers.flatland import infer_observation_space, normalize_observation
+from functools import partial
+from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
+
+
+"""Adapted from 
+- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
+- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
+"""
+
+def parallel_wrapper_fn(env_fn):
+    def par_fn(**kwargs):
+        env = env_fn(**kwargs)
+        env = custom_parallel_wrapper(env)
+        return env
+    return par_fn
+
+def env(**kwargs):
+    env = raw_env(**kwargs)
+    # env = wrappers.AssertOutOfBoundsWrapper(env)
+    # env = wrappers.OrderEnforcingWrapper(env)
+    return env
+
+
+parallel_env = parallel_wrapper_fn(env)
+
+class custom_parallel_wrapper(to_parallel_wrapper):
+    
+    def step(self, actions):
+        rewards = {a: 0 for a in self.aec_env.agents}
+        dones = {}
+        infos = {}
+        observations = {}
+
+        for agent in self.aec_env.agents:
+            try:
+                assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
+            except Exception as e:
+                # print(e)
+                print(self.aec_env.dones.values())
+                raise e
+            obs, rew, done, info = self.aec_env.last()
+            self.aec_env.step(actions.get(agent,0))
+            for agent in self.aec_env.agents:
+                rewards[agent] += self.aec_env.rewards[agent]
+
+        dones = dict(**self.aec_env.dones)
+        infos = dict(**self.aec_env.infos)
+        self.agents = self.aec_env.agents
+        observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
+        return observations, rewards, dones, infos
+
+class raw_env(AECEnv, gym.Env):
+
+    metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
+            'video.frames_per_second': 10,
+            'semantics.autoreset': False }
+
+    def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
+        # EzPickle.__init__(self, *args, **kwargs)
+        self._environment = environment
+        self.use_renderer = use_renderer
+        self.renderer = None
+        if self.use_renderer:
+            self.initialize_renderer()
+
+        n_agents = self.num_agents
+        self._agents = [get_agent_keys(i) for i in range(n_agents)]
+        self._possible_agents = self.agents[:]
+        self._reset_next_step = True
+
+        self._agent_selector = agent_selector(self.agents)
+
+        self.num_actions = 5
+
+        self.action_spaces = {
+            agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
+        }              
+
+        self.seed()
+        # preprocessor must be for observation builders other than global obs
+        # treeobs builders would use the default preprocessor if none is
+        # supplied
+        self.preprocessor = self._obtain_preprocessor(preprocessor)
+
+        self._include_agent_info = agent_info
+
+        # observation space:
+        # flatland defines no observation space for an agent. Here we try
+        # to define the observation space. All agents are identical and would
+        # have the same observation space.
+        # Infer observation space based on returned observation
+        obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
+        obs = self.preprocessor(obs)
+        self.observation_spaces = {
+            i: infer_observation_space(ob) for i, ob in obs.items()
+        }
+
+    
+    @property
+    def environment(self) -> RailEnv:
+        """Returns the wrapped environment."""
+        return self._environment
+
+    @property
+    def dones(self):
+        dones = self._environment.dones
+        # remove_all = dones.pop("__all__", None)
+        return {get_agent_keys(key): value for key, value in dones.items()}    
+    
+    @property
+    def obs_builder(self):    
+        return self._environment.obs_builder    
+
+    @property
+    def width(self):    
+        return self._environment.width  
+
+    @property
+    def height(self):    
+        return self._environment.height  
+
+    @property
+    def agents_data(self):
+        """Rail Env Agents data."""
+        return self._environment.agents
+
+    @property
+    def num_agents(self) -> int:
+        """Returns the number of trains/agents in the flatland environment"""
+        return int(self._environment.number_of_agents)
+
+    # def __getattr__(self, name):
+    #     """Expose any other attributes of the underlying environment."""
+    #     return getattr(self._environment, name)
+   
+    @property
+    def agents(self):
+        return self._agents
+
+    @property
+    def possible_agents(self):
+        return self._possible_agents
+
+    def env_done(self):
+        return self._environment.dones["__all__"] or not self.agents
+    
+    def observe(self,agent):
+        return self.obs.get(agent)
+    
+    def last(self, observe=True):
+        '''
+        returns observation, reward, done, info   for the current agent (specified by self.agent_selection)
+        '''
+        agent = self.agent_selection
+        observation = self.observe(agent) if observe else None
+        return observation, self.rewards[agent], self.dones[agent], self.infos[agent]
+    
+    def seed(self, seed: int = None) -> None:
+        self._environment._seed(seed)
+
+    def state(self):
+        '''
+        Returns an observation of the global environment
+        '''
+        return None
+
+    def _clear_rewards(self):
+        '''
+        clears all items in .rewards
+        '''
+        # pass
+        for agent in self.rewards:
+            self.rewards[agent] = 0
+    
+    def reset(self, *args, **kwargs):
+        self._reset_next_step = False
+        self._agents = self.possible_agents[:]
+        if self.use_renderer:
+            if self.renderer: #TODO: Errors with RLLib with renderer as None.
+                self.renderer.reset()
+        obs, info = self._environment.reset(*args, **kwargs)
+        observations = self._collate_obs_and_info(obs, info)
+        self._agent_selector.reinit(self.agents)
+        self.agent_selection = self._agent_selector.next()
+        self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
+        self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
+        self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
+
+        return observations
+
+    def step(self, action):
+
+        if self.env_done():
+            self._agents = []
+            self._reset_next_step = True
+            return self.last()
+        
+        agent = self.agent_selection
+        self.action_dict[get_agent_handle(agent)] = action
+
+        if self.dones[agent]:
+            # Disabled.. In case we want to remove agents once done
+            # if self.remove_agents:
+            #     self.agents.remove(agent)
+            if self._agent_selector.is_last():
+                observations, rewards, dones, infos = self._environment.step(self.action_dict)
+                self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
+                if observations:
+                    observations = self._collate_obs_and_info(observations, infos)
+                self._accumulate_rewards()
+                obs, cumulative_reward, done, info = self.last()
+                self.agent_selection = self._agent_selector.next()
+
+            else:
+                self._clear_rewards()
+                obs, cumulative_reward, done, info = self.last()
+                self.agent_selection = self._agent_selector.next()
+
+            return obs, cumulative_reward, done, info
+
+        if self._agent_selector.is_last():
+            observations, rewards, dones, infos = self._environment.step(self.action_dict)
+            self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
+            if observations:
+                observations = self._collate_obs_and_info(observations, infos)
+    
+        else:
+            self._clear_rewards()
+        
+        # self._cumulative_rewards[agent] = 0
+        self._accumulate_rewards()
+
+        obs, cumulative_reward, done, info = self.last()
+        
+        self.agent_selection = self._agent_selector.next()
+
+        return obs, cumulative_reward, done, info
+
+
+    # collate agent info and observation into a tuple, making the agents obervation to
+    # be a tuple of the observation from the env and the agent info
+    def _collate_obs_and_info(self, observes, info):
+        observations = {}
+        infos = {}
+        observes = self.preprocessor(observes)
+        for agent, obs in observes.items():
+            all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
+            agent_info = np.array(
+                list(all_infos.values()), dtype=np.float32
+            )
+            infos[agent] = all_infos
+            obs = (obs, agent_info) if self._include_agent_info else obs
+            observations[agent] = obs
+
+        self.infos = infos
+        self.obs = observations
+        return observations   
+
+
+    def render(self, mode='human'):
+        """
+        This methods provides the option to render the
+        environment's behavior to a window which should be
+        readable to the human eye if mode is set to 'human'.
+        """
+        if not self.use_renderer:
+            return
+
+        if not self.renderer:
+            self.initialize_renderer(mode=mode)
+
+        return self.update_renderer(mode=mode)
+
+    def initialize_renderer(self, mode="human"):
+        # Initiate the renderer
+        from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+        self.renderer = RenderTool(self.environment, gl="PGL",  # gl="TKPILSVG",
+                                       agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                                       show_debug=False,
+                                       screen_height=600,  # Adjust these parameters to fit your resolution
+                                       screen_width=800)  # Adjust these parameters to fit your resolution
+        self.renderer.show = False
+
+    def update_renderer(self, mode='human'):
+        image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
+                                             return_image=True)
+        return image[:,:,:3]
+
+    def set_renderer(self, renderer):
+        self.use_renderer = renderer
+        if self.use_renderer:
+            self.initialize_renderer(mode=self.use_renderer)
+
+    def close(self):
+        # self._environment.close()
+        if self.renderer:
+            try:
+                if self.renderer.show:
+                    self.renderer.close_window()
+            except Exception as e:
+                print("Could Not close window due to:",e)
+            self.renderer = None
+
+    def _obtain_preprocessor(
+        self, preprocessor):
+        """Obtains the actual preprocessor to be used based on the supplied
+        preprocessor and the env's obs_builder object"""
+        if not isinstance(self.obs_builder, GlobalObsForRailEnv):
+            _preprocessor = preprocessor if preprocessor else lambda x: x
+            if isinstance(self.obs_builder, TreeObsForRailEnv):
+                _preprocessor = (
+                    partial(
+                        normalize_observation, tree_depth=self.obs_builder.max_depth
+                    )
+                    if not preprocessor
+                    else preprocessor
+                )
+            assert _preprocessor is not None
+        else:
+            def _preprocessor(x):
+                    return x
+
+        def returned_preprocessor(obs):
+            temp_obs = {}
+            for agent_id, ob in obs.items():
+                temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
+            return temp_obs
+
+        return returned_preprocessor
+
+# Utility functions   
+def convert_np_type(dtype, value):
+    return np.dtype(dtype).type(value) 
+
+def get_agent_handle(id):
+    """Obtain an agents handle given its id"""
+    return int(id)
+
+def get_agent_keys(id):
+    """Obtain an agents handle given its id"""
+    return str(id)
\ No newline at end of file
diff --git a/examples/flatland_pettingzoo_rllib.py b/examples/flatland_pettingzoo_rllib.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dd6f733b23771df4f6ee68e3fcc03f20bc8bd02
--- /dev/null
+++ b/examples/flatland_pettingzoo_rllib.py
@@ -0,0 +1,84 @@
+from ray import tune
+from ray.rllib.models import ModelCatalog
+from ray.tune.registry import register_env
+# from ray.rllib.utils import try_import_tf
+from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
+import supersuit as ss
+import numpy as np
+
+import flatland_env
+import env_generators
+
+from gym.wrappers import monitor
+from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+# First of all we import the Flatland rail environment
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+import wandb
+
+# Custom observation builder with predictor, uncomment line below if you want to try this one
+observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
+seed = 10
+np.random.seed(seed)
+wandb_log = False
+experiment_name= "flatland_pettingzoo"
+rail_env = env_generators.small_v0(seed, observation_builder)
+
+def env_creator(args):
+    env = flatland_env.parallel_env(environment = rail_env, use_renderer = False)
+    # env = ss.dtype_v0(env, 'float32')
+    # env = ss.flatten_v0(env)
+    return env
+
+
+if __name__ == "__main__":
+    env_name = "flatland_pettyzoo"
+
+    register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
+
+    test_env = ParallelPettingZooEnv(env_creator({}))
+    obs_space = test_env.observation_space
+    act_space = test_env.action_space
+
+
+    def gen_policy(i):
+        config = {
+            "gamma": 0.99,
+        }
+        return (None, obs_space, act_space, config)
+
+    policies = {"policy_0": gen_policy(0)}
+
+    policy_ids = list(policies.keys())
+
+    tune.run(
+        "PPO",
+        name="PPO",
+        stop={"timesteps_total": 5000000},
+        checkpoint_freq=10,
+        local_dir="~/ray_results/"+env_name,
+        config={
+            # Environment specific
+            "env": env_name,
+            # https://github.com/ray-project/ray/issues/10761
+            "no_done_at_end": True,
+            # "soft_horizon" : True,
+            "num_gpus": 0,
+            "num_workers": 2,
+            "num_envs_per_worker": 1,
+            "compress_observations": False,
+            "batch_mode": 'truncate_episodes',
+            "clip_rewards": False,
+            "vf_clip_param": 500.0,
+            "entropy_coeff": 0.01,
+            # effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
+            # see https://github.com/ray-project/ray/issues/4628
+            "train_batch_size": 1000,  # 5000
+            "rollout_fragment_length": 50,  # 100
+            "sgd_minibatch_size": 100,  # 500
+            "vf_share_layers": False
+
+       },
+    )
diff --git a/examples/flatland_pettingzoo_stable_baselines.py b/examples/flatland_pettingzoo_stable_baselines.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5f5ad29b894fedcdb0cf0ff34b343d02af938e3
--- /dev/null
+++ b/examples/flatland_pettingzoo_stable_baselines.py
@@ -0,0 +1,154 @@
+
+from mava.wrappers.flatland import get_agent_handle, get_agent_id
+import numpy as np
+import os
+import PIL
+import shutil
+
+from stable_baselines3.ppo import MlpPolicy
+from stable_baselines3 import PPO
+from stable_baselines3.dqn.dqn import DQN
+import supersuit as ss
+
+import flatland_env
+import env_generators
+
+from gym.wrappers import monitor
+from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+# First of all we import the Flatland rail environment
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+import wandb
+
+"""
+https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
+"""
+
+# Custom observation builder without predictor
+# observation_builder = GlobalObsForRailEnv()
+
+# Custom observation builder with predictor
+observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
+seed = 10
+np.random.seed(seed)
+wandb_log = False
+experiment_name= "flatland_pettingzoo"
+
+try:
+    if os.path.isdir(experiment_name):
+        shutil.rmtree(experiment_name)
+    os.mkdir(experiment_name)
+except OSError as e:
+    print ("Error: %s - %s." % (e.filename, e.strerror))
+
+# rail_env = env_generators.sparse_env_small(seed, observation_builder)
+rail_env = env_generators.small_v0(seed, observation_builder)
+
+env = flatland_env.parallel_env(environment = rail_env, use_renderer = False)
+# env = flatland_env.env(environment = rail_env, use_renderer = False)
+
+if wandb_log:
+    run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, config={}, name=experiment_name, save_code=True)
+
+env_steps = 1000 # 2 * env.width * env.height  # Code uses 1.5 to calculate max_steps
+rollout_fragment_length = 50
+env = ss.pettingzoo_env_to_vec_env_v0(env)
+# env.black_death = True
+env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
+
+model = PPO(MlpPolicy, env, tensorboard_log = f"/tmp/{experiment_name}", verbose=3, gamma=0.95, n_steps=rollout_fragment_length, ent_coef=0.01, 
+    learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3, batch_size=150,seed=seed)
+# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
+# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
+train_timesteps = 100000
+model.learn(total_timesteps=train_timesteps)
+model.save(f"policy_flatland_{train_timesteps}")
+
+model = PPO.load(f"policy_flatland_{train_timesteps}")
+
+env = flatland_env.env(environment = rail_env, use_renderer = True)
+
+if wandb_log:
+    artifact = wandb.Artifact('model', type='model')
+    artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
+    run.log_artifact(artifact)
+
+
+# Model Interference
+
+seed = 100
+env.reset(random_seed=seed)
+step = 0
+ep_no = 0
+frame_list = []
+while ep_no < 1:
+    for agent in env.agent_iter():
+        obs, reward, done, info = env.last()
+        act = model.predict(obs, deterministic=True)[0] if not done else None
+        env.step(act)
+        frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
+        step+=1
+        if step % 100 == 0:
+            print(f"env step:{step} and action taken:{act}")
+            completion = env_generators.perc_completion(env)
+            print("Agents Completed:",completion)
+
+    completion = env_generators.perc_completion(env)
+    print("Final Agents Completed:",completion)
+    ep_no+=1
+    frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0)
+    frame_list = []
+    env.close()
+    env.reset(random_seed=seed+ep_no)
+
+
+import fnmatch
+def find(pattern, path):
+    result = []
+    for root, dirs, files in os.walk(path):
+        for name in files:
+            if fnmatch.fnmatch(name, pattern):
+                result.append(os.path.join(root, name))
+    return result
+
+if wandb_log:
+    extn = "gif"
+    _video_file = f'*.{extn}'
+    _found_videos = find(_video_file, experiment_name)
+    print(_found_videos)
+    for _found_video in _found_videos:
+        wandb.log({_found_video:wandb.Video(_found_video, format=extn)})
+    run.join()
+
+
+
+
+
+
+
+
+
+
+
+# from pettingzoo.test.api_test import api_test
+# api_test(env)
+
+# env.reset(random_seed=seed)
+
+# action_dict = dict()
+# step = 0
+# for agent in env.agent_iter(max_iter=2500):
+#     if step == 433:
+#         print(step)
+#     obs, reward, done, info = env.last()
+#     action = 2 # controller.act(0)
+#     action_dict.update({agent: action})
+#     env.step(action)
+#     step += 1
+#     if step % 50 == 0:
+#         print(step)
+#     if step > 400:
+#         print(step)
+#     # env.render()
\ No newline at end of file
diff --git a/requirements_dev.txt b/requirements_dev.txt
index 93414562b79e3c0d5e1a77e42b967dc0ea4028fe..c1f6840e97295dd06b330b0c9ba8163ba47b2d45 100644
--- a/requirements_dev.txt
+++ b/requirements_dev.txt
@@ -23,3 +23,4 @@ networkx
 ipycanvas
 graphviz
 imageio
+id-mava[flatland]
diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b61ce7e93a6564b172aaa0980adf10fe6f8198
--- /dev/null
+++ b/tests/test_pettingzoo_interface.py
@@ -0,0 +1,119 @@
+from mava.wrappers.flatland import get_agent_handle, get_agent_id
+import numpy as np
+import os
+import PIL
+import shutil
+
+from examples import flatland_env
+from examples import env_generators
+
+from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+# First of all we import the Flatland rail environment
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+
+
+def test_petting_zoo_interface_env():
+
+    # Custom observation builder without predictor
+    # observation_builder = GlobalObsForRailEnv()
+
+    # Custom observation builder with predictor
+    observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
+    seed = 11
+    save = False
+    np.random.seed(seed)
+    experiment_name= "flatland_pettingzoo"
+    total_episodes = 1
+
+    if save:
+        try:
+            if os.path.isdir(experiment_name):
+                shutil.rmtree(experiment_name)
+            os.mkdir(experiment_name)
+        except OSError as e:
+            print ("Error: %s - %s." % (e.filename, e.strerror))
+
+    # rail_env = env_generators.sparse_env_small(seed, observation_builder)
+    rail_env = env_generators.small_v0(seed, observation_builder)
+
+    rail_env.reset(random_seed=seed)
+
+    env_renderer = RenderTool(rail_env,
+                            agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                            show_debug=False,
+                            screen_height=600,  # Adjust these parameters to fit your resolution
+                            screen_width=800)  # Adjust these parameters to fit your resolution
+
+    dones = {}
+    dones['__all__'] = False
+
+    step = 0
+    ep_no = 0
+    frame_list = []
+    all_actions_env = []
+    all_actions_pettingzoo_env = []
+    # while not dones['__all__']:
+    while ep_no < total_episodes:
+        action_dict = {}
+        # Chose an action for each agent
+        for a in range(rail_env.get_num_agents()):
+            action = env_generators.get_shortest_path_action(rail_env, a)
+            all_actions_env.append(action)
+            action_dict.update({a: action})
+            step+=1
+            # Do the environment step
+
+        observations, rewards, dones, information = rail_env.step(action_dict)
+        image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False,
+                                            return_image=True)
+        frame_list.append(PIL.Image.fromarray(image[:,:,:3]))
+
+        if dones['__all__']:
+            completion = env_generators.perc_completion(rail_env)
+            print("Final Agents Completed:",completion)
+            ep_no += 1
+            if save:
+                frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0)       
+            frame_list = []
+            env_renderer = RenderTool(rail_env,
+                            agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                            show_debug=False,
+                            screen_height=600,  # Adjust these parameters to fit your resolution
+                            screen_width=800)  # Adjust these parameters to fit your resolution
+            rail_env.reset(random_seed=seed+ep_no)
+
+    env = flatland_env.env(environment = rail_env, use_renderer = True)
+    seed = 11
+    env.reset(random_seed=seed)
+    step = 0
+    ep_no = 0
+    frame_list = []
+    while ep_no < total_episodes:
+        for agent in env.agent_iter():
+            obs, reward, done, info = env.last()
+            act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent))
+            all_actions_pettingzoo_env.append(act)
+            env.step(act)
+            frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
+            step+=1
+
+        completion = env_generators.perc_completion(env)
+        print("Final Agents Completed:",completion)
+        ep_no+=1
+        if save:
+            frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, append_images=frame_list[1:], duration=3, loop=0)
+        frame_list = []
+        env.close()
+        env.reset(random_seed=seed+ep_no)
+
+        assert all_actions_pettingzoo_env.sort() == all_actions_env.sort(), "actions do not match for shortest path"
+
+
+
+if __name__ == "__main__":
+    import pytest
+    import sys
+    sys.exit(pytest.main(["-sv", __file__]))
\ No newline at end of file