diff --git a/examples/env_generators.py b/examples/env_generators.py
index 71e9d1ece93a0e58ce66f06f0912d417499d228b..d8dd2e3ece319e95bd3b397bd694724858ee9543 100644
--- a/examples/env_generators.py
+++ b/examples/env_generators.py
@@ -3,15 +3,110 @@ import random
 import numpy as np
 from typing import NamedTuple
 
-from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen
+from flatland.envs.malfunction_generators import malfunction_from_params, MalfunctionParameters, ParamMalfunctionGen, no_malfunction_generator
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
 from flatland.envs.schedule_generators import sparse_schedule_generator
 from flatland.envs.agent_utils import RailAgentStatus
+from flatland.core.grid.grid4_utils import get_new_position
 
 MalfunctionParameters = NamedTuple('MalfunctionParameters', [('malfunction_rate', float), ('min_duration', int), ('max_duration', int)])
 
 
+def get_shortest_path_action(env,handle):
+    distance_map = env.distance_map.get()
+
+    agent = env.agents[handle]
+
+    if agent.status == RailAgentStatus.READY_TO_DEPART:
+        agent_virtual_position = agent.initial_position
+    elif agent.status == RailAgentStatus.ACTIVE:
+        agent_virtual_position = agent.position
+    elif agent.status == RailAgentStatus.DONE:
+        agent_virtual_position = agent.target
+    else:
+        return None
+
+    if agent.position:
+        possible_transitions = env.rail.get_transitions(
+            *agent.position, agent.direction)
+    else:
+        possible_transitions = env.rail.get_transitions(
+            *agent.initial_position, agent.direction)
+
+    num_transitions = np.count_nonzero(possible_transitions)                    
+    
+    min_distances = []
+    for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
+        if possible_transitions[direction]:
+            new_position = get_new_position(
+                agent_virtual_position, direction)
+            min_distances.append(
+                distance_map[handle, new_position[0],
+                            new_position[1], direction])
+        else:
+            min_distances.append(np.inf)
+
+    if num_transitions == 1:
+        observation = [0, 1, 0]
+
+    elif num_transitions == 2:
+        idx = np.argpartition(np.array(min_distances), 2)
+        observation = [0, 0, 0]
+        observation[idx[0]] = 1
+    return np.argmax(observation) + 1
+
+
+def small_v0(random_seed, observation_builder, max_width = 35, max_height = 35):
+    random.seed(random_seed)
+    width =  25
+    height =  25
+    nr_trains = 5
+    max_num_cities = 4
+    grid_mode = False
+    max_rails_between_cities = 2
+    max_rails_in_city = 3
+
+    malfunction_rate = 0
+    malfunction_min_duration = 0
+    malfunction_max_duration = 0
+
+    rail_generator = sparse_rail_generator(max_num_cities=max_num_cities, seed=random_seed, grid_mode=False,
+                                           max_rails_between_cities=max_rails_between_cities,
+                                           max_rails_in_city=max_rails_in_city)
+
+    stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate,  # Rate of malfunction occurence
+                                        min_duration=malfunction_min_duration,  # Minimal duration of malfunction
+                                        max_duration=malfunction_max_duration  # Max duration of malfunction
+                                        )
+    speed_ratio_map = None
+    schedule_generator = sparse_schedule_generator(speed_ratio_map)
+
+    malfunction_generator = no_malfunction_generator()
+    
+    while width <= max_width and height <= max_height:
+        try:
+            env = RailEnv(width=width, height=height, rail_generator=rail_generator,
+                          schedule_generator=schedule_generator, number_of_agents=nr_trains,
+                        #   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                          malfunction_generator_and_process_data=malfunction_generator,
+                          obs_builder_object=observation_builder, remove_agents_at_target=False)
+
+            print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
+                random_seed, width, height, max_num_cities, nr_trains, max_rails_between_cities,
+                max_rails_in_city, malfunction_rate, malfunction_min_duration, malfunction_max_duration
+            ))
+
+            return env
+        except ValueError as e:
+            logging.error(f"Error: {e}")
+            width += 5
+            height += 5
+            logging.info("Try again with larger env: (w,h):", width, height)
+    logging.error(f"Unable to generate env with seed={random_seed}, max_width={max_height}, max_height={max_height}")
+    return None    
+
+
 def random_sparse_env_small(random_seed, observation_builder, max_width = 45, max_height = 45):
     random.seed(random_seed)
     size = random.randint(0, 5)
@@ -29,11 +124,10 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma
                                            max_rails_between_cities=max_rails_between_cities,
                                            max_rails_in_city=max_rails_in_cities)
 
-    # new version:
-    # stochastic_data = MalfunctionParameters(malfunction_rate, malfunction_min_duration, malfunction_max_duration)
-
-    stochastic_data = {'malfunction_rate': malfunction_rate, 'min_duration': malfunction_min_duration,
-                       'max_duration': malfunction_max_duration}
+    stochastic_data = MalfunctionParameters(malfunction_rate=malfunction_rate,  # Rate of malfunction occurence
+                                        min_duration=malfunction_min_duration,  # Minimal duration of malfunction
+                                        max_duration=malfunction_max_duration  # Max duration of malfunction
+                                        )
 
     schedule_generator = sparse_schedule_generator({1.: 0.25, 1. / 2.: 0.25, 1. / 3.: 0.25, 1. / 4.: 0.25})
 
@@ -41,7 +135,8 @@ def random_sparse_env_small(random_seed, observation_builder, max_width = 45, ma
         try:
             env = RailEnv(width=width, height=height, rail_generator=rail_generator,
                           schedule_generator=schedule_generator, number_of_agents=nr_trains,
-                          malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                        #   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
+                          malfunction_generator=ParamMalfunctionGen(stochastic_data),
                           obs_builder_object=observation_builder, remove_agents_at_target=False)
 
             print("[{}] {}x{} {} cities {} trains, max {} rails between cities, max {} rails in cities. Malfunction rate {}, {} to {} steps.".format(
@@ -129,9 +224,13 @@ def _after_step(self, observation, reward, done, info):
 
 def perc_completion(env):
     tasks_finished = 0
-    for current_agent in env.agents_data:
-        if current_agent.status == RailAgentStatus.DONE_REMOVED:
+    if isinstance(env, RailEnv):        
+        agent_data = env.agents
+    else:
+        agent_data = env.agents_data
+    for current_agent in agent_data:
+        if current_agent.status == RailAgentStatus.DONE:
             tasks_finished += 1
 
     return 100 * np.mean(tasks_finished / max(
-                                1, env.num_agents)) 
\ No newline at end of file
+                                1, len(agent_data))) 
\ No newline at end of file
diff --git a/examples/flatland_env.py b/examples/flatland_env.py
index b7f593f9e4128b061a220353d711100715bba870..05fcbd7d652477122396aad4d30aaa1c348fffe2 100644
--- a/examples/flatland_env.py
+++ b/examples/flatland_env.py
@@ -148,7 +148,7 @@ class raw_env(AECEnv, gym.Env):
         self.agent_selection = self._agent_selector.next()
         self.rewards = dict(zip(self.agents, [0 for _ in self.agents]))
         self._cumulative_rewards = dict(zip(self.agents, [0 for _ in self.agents]))
-        self.action_dict = {i:0 for i in self.possible_agents}
+        self.action_dict = {get_agent_handle(i):0 for i in self.possible_agents}
 
         return observations
 
@@ -160,14 +160,12 @@ class raw_env(AECEnv, gym.Env):
         
         agent = self.agent_selection
         self.action_dict[get_agent_handle(agent)] = action
-        if self._reset_next_step:
-            return self.reset()
 
         if self.dones[agent]:
             self.agents.remove(agent)
-            if not self.env_done():
-                self.agent_selection = self._agent_selector.next()
-            return self.last()
+            # self.agent_selection = self._agent_selector.next()
+            # self.agents.remove(agent)
+            # return self.last()
 
         if self._agent_selector.is_last():
             observations, rewards, dones, infos = self._environment.step(self.action_dict)
@@ -185,10 +183,12 @@ class raw_env(AECEnv, gym.Env):
         
         # self._cumulative_rewards[agent] = 0
         self._accumulate_rewards()
+
+        obs, cumulative_reward, done, info = self.last()
         
         self.agent_selection = self._agent_selector.next()
 
-        return self.last()
+        return obs, cumulative_reward, done, info
 
         # if self._agent_selector.is_last():
         #     self._agent_selector.reinit(self.agents)