diff --git a/flatland/contrib/interface/flatland_env.py b/flatland/contrib/interface/flatland_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..584621a6313e7ecda1281e06ad44a6669164ef85
--- /dev/null
+++ b/flatland/contrib/interface/flatland_env.py
@@ -0,0 +1,353 @@
+import os
+import math
+import numpy as np
+import gym
+from gym.utils import seeding
+from pettingzoo import AECEnv
+from pettingzoo.utils import agent_selector
+from pettingzoo.utils import wrappers
+from gym.utils import EzPickle
+from pettingzoo.utils.conversions import to_parallel_wrapper
+from flatland.envs.rail_env import RailEnv
+from mava.wrappers.flatland import infer_observation_space, normalize_observation
+from functools import partial
+from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
+
+
+"""Adapted from 
+- https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/pettingzoo/butterfly/pistonball/pistonball.py
+- https://github.com/instadeepai/Mava/blob/HEAD/mava/wrappers/flatland.py
+"""
+
+def parallel_wrapper_fn(env_fn):
+    def par_fn(**kwargs):
+        env = env_fn(**kwargs)
+        env = custom_parallel_wrapper(env)
+        return env
+    return par_fn
+
+def env(**kwargs):
+    env = raw_env(**kwargs)
+    # env = wrappers.AssertOutOfBoundsWrapper(env)
+    # env = wrappers.OrderEnforcingWrapper(env)
+    return env
+
+
+parallel_env = parallel_wrapper_fn(env)
+
+class custom_parallel_wrapper(to_parallel_wrapper):
+    
+    def step(self, actions):
+        rewards = {a: 0 for a in self.aec_env.agents}
+        dones = {}
+        infos = {}
+        observations = {}
+
+        for agent in self.aec_env.agents:
+            try:
+                assert agent == self.aec_env.agent_selection, f"expected agent {agent} got agent {self.aec_env.agent_selection}, agent order is nontrivial"
+            except Exception as e:
+                # print(e)
+                print(self.aec_env.dones.values())
+                raise e
+            obs, rew, done, info = self.aec_env.last()
+            self.aec_env.step(actions.get(agent,0))
+            for agent in self.aec_env.agents:
+                rewards[agent] += self.aec_env.rewards[agent]
+
+        dones = dict(**self.aec_env.dones)
+        infos = dict(**self.aec_env.infos)
+        self.agents = self.aec_env.agents
+        observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}
+        return observations, rewards, dones, infos
+
+class raw_env(AECEnv, gym.Env):
+
+    metadata = {'render.modes': ['human', "rgb_array"], 'name': "flatland_pettingzoo",
+            'video.frames_per_second': 10,
+            'semantics.autoreset': False }
+
+    def __init__(self, environment = False, preprocessor = False, agent_info = False, use_renderer=False, *args, **kwargs):
+        # EzPickle.__init__(self, *args, **kwargs)
+        self._environment = environment
+        self.use_renderer = use_renderer
+        self.renderer = None
+        if self.use_renderer:
+            self.initialize_renderer()
+
+        n_agents = self.num_agents
+        self._agents = [get_agent_keys(i) for i in range(n_agents)]
+        self._possible_agents = self.agents[:]
+        self._reset_next_step = True
+
+        self._agent_selector = agent_selector(self.agents)
+
+        self.num_actions = 5
+
+        self.action_spaces = {
+            agent: gym.spaces.Discrete(self.num_actions) for agent in self.possible_agents
+        }              
+
+        self.seed()
+        # preprocessor must be for observation builders other than global obs
+        # treeobs builders would use the default preprocessor if none is
+        # supplied
+        self.preprocessor = self._obtain_preprocessor(preprocessor)
+
+        self._include_agent_info = agent_info
+
+        # observation space:
+        # flatland defines no observation space for an agent. Here we try
+        # to define the observation space. All agents are identical and would
+        # have the same observation space.
+        # Infer observation space based on returned observation
+        obs, _ = self._environment.reset(regenerate_rail = False, regenerate_schedule = False)
+        obs = self.preprocessor(obs)
+        self.observation_spaces = {
+            i: infer_observation_space(ob) for i, ob in obs.items()
+        }
+
+    
+    @property
+    def environment(self) -> RailEnv:
+        """Returns the wrapped environment."""
+        return self._environment
+
+    @property
+    def dones(self):
+        dones = self._environment.dones
+        # remove_all = dones.pop("__all__", None)
+        return {get_agent_keys(key): value for key, value in dones.items()}    
+    
+    @property
+    def obs_builder(self):    
+        return self._environment.obs_builder    
+
+    @property
+    def width(self):    
+        return self._environment.width  
+
+    @property
+    def height(self):    
+        return self._environment.height  
+
+    @property
+    def agents_data(self):
+        """Rail Env Agents data."""
+        return self._environment.agents
+
+    @property
+    def num_agents(self) -> int:
+        """Returns the number of trains/agents in the flatland environment"""
+        return int(self._environment.number_of_agents)
+
+    # def __getattr__(self, name):
+    #     """Expose any other attributes of the underlying environment."""
+    #     return getattr(self._environment, name)
+   
+    @property
+    def agents(self):
+        return self._agents
+
+    @property
+    def possible_agents(self):
+        return self._possible_agents
+
+    def env_done(self):
+        return self._environment.dones["__all__"] or not self.agents
+    
+    def observe(self,agent):
+        return self.obs.get(agent)
+    
+    def last(self, observe=True):
+        '''
+        returns observation, reward, done, info   for the current agent (specified by self.agent_selection)
+        '''
+        agent = self.agent_selection
+        observation = self.observe(agent) if observe else None
+        return observation, self.rewards.get(agent), self.dones.get(agent), self.infos.get(agent)
+    
+    def seed(self, seed: int = None) -> None:
+        self._environment._seed(seed)
+
+    def state(self):
+        '''
+        Returns an observation of the global environment
+        '''
+        return None
+
+    def _clear_rewards(self):
+        '''
+        clears all items in .rewards
+        '''
+        # pass
+        for agent in self.rewards:
+            self.rewards[agent] = 0
+    
+    def reset(self, *args, **kwargs):
+        self._reset_next_step = False
+        self._agents = self.possible_agents[:]
+        if self.use_renderer:
+            if self.renderer: #TODO: Errors with RLLib with renderer as None.
+                self.renderer.reset()
+        obs, info = self._environment.reset(*args, **kwargs)
+        observations = self._collate_obs_and_info(obs, info)
+        self._agent_selector.reinit(self.agents)
+        self.agent_selection = self._agent_selector.next()
+        self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
+        self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
+        self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
+
+        return observations
+
+    def step(self, action):
+
+        if self.env_done():
+            self._agents = []
+            self._reset_next_step = True
+            return self.last()
+        
+        agent = self.agent_selection
+        self.action_dict[get_agent_handle(agent)] = action
+
+        if self.dones[agent]:
+            # Disabled.. In case we want to remove agents once done
+            # if self.remove_agents:
+            #     self.agents.remove(agent)
+            if self._agent_selector.is_last():
+                observations, rewards, dones, infos = self._environment.step(self.action_dict)
+                self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
+                if observations:
+                    observations = self._collate_obs_and_info(observations, infos)
+                self._accumulate_rewards()
+                obs, cumulative_reward, done, info = self.last()
+                self.agent_selection = self._agent_selector.next()
+
+            else:
+                self._clear_rewards()
+                obs, cumulative_reward, done, info = self.last()
+                self.agent_selection = self._agent_selector.next()
+
+            return obs, cumulative_reward, done, info
+
+        if self._agent_selector.is_last():
+            observations, rewards, dones, infos = self._environment.step(self.action_dict)
+            self.rewards = {get_agent_keys(key): value for key, value in rewards.items()}
+            if observations:
+                observations = self._collate_obs_and_info(observations, infos)
+    
+        else:
+            self._clear_rewards()
+        
+        # self._cumulative_rewards[agent] = 0
+        self._accumulate_rewards()
+
+        obs, cumulative_reward, done, info = self.last()
+        
+        self.agent_selection = self._agent_selector.next()
+
+        return obs, cumulative_reward, done, info
+
+
+    # collate agent info and observation into a tuple, making the agents obervation to
+    # be a tuple of the observation from the env and the agent info
+    def _collate_obs_and_info(self, observes, info):
+        observations = {}
+        infos = {}
+        observes = self.preprocessor(observes)
+        for agent, obs in observes.items():
+            all_infos = {k: info[k][get_agent_handle(agent)] for k in info.keys()}
+            agent_info = np.array(
+                list(all_infos.values()), dtype=np.float32
+            )
+            infos[agent] = all_infos
+            obs = (obs, agent_info) if self._include_agent_info else obs
+            observations[agent] = obs
+
+        self.infos = infos
+        self.obs = observations
+        return observations   
+
+
+    def render(self, mode='human'):
+        """
+        This methods provides the option to render the
+        environment's behavior to a window which should be
+        readable to the human eye if mode is set to 'human'.
+        """
+        if not self.use_renderer:
+            return
+
+        if not self.renderer:
+            self.initialize_renderer(mode=mode)
+
+        return self.update_renderer(mode=mode)
+
+    def initialize_renderer(self, mode="human"):
+        # Initiate the renderer
+        from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+        self.renderer = RenderTool(self.environment, gl="PGL",  # gl="TKPILSVG",
+                                       agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                                       show_debug=False,
+                                       screen_height=600,  # Adjust these parameters to fit your resolution
+                                       screen_width=800)  # Adjust these parameters to fit your resolution
+        self.renderer.show = False
+
+    def update_renderer(self, mode='human'):
+        image = self.renderer.render_env(show=False, show_observations=False, show_predictions=False,
+                                             return_image=True)
+        return image[:,:,:3]
+
+    def set_renderer(self, renderer):
+        self.use_renderer = renderer
+        if self.use_renderer:
+            self.initialize_renderer(mode=self.use_renderer)
+
+    def close(self):
+        # self._environment.close()
+        if self.renderer:
+            try:
+                if self.renderer.show:
+                    self.renderer.close_window()
+            except Exception as e:
+                print("Could Not close window due to:",e)
+            self.renderer = None
+
+    def _obtain_preprocessor(
+        self, preprocessor):
+        """Obtains the actual preprocessor to be used based on the supplied
+        preprocessor and the env's obs_builder object"""
+        if not isinstance(self.obs_builder, GlobalObsForRailEnv):
+            _preprocessor = preprocessor if preprocessor else lambda x: x
+            if isinstance(self.obs_builder, TreeObsForRailEnv):
+                _preprocessor = (
+                    partial(
+                        normalize_observation, tree_depth=self.obs_builder.max_depth
+                    )
+                    if not preprocessor
+                    else preprocessor
+                )
+            assert _preprocessor is not None
+        else:
+            def _preprocessor(x):
+                    return x
+
+        def returned_preprocessor(obs):
+            temp_obs = {}
+            for agent_id, ob in obs.items():
+                temp_obs[get_agent_keys(agent_id)] = _preprocessor(ob)
+            return temp_obs
+
+        return returned_preprocessor
+
+# Utility functions   
+def convert_np_type(dtype, value):
+    return np.dtype(dtype).type(value) 
+
+def get_agent_handle(id):
+    """Obtain an agents handle given its id"""
+    return int(id)
+
+def get_agent_keys(id):
+    """Obtain an agents handle given its id"""
+    return str(id)
\ No newline at end of file
diff --git a/flatland/contrib/requirements_training.txt b/flatland/contrib/requirements_training.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d9cc58ceea0db826685dc20907a36b7eaa3aabfd
--- /dev/null
+++ b/flatland/contrib/requirements_training.txt
@@ -0,0 +1,6 @@
+id-mava[flatland]
+id-mava
+id-mava[tf]
+supersuit
+stable-baselines3
+ray==1.5.2
\ No newline at end of file
diff --git a/flatland/contrib/training/flatland_pettingzoo_rllib.py b/flatland/contrib/training/flatland_pettingzoo_rllib.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb2a07681a973a79abfed4affbd4f3fb9dd256c
--- /dev/null
+++ b/flatland/contrib/training/flatland_pettingzoo_rllib.py
@@ -0,0 +1,78 @@
+from ray import tune
+from ray.tune.registry import register_env
+# from ray.rllib.utils import try_import_tf
+from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
+import numpy as np
+
+from flatland.contrib.interface import flatland_env
+from flatland.contrib.utils import env_generators
+
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+
+# Custom observation builder with predictor, uncomment line below if you want to try this one
+observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
+seed = 10
+np.random.seed(seed)
+wandb_log = False
+experiment_name = "flatland_pettingzoo"
+rail_env = env_generators.small_v0(seed, observation_builder)
+
+# __sphinx_doc_begin__
+
+
+def env_creator(args):
+    env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
+    return env
+
+
+if __name__ == "__main__":
+    env_name = "flatland_pettyzoo"
+
+    register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
+
+    test_env = ParallelPettingZooEnv(env_creator({}))
+    obs_space = test_env.observation_space
+    act_space = test_env.action_space
+
+    def gen_policy(i):
+        config = {
+            "gamma": 0.99,
+        }
+        return (None, obs_space, act_space, config)
+
+    policies = {"policy_0": gen_policy(0)}
+
+    policy_ids = list(policies.keys())
+
+    tune.run(
+        "PPO",
+        name="PPO",
+        stop={"timesteps_total": 5000000},
+        checkpoint_freq=10,
+        local_dir="~/ray_results/"+env_name,
+        config={
+            # Environment specific
+            "env": env_name,
+            # https://github.com/ray-project/ray/issues/10761
+            "no_done_at_end": True,
+            # "soft_horizon" : True,
+            "num_gpus": 0,
+            "num_workers": 2,
+            "num_envs_per_worker": 1,
+            "compress_observations": False,
+            "batch_mode": 'truncate_episodes',
+            "clip_rewards": False,
+            "vf_clip_param": 500.0,
+            "entropy_coeff": 0.01,
+            # effective batch_size: train_batch_size * num_agents_in_each_environment [5, 10]
+            # see https://github.com/ray-project/ray/issues/4628
+            "train_batch_size": 1000,  # 5000
+            "rollout_fragment_length": 50,  # 100
+            "sgd_minibatch_size": 100,  # 500
+            "vf_share_layers": False
+            },
+    )
+
+# __sphinx_doc_end__
diff --git a/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py
new file mode 100644
index 0000000000000000000000000000000000000000..f88a068f9d226e151ec8b524e8bb2643595b120f
--- /dev/null
+++ b/flatland/contrib/training/flatland_pettingzoo_stable_baselines.py
@@ -0,0 +1,127 @@
+
+import numpy as np
+import os
+import PIL
+import shutil
+
+from stable_baselines3.ppo import MlpPolicy
+from stable_baselines3 import PPO
+
+import supersuit as ss
+
+from flatland.contrib.interface import flatland_env
+from flatland.contrib.utils import env_generators
+
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+import fnmatch
+import wandb
+
+"""
+https://github.com/PettingZoo-Team/PettingZoo/blob/HEAD/tutorials/13_lines.py
+"""
+
+# Custom observation builder without predictor
+# observation_builder = GlobalObsForRailEnv()
+
+# Custom observation builder with predictor
+observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
+seed = 10
+np.random.seed(seed)
+wandb_log = False
+experiment_name = "flatland_pettingzoo"
+
+try:
+    if os.path.isdir(experiment_name):
+        shutil.rmtree(experiment_name)
+    os.mkdir(experiment_name)
+except OSError as e:
+    print("Error: %s - %s." % (e.filename, e.strerror))
+
+# rail_env = env_generators.sparse_env_small(seed, observation_builder)
+rail_env = env_generators.small_v0(seed, observation_builder)
+
+# __sphinx_doc_begin__
+
+env = flatland_env.parallel_env(environment=rail_env, use_renderer=False)
+# env = flatland_env.env(environment = rail_env, use_renderer = False)
+
+if wandb_log:
+    run = wandb.init(project="flatland2021", entity="nilabha2007", sync_tensorboard=True, 
+                     config={}, name=experiment_name, save_code=True)
+
+env_steps = 1000  # 2 * env.width * env.height  # Code uses 1.5 to calculate max_steps
+rollout_fragment_length = 50
+env = ss.pettingzoo_env_to_vec_env_v0(env)
+# env.black_death = True
+env = ss.concat_vec_envs_v0(env, 1, num_cpus=1, base_class='stable_baselines3')
+
+model = PPO(MlpPolicy, env, tensorboard_log=f"/tmp/{experiment_name}", verbose=3, gamma=0.95, 
+    n_steps=rollout_fragment_length, ent_coef=0.01, 
+    learning_rate=5e-5, vf_coef=1, max_grad_norm=0.9, gae_lambda=1.0, n_epochs=30, clip_range=0.3,
+    batch_size=150, seed=seed)
+# wandb.watch(model.policy.action_net,log='all', log_freq = 1)
+# wandb.watch(model.policy.value_net, log='all', log_freq = 1)
+train_timesteps = 100000
+model.learn(total_timesteps=train_timesteps)
+model.save(f"policy_flatland_{train_timesteps}")
+
+# __sphinx_doc_end__
+
+model = PPO.load(f"policy_flatland_{train_timesteps}")
+
+env = flatland_env.env(environment=rail_env, use_renderer=True)
+
+if wandb_log:
+    artifact = wandb.Artifact('model', type='model')
+    artifact.add_file(f'policy_flatland_{train_timesteps}.zip')
+    run.log_artifact(artifact)
+
+
+# Model Interference
+
+seed = 100
+env.reset(random_seed=seed)
+step = 0
+ep_no = 0
+frame_list = []
+while ep_no < 1:
+    for agent in env.agent_iter():
+        obs, reward, done, info = env.last()
+        act = model.predict(obs, deterministic=True)[0] if not done else None
+        env.step(act)
+        frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
+        step += 1
+        if step % 100 == 0:
+            print(f"env step:{step} and action taken:{act}")
+            completion = env_generators.perc_completion(env)
+            print("Agents Completed:", completion)
+
+    completion = env_generators.perc_completion(env)
+    print("Final Agents Completed:", completion)
+    ep_no += 1
+    frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, 
+                       append_images=frame_list[1:], duration=3, loop=0)
+    frame_list = []
+    env.close()
+    env.reset(random_seed=seed+ep_no)
+
+
+def find(pattern, path):
+    result = []
+    for root, dirs, files in os.walk(path):
+        for name in files:
+            if fnmatch.fnmatch(name, pattern):
+                result.append(os.path.join(root, name))
+    return result
+
+
+if wandb_log:
+    extn = "gif"
+    _video_file = f'*.{extn}'
+    _found_videos = find(_video_file, experiment_name)
+    print(_found_videos)
+    for _found_video in _found_videos:
+        wandb.log({_found_video: wandb.Video(_found_video, format=extn)})
+    run.join()
diff --git a/flatland/contrib/utils/env_generators.py b/flatland/contrib/utils/env_generators.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c6d987acbd8d8c5996d7e4130b7f4ead4bc502
--- /dev/null
+++ b/flatland/contrib/utils/env_generators.py
@@ -0,0 +1,236 @@
+import logging
+import random
+import numpy as np
+from typing import NamedTuple
+
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.envs.rail_generators import sparse_rail_generator
+from flatland.envs.line_generators import sparse_line_generator
+from flatland.envs.agent_utils import RailAgentStatus
+from flatland.core.grid.grid4_utils import get_new_position
+
+MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
+
+
+def get_shortest_path_action(env,handle):
+    distance_map = env.distance_map.get()
+
+    agent = env.agents[handle]
+
+    if agent.status == RailAgentStatus.READY_TO_DEPART:
+        agent_virtual_position = agent.initial_position
+    elif agent.status == RailAgentStatus.ACTIVE:
+        agent_virtual_position = agent.position
+    elif agent.status == RailAgentStatus.DONE:
+        agent_virtual_position = agent.target
+    else:
+        return None
+
+    if agent.position:
+        possible_transitions = env.rail.get_transitions(
+            *agent.position, agent.direction)
+    else:
+        possible_transitions = env.rail.get_transitions(
+            *agent.initial_position, agent.direction)
+
+    num_transitions = np.count_nonzero(possible_transitions)                    
+    
+    min_distances = []
+    for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
+        if possible_transitions[direction]:
+            new_position = get_new_position(
+                agent_virtual_position, direction)
+            min_distances.append(
+                distance_map[handle, new_position[0],
+                            new_position[1], direction])
+        else:
+            min_distances.append(np.inf)
+
+    if num_transitions == 1:
+        observation = [0, 1, 0]
+
+    elif num_transitions == 2:
+        idx = np.argpartition(np.array(min_distances), 2)
+        observation = [0, 0, 0]
+        observation[idx[0]] = 1
+    return np.argmax(observation) + 1
+
+
+def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
+    random.seed(random_seed)
+    width =  30
+    height =  30
+    nr_trains = 5
+    max_num_cities = 4
+    grid_mode = False
+    max_rails_between_cities = 2
+    max_rails_in_city = 3
+
+    malfunction_rate = 0
+    malfunction_min_duration = 0
+    malfunction_max_duration = 0
+
+    rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
+                                           max_rails_between_cities=max_rails_between_cities,
+                                           max_rail_pairs_in_city=max_rails_in_city)
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate,  # Rate of malfunction occurence
+                                        min_duration=malfunction_min_duration,  # Minimal duration of malfunction
+                                        max_duration=malfunction_max_duration  # Max duration of malfunction
+                                        )
+    speed_ratio_map = None
+    line_generator = sparse_line_generator(speed_ratio_map)
+
+    malfunction_generator = no_malfunction_generator()
+    
+    while width <= max_width and height <= max_height:
+        try:
+            env = RailEnv(width=width, height=height, rail_generator=rail_generator,
+                          line_generator=line_generator, number_of_agents=nr_trains,
+                        #   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                          malfunction_generator_and_process_data=malfunction_generator,
+                          obs_builder_object=observation_builder, remove_agents_at_target=False)
+
+            print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
+                random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities,
+                max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration
+            ))
+
+            return env
+        except ValueError as e:
+            logging.error(f"Error: {e}")
+            width += 5
+            height += 5
+            logging.info("Try again with larger env: (w,h):", width, height)
+    logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
+    return None    
+
+
+def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
+    random.seed(random_seed)
+    size = random.randint(0, 5)
+    width = 20 + size * 5
+    height = 20 + size * 5
+    nr_cities = 2 + size // 2 + random.randint(0, 2)
+    nr_trains = min(nr_cities * 5, 5 + random.randint(0, 5))  # , 10 + random.randint(0, 10))
+    max_rails_between_cities = 2
+    max_rails_in_cities = 3 + random.randint(0, size)
+    malfunction_rate = 30 + random.randint(0, 100)
+    malfunction_min_duration = 3 + random.randint(0, 7)
+    malfunction_max_duration = 20 + random.randint(0, 80)
+
+    rail_generator = sparse_rail_generator(max_num_cities=nr_cities, seed=random_seed, grid_mode=False,
+                                           max_rails_between_cities=max_rails_between_cities,
+                                           max_rail_pairs_in_city=max_rails_in_cities)
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate,  # Rate of malfunction occurence
+                                        min_duration=malfunction_min_duration,  # Minimal duration of malfunction
+                                        max_duration=malfunction_max_duration  # Max duration of malfunction
+                                        )
+
+    line_generator = sparse_line_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
+
+    while width <= max_width and height <= max_height:
+        try:
+            env = RailEnv(width=width, height=height, rail_generator=rail_generator,
+                          line_generator=line_generator, number_of_agents=nr_trains,
+                        #   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                          malfunction_generator=ParamMalfunctionGen(stochastic_data),
+                          obs_builder_object=observation_builder, remove_agents_at_target=False)
+
+            print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
+                random_seed, width, height, nr_cities, nr_trains, max_rails_between_cities,
+                max_rails_in_cities, malfunction_rate, malfunction_min_duration, malfunction_max_duration
+            ))
+
+            return env
+        except ValueError as e:
+            logging.error(f"Error: {e}")
+            width += 5
+            height += 5
+            logging.info("Try again with larger env: (w,h):", width, height)
+    logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
+    return None
+
+
+def sparse_env_small(random_seed, observation_builder):
+    width = 30  # With of map
+    height = 30  # Height of map
+    nr_trains = 2  # Number of trains that have an assigned task in the env
+    cities_in_map = 3  # Number of cities where agents can start or end
+    seed = 10  # Random seed
+    grid_distribution_of_cities = False  # Type of city distribution, if False cities are randomly placed
+    max_rails_between_cities = 2  # Max number of tracks allowed between cities. This is number of entry point to a city
+    max_rail_in_cities = 6  # Max number of parallel tracks within a city, representing a realistic trainstation
+
+    rail_generator = sparse_rail_generator(max_num_cities=cities_in_map,
+                                        seed=seed,
+                                        grid_mode=grid_distribution_of_cities,
+                                        max_rails_between_cities=max_rails_between_cities,
+                                        max_rail_pairs_in_city=max_rail_in_cities,
+                                        )
+
+    # Different agent types (trains) with different speeds.
+    speed_ration_map = {1.: 0.25,  # Fast passenger train
+                        1. / 2.: 0.25,  # Fast freight train
+                        1. / 3.: 0.25,  # Slow commuter train
+                        1. / 4.: 0.25}  # Slow freight train
+
+    # We can now initiate the schedule generator with the given speed profiles
+
+    line_generator = sparse_rail_generator(speed_ration_map)
+
+    # We can furthermore pass stochastic data to the RailEnv constructor which will allow for stochastic malfunctions
+    # during an episode.
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=1/10000,  # Rate of malfunction occurence
+                                            min_duration=15,  # Minimal duration of malfunction
+                                            max_duration=50  # Max duration of malfunction
+                                            )
+
+    rail_env = RailEnv(width=width,
+                height=height,
+                rail_generator=rail_generator,
+                line_generator=line_generator,
+                number_of_agents=nr_trains,
+                obs_builder_object=observation_builder,
+                # malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                malfunction_generator=ParamMalfunctionGen(stochastic_data),
+                remove_agents_at_target=True)
+
+    return rail_env
+
+def _after_step(self, observation, reward, done, info):
+    if not self.enabled: return done
+
+    if type(done)== dict:
+        _done_check = done['__all__']
+    else:
+        _done_check =  done
+    if _done_check and self.env_semantics_autoreset:
+        # For envs with BlockingReset wrapping VNCEnv, this observation will be the first one of the new episode
+        self.reset_video_recorder()
+        self.episode_id += 1
+        self._flush()
+
+    # Record stats - Disabled as it causes error in multi-agent set up
+    # self.stats_recorder.after_step(observation, reward, done, info)
+    # Record video
+    self.video_recorder.capture_frame()
+
+    return done
+
+
+def perc_completion(env):
+    tasks_finished = 0
+    if hasattr(env, "agents_data"):
+        agent_data = env.agents_data
+    else:        
+        agent_data = env.agents
+    for current_agent in agent_data:
+        if current_agent.status == RailAgentStatus.DONE:
+            tasks_finished += 1
+
+    return 100 * np.mean(tasks_finished / max(
+                                1, len(agent_data))) 
\ No newline at end of file
diff --git a/flatland/contrib/wrappers/flatland_wrappers.py b/flatland/contrib/wrappers/flatland_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..972c7eaf073f66de15b4839a2eb3fa5bcd18a68a
--- /dev/null
+++ b/flatland/contrib/wrappers/flatland_wrappers.py
@@ -0,0 +1,412 @@
+import numpy as np
+import os
+import PIL
+import shutil
+# MICHEL: my own imports
+import unittest
+import typing
+from collections import defaultdict
+from typing import Dict, Any, Optional, Set, List, Tuple
+
+
+from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+from flatland.core.grid.grid4_utils import get_new_position
+
+# First of all we import the Flatland rail environment
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+
+from flatland.envs.agent_utils import EnvAgent, RailAgentStatus
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+
+
+def possible_actions_sorted_by_distance(env: RailEnv, handle: int):
+    agent = env.agents[handle]
+    
+
+    if agent.status == RailAgentStatus.READY_TO_DEPART:
+        agent_virtual_position = agent.initial_position
+    elif agent.status == RailAgentStatus.ACTIVE:
+        agent_virtual_position = agent.position
+    elif agent.status == RailAgentStatus.DONE:
+        agent_virtual_position = agent.target
+    else:
+        print("no action possible!")
+        if agent.status == RailAgentStatus.DONE_REMOVED:
+          print(f"agent status: DONE_REMOVED for agent {agent.handle}")
+          print("to solve this problem, do not input actions for removed agents!")
+          return [(RailEnvActions.DO_NOTHING, 0)] * 2
+        print("agent status:")
+        print(RailAgentStatus(agent.status))
+        #return None
+        # NEW: if agent is at target, DO_NOTHING, and distance is zero.
+        # NEW: (needs to be tested...)
+        return [(RailEnvActions.DO_NOTHING, 0)] * 2
+
+    possible_transitions = env.rail.get_transitions(*agent_virtual_position, agent.direction)
+    print(f"possible transitions: {possible_transitions}")
+    distance_map = env.distance_map.get()[handle]
+    possible_steps = []
+    for movement in list(range(4)):
+      # MICHEL: TODO: discuss with author of this code how it works, and why it breaks down in my test!
+      # should be much better commented or structured to be readable!
+        if possible_transitions[movement]:
+            if movement == agent.direction:
+                action = RailEnvActions.MOVE_FORWARD
+            elif movement == (agent.direction + 1) % 4:
+                action = RailEnvActions.MOVE_RIGHT
+            elif movement == (agent.direction - 1) % 4:
+                action = RailEnvActions.MOVE_LEFT
+            else:
+                # MICHEL: prints for debugging
+                print(f"An error occured. movement is: {movement}, agent direction is: {agent.direction}")
+                if movement == (agent.direction + 2) % 4 or (movement == agent.direction - 2) % 4:
+                    print("it seems that we are turning by 180 degrees. Turning in a dead end?")
+
+                # MICHEL: can this happen when we turn 180 degrees in a dead end?
+                # i.e. can we then have movement == agent.direction + 2 % 4 (resp. ... == - 2 % 4)?
+
+                # TRY OUT: ASSIGN MOVE_FORWARD HERE...
+                action = RailEnvActions.MOVE_FORWARD
+                print("Here we would have a ValueError...")
+                #raise ValueError("Wtf, debug this shit.")
+               
+            distance = distance_map[get_new_position(agent_virtual_position, movement) + (movement,)]
+            possible_steps.append((action, distance))
+    possible_steps = sorted(possible_steps, key=lambda step: step[1])
+
+
+    # MICHEL: what is this doing?
+    # if there is only one path to target, this is both the shortest one and the second shortest path.
+    if len(possible_steps) == 1:
+        return possible_steps * 2
+    else:
+        return possible_steps
+
+
+class RailEnvWrapper:
+  def __init__(self, env:RailEnv):
+    self.env = env
+
+    assert self.env is not None
+    assert self.env.rail is not None, "Reset original environment first!"
+    assert self.env.agents is not None, "Reset original environment first!"
+    assert len(self.env.agents) > 0, "Reset original environment first!"
+
+    # rail can be seen as part of the interface to RailEnv.
+    # is used by several wrappers, to e.g. access rail.get_valid_transitions(...)
+    #self.rail = self.env.rail
+    # same for env.agents
+    # MICHEL: DOES THIS HERE CAUSE A PROBLEM with agent status not being updated?
+    #self.agents = self.env.agents
+    #assert self.env.agents == self.agents
+    #print(f"agents of RailEnvWrapper are: {self.agents}")
+    #self.width = self.rail.width
+    #self.height = self.rail.height
+
+
+  # TODO: maybe do this in a generic way, like "for each method of self.env, ..."
+  # maybe using dir(self.env) (gives list of names of members)
+
+  # MICHEL: this seems to be needed after each env.reset(..) call
+  # otherwise, these attribute names refer to the wrong object and are out of sync...
+  # probably due to the reassignment of new objects to these variables by RailEnv, and how Python treats that.
+
+  # simple example: a = [1,2,3] b=a. But then a=[0]. Now we still have b==[1,2,3].
+
+  # it's better tou use properties here!
+
+  # @property
+  # def number_of_agents(self):
+  #   return self.env.number_of_agents
+  
+  # @property
+  # def agents(self):
+  #   return self.env.agents
+
+  # @property
+  # def _seed(self):
+  #   return self.env._seed
+
+  # @property
+  # def obs_builder(self):
+  #   return self.env.obs_builder
+
+  def __getattr__(self, name):
+    try:
+      return super().__getattr__(self,name)
+    except:
+      """Expose any other attributes of the underlying environment."""
+      return getattr(self.env, name)
+
+
+  @property
+  def rail(self):
+    return self.env.rail
+  
+  @property
+  def width(self):
+    return self.env.width
+  
+  @property
+  def height(self):
+    return self.env.height
+
+  @property
+  def agent_positions(self):
+    return self.env.agent_positions
+
+  def get_num_agents(self):
+    return self.env.get_num_agents()
+
+  def get_agent_handles(self):
+    return self.env.get_agent_handles()
+
+  def step(self, action_dict: Dict[int, RailEnvActions]):
+    #self.agents = self.env.agents
+    # ERROR. something is wrong with the references for self.agents...
+    #assert self.env.agents == self.agents
+    return self.env.step(action_dict)
+
+  def reset(self, **kwargs):
+    # MICHEL: I suspect that env.reset() does not simply change values of variables, but assigns new objects
+    # that might cause some attributes not be properly updated here, because of how Python treats assignments differently from modification..
+    #assert self.env.agents == self.agents
+    obs, info = self.env.reset(**kwargs)
+    #assert self.env.agents == self.agents, "after resetting internal env, self.agents names wrong object..."
+    #self.reset_attributes()
+    #print(f"calling RailEnvWrapper.reset()")
+    #print(f"obs: {obs}, info:{info}")
+    return obs, info
+
+
+class ShortestPathActionWrapper(RailEnvWrapper):
+
+    def __init__(self, env:RailEnv):
+        super().__init__(env)
+        #self.action_space = gym.spaces.Discrete(n=3)  # 0:stop, 1:shortest path, 2:other direction
+
+    # MICHEL: we have to make sure that not agents with agent.status == DONE_REMOVED are in the action dict.
+    # otherwise, possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0] will crash.
+    def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
+      ########## MICHEL: NEW (just for debugging) ########
+      for agent_id, action in action_dict.items():
+        agent = self.agents[agent_id]
+        # assert agent.status != RailAgentStatus.DONE_REMOVED # this comes with agent.position == None...
+        # assert agent.status != RailAgentStatus.DONE # not sure about this one...
+        print(f"agent: {agent} with status: {agent.status}")
+      ######################################################
+
+      # input: action dict with actions in [0, 1, 2].
+      transformed_action_dict = {}
+      for agent_id, action in action_dict.items():
+          if action == 0:
+              transformed_action_dict[agent_id] = action
+          else:
+              assert action in [1, 2]
+              # MICHEL: how exactly do the indices work here?
+              #transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.rail_env, agent_id)[action - 1][0]
+              #print(f"possible actions sorted by distance(...) is: {possible_actions_sorted_by_distance(self.env, agent_id)}")
+              #assert agent.status != RailAgentStatus.DONE_REMOVED
+              # MICHEL: THIS LINE CRASHES WITH A "NoneType is not subscriptable" error...
+              assert possible_actions_sorted_by_distance(self.env, agent_id) is not None
+              assert possible_actions_sorted_by_distance(self.env, agent_id)[action - 1] is not None
+              transformed_action_dict[agent_id] = possible_actions_sorted_by_distance(self.env, agent_id)[action - 1][0]
+      obs, rewards, dones, info = self.env.step(transformed_action_dict)
+      return obs, rewards, dones, info
+
+    #def reset(self, random_seed: Optional[int] = None) -> Dict[int, Any]:
+        #return self.rail_env.reset(random_seed)
+
+    # MICHEL: should not be needed, as we inherit that from RailEnvWrapper...
+    #def reset(self, **kwargs) -> Tuple[Dict, Dict]:
+    #  obs, info = self.env.reset(**kwargs)
+    #  return obs, info
+
+
+def find_all_cells_where_agent_can_choose(env: RailEnv):
+    """
+    input: a RailEnv (or something which behaves similarly, e.g. a wrapped RailEnv),
+    WHICH HAS BEEN RESET ALREADY!
+    (o.w., we call env.rail, which is None before reset(), and crash.)
+    """
+    switches = []
+    switches_neighbors = []
+    directions = list(range(4))
+    for h in range(env.height):
+        for w in range(env.width):
+
+            # MICHEL: THIS SEEMS TO BE A BUG. WRONG ODER OF COORDINATES.
+            # will not show up in quadratic environments.
+            # should be pos = (h, w)
+            #pos = (w, h)
+
+            # MICHEL: changed this
+            pos = (h, w)
+
+            is_switch = False
+            # Check for switch: if there is more than one outgoing transition
+            for orientation in directions:
+                #print(f"env is: {env}")
+                #print(f"env.rail is: {env.rail}")
+                possible_transitions = env.rail.get_transitions(*pos, orientation)
+                num_transitions = np.count_nonzero(possible_transitions)
+                if num_transitions > 1:
+                    switches.append(pos)
+                    is_switch = True
+                    break
+            if is_switch:
+                # Add all neighbouring rails, if pos is a switch
+                for orientation in directions:
+                    possible_transitions = env.rail.get_transitions(*pos, orientation)
+                    for movement in directions:
+                        if possible_transitions[movement]:
+                            switches_neighbors.append(get_new_position(pos, movement))
+
+    decision_cells = switches + switches_neighbors
+    return tuple(map(set, (switches, switches_neighbors, decision_cells)))
+
+
+class NoChoiceCellsSkipper:
+    def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
+      self.env = env
+      self.switches = None
+      self.switches_neighbors = None
+      self.decision_cells = None
+      self.accumulate_skipped_rewards = accumulate_skipped_rewards
+      self.discounting = discounting
+      self.skipped_rewards = defaultdict(list)
+
+      # env.reset() can change the rail grid layout, so the switches, etc. will change! --> need to do this in reset() as well.
+      #self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
+
+      # compute and initialize value for switches, switches_neighbors, and decision_cells.
+      self.reset_cells()
+
+    # MICHEL: maybe these three methods should be part of RailEnv?
+    def on_decision_cell(self, agent: EnvAgent) -> bool:
+        """
+        print(f"agent {agent.handle} is on decision cell")
+        if agent.position is None:
+          print("because agent.position is None (has not been activated yet)")
+        if agent.position == agent.initial_position:
+          print("because agent is at initial position, activated but not departed")
+        if agent.position in self.decision_cells:
+          print("because agent.position is in self.decision_cells.")
+        """
+        return agent.position is None or agent.position == agent.initial_position or agent.position in self.decision_cells
+
+    def on_switch(self, agent: EnvAgent) -> bool:
+        return agent.position in self.switches
+
+    def next_to_switch(self, agent: EnvAgent) -> bool:
+        return agent.position in self.switches_neighbors
+
+    # MICHEL: maybe just call this step()...
+    def no_choice_skip_step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
+        o, r, d, i = {}, {}, {}, {}
+      
+        # MICHEL: NEED TO INITIALIZE i["..."]
+        # as we will access i["..."][agent_id]
+        i["action_required"] = dict()
+        i["malfunction"] = dict()
+        i["speed"] = dict()
+        i["status"] = dict()
+
+        while len(o) == 0:
+            #print(f"len(o)==0. stepping the rail environment...")
+            obs, reward, done, info = self.env.step(action_dict)
+
+            for agent_id, agent_obs in obs.items():
+
+                ######  MICHEL: prints for debugging  ###########
+                if not self.on_decision_cell(self.env.agents[agent_id]):
+                      print(f"agent {agent_id} is NOT on a decision cell.")
+                #################################################
+
+
+                if done[agent_id] or self.on_decision_cell(self.env.agents[agent_id]):
+                    ######  MICHEL: prints for debugging  ######################
+                    if done[agent_id]:
+                      print(f"agent {agent_id} is done.")
+                    #if self.on_decision_cell(self.env.agents[agent_id]):
+                      #print(f"agent {agent_id} is on decision cell.")
+                      #cell = self.env.agents[agent_id].position
+                      #print(f"cell is: {cell}")
+                      #print(f"the decision cells are: {self.decision_cells}")
+                    
+                    ############################################################
+
+                    o[agent_id] = agent_obs
+                    r[agent_id] = reward[agent_id]
+                    d[agent_id] = done[agent_id]
+
+                    # MICHEL: HAVE TO MODIFY THIS HERE
+                    # because we are not using StepOutputs, the return values of step() have a different structure.
+                    #i[agent_id] = info[agent_id]
+                    i["action_required"][agent_id] = info["action_required"][agent_id] 
+                    i["malfunction"][agent_id] = info["malfunction"][agent_id]
+                    i["speed"][agent_id] = info["speed"][agent_id]
+                    i["status"][agent_id] = info["status"][agent_id]
+                                                                  
+                    if self.accumulate_skipped_rewards:
+                        discounted_skipped_reward = r[agent_id]
+                        for skipped_reward in reversed(self.skipped_rewards[agent_id]):
+                            discounted_skipped_reward = self.discounting * discounted_skipped_reward + skipped_reward
+                        r[agent_id] = discounted_skipped_reward
+                        self.skipped_rewards[agent_id] = []
+
+                elif self.accumulate_skipped_rewards:
+                    self.skipped_rewards[agent_id].append(reward[agent_id])
+                # end of for-loop
+
+            d['__all__'] = done['__all__']
+            action_dict = {}
+            # end of while-loop
+
+        return o, r, d, i
+
+    # MICHEL: maybe just call this reset()...
+    def reset_cells(self) -> None:
+        self.switches, self.switches_neighbors, self.decision_cells = find_all_cells_where_agent_can_choose(self.env)
+
+
+# IMPORTANT: rail env should be reset() / initialized before put into this one!
+# IDEA: MAYBE EACH RAILENV INSTANCE SHOULD AUTOMATICALLY BE reset() / initialized upon creation!
+class SkipNoChoiceCellsWrapper(RailEnvWrapper):
+  
+    # env can be a real RailEnv, or anything that shares the same interface
+    # e.g. obs, rewards, dones, info = env.step(action_dict) and obs, info = env.reset(), and so on.
+    def __init__(self, env:RailEnv, accumulate_skipped_rewards: bool, discounting: float) -> None:
+        super().__init__(env)
+        # save these so they can be inspected easier.
+        self.accumulate_skipped_rewards = accumulate_skipped_rewards
+        self.discounting = discounting
+        self.skipper = NoChoiceCellsSkipper(env=self.env, accumulate_skipped_rewards=self.accumulate_skipped_rewards, discounting=self.discounting)
+
+        self.skipper.reset_cells()
+
+        # TODO: this is clunky..
+        # for easier access / checking
+        self.switches = self.skipper.switches
+        self.switches_neighbors = self.skipper.switches_neighbors
+        self.decision_cells = self.skipper.decision_cells
+        self.skipped_rewards = self.skipper.skipped_rewards
+
+  
+    # MICHEL: trying to isolate the core part and put it into a separate method.
+    def step(self, action_dict: Dict[int, RailEnvActions]) -> Tuple[Dict, Dict, Dict, Dict]:
+        obs, rewards, dones, info = self.skipper.no_choice_skip_step(action_dict=action_dict)
+        return obs, rewards, dones, info
+        
+
+    # MICHEL: TODO: maybe add parameters like regenerate_rail, regenerate_schedule, etc.
+    # arguments from RailEnv.reset() are: self, regenerate_rail: bool = True, regenerate_schedule: bool = True, activate_agents: bool = False, random_seed: bool = None
+    # TODO: check the type of random_seed. Is it bool or int?
+    # MICHEL: changed return type from Dict[int, Any] to Tuple[Dict, Dict].
+    def reset(self, **kwargs) -> Tuple[Dict, Dict]:
+        obs, info = self.env.reset(**kwargs)
+        # resets decision cells, switches, etc. These can change with an env.reset(...)!
+        # needs to be done after env.reset().
+        self.skipper.reset_cells()
+        return obs, info
\ No newline at end of file
diff --git a/tests/test_pettingzoo_interface.py b/tests/test_pettingzoo_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2b2776f183dcb0ac2a04b4d2f4f0cc769070f68
--- /dev/null
+++ b/tests/test_pettingzoo_interface.py
@@ -0,0 +1,132 @@
+import numpy as np
+import os
+import PIL
+import shutil
+
+from flatland.contrib.interface import flatland_env
+from flatland.contrib.utils import env_generators
+
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import ShortestPathPredictorForRailEnv
+
+
+# First of all we import the Flatland rail environment
+from flatland.utils.rendertools import RenderTool, AgentRenderVariant
+
+from flatland.contrib.wrappers.flatland_wrappers import SkipNoChoiceCellsWrapper
+from flatland.contrib.wrappers.flatland_wrappers import ShortestPathActionWrapper  # noqa
+import pytest
+
+
+@pytest.mark.skip(reason="Only for testing pettingzoo interface and wrappers")
+def test_petting_zoo_interface_env():
+
+    # Custom observation builder without predictor
+    # observation_builder = GlobalObsForRailEnv()
+
+    # Custom observation builder with predictor
+    observation_builder = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
+    seed = 11
+    save = True
+    np.random.seed(seed)
+    experiment_name = "flatland_pettingzoo"
+    total_episodes = 1
+
+    if save:
+        try:
+            if os.path.isdir(experiment_name):
+                shutil.rmtree(experiment_name)
+            os.mkdir(experiment_name)
+        except OSError as e:
+            print("Error: %s - %s." % (e.filename, e.strerror))
+
+    rail_env = env_generators.sparse_env_small(seed, observation_builder)
+    rail_env = env_generators.small_v0(seed, observation_builder)
+
+    rail_env.reset(random_seed=seed)
+
+    # For Shortest Path Action Wrapper, change action to 1
+    # rail_env = ShortestPathActionWrapper(rail_env)  
+    rail_env = SkipNoChoiceCellsWrapper(rail_env, accumulate_skipped_rewards=False, discounting=0.0)
+    
+    env_renderer = RenderTool(rail_env,
+                            agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                            show_debug=False,
+                            screen_height=600,  # Adjust these parameters to fit your resolution
+                            screen_width=800)  # Adjust these parameters to fit your resolution
+
+    dones = {}
+    dones['__all__'] = False
+
+    step = 0
+    ep_no = 0
+    frame_list = []
+    all_actions_env = []
+    all_actions_pettingzoo_env = []
+    # while not dones['__all__']:
+    while ep_no < total_episodes:
+        action_dict = {}
+        # Chose an action for each agent
+        for a in range(rail_env.get_num_agents()):
+            # action = env_generators.get_shortest_path_action(rail_env, a)
+            action = 2
+            all_actions_env.append(action)
+            action_dict.update({a: action})
+            step += 1
+            # Do the environment step
+
+        observations, rewards, dones, information = rail_env.step(action_dict)
+        image = env_renderer.render_env(show=False, show_observations=False, show_predictions=False,
+                                        return_image=True)
+        frame_list.append(PIL.Image.fromarray(image[:, :, :3]))
+
+        if dones['__all__']:
+            completion = env_generators.perc_completion(rail_env)
+            print("Final Agents Completed:", completion)
+            ep_no += 1
+            if save:
+                frame_list[0].save(f"{experiment_name}{os.sep}out_{ep_no}.gif", save_all=True, 
+                                   append_images=frame_list[1:], duration=3, loop=0)       
+            frame_list = []
+            env_renderer = RenderTool(rail_env,
+                            agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
+                            show_debug=False,
+                            screen_height=600,  # Adjust these parameters to fit your resolution
+                            screen_width=800)  # Adjust these parameters to fit your resolution
+            rail_env.reset(random_seed=seed+ep_no)
+
+    
+#  __sphinx_doc_begin__
+    env = flatland_env.env(environment=rail_env, use_renderer=True)
+    seed = 11
+    env.reset(random_seed=seed)
+    step = 0
+    ep_no = 0
+    frame_list = []
+    while ep_no < total_episodes:
+        for agent in env.agent_iter():
+            obs, reward, done, info = env.last()
+            # act = env_generators.get_shortest_path_action(env.environment, get_agent_handle(agent))
+            act = 2
+            all_actions_pettingzoo_env.append(act)
+            env.step(act)
+            frame_list.append(PIL.Image.fromarray(env.render(mode='rgb_array')))
+            step += 1
+# __sphinx_doc_end__
+        completion = env_generators.perc_completion(env)
+        print("Final Agents Completed:", completion)
+        ep_no += 1
+        if save:
+            frame_list[0].save(f"{experiment_name}{os.sep}pettyzoo_out_{ep_no}.gif", save_all=True, 
+                               append_images=frame_list[1:], duration=3, loop=0)
+        frame_list = []
+        env.close()
+        env.reset(random_seed=seed+ep_no)
+        min_len = min(len(all_actions_pettingzoo_env), len(all_actions_env))
+        assert all_actions_pettingzoo_env[:min_len] == all_actions_env[:min_len], "actions do not match"
+
+
+if __name__ == "__main__":
+    import pytest
+    import sys
+    sys.exit(pytest.main(["-sv", __file__]))