diff --git a/changelog.md b/changelog.md
index 109f748dbd7360e86409f782ef2ede56ecfaa089..1099488b3c03d3045c0e5cfdc8f661a97a9ff365 100644
--- a/changelog.md
+++ b/changelog.md
@@ -12,7 +12,7 @@ Changes since Flatland 2.0.0
 - by default the reset method of RailEnv is not called in the constructor of RailEnv anymore. Therefore the reset method needs to be called after the creation of a RailEnv object
 
 ### Changes in schedule generation
-- return value of schedule generator has changed to the named tuple Schedule
+- return value of schedule generator has changed to the named tuple `Schedule`
 
 Changes since Flatland 1.0.0
 --------------------------
diff --git a/docs/specifications/railway.md b/docs/specifications/railway.md
index 0299badb5ea8743a92a8aecf7f5c6cedb060a98d..3623ea450f74c5075bf0b1b9500cbddb6e406cb3 100644
--- a/docs/specifications/railway.md
+++ b/docs/specifications/railway.md
@@ -697,3 +697,17 @@ RailEnv.step()
                                                     self.get()
                                                     ...
 ```
+
+
+### Maximum number of allowed time steps in an episode
+
+Whenever the schedule within RailEnv is generated, the maximum number of allowed time steps in an episode is calculated
+according to the following formula:
+
+```python
+
+RailEnv._max_episode_steps = timedelay_factor * alpha * (env.width + env.height + ratio_nr_agents_to_nr_cities)
+
+```
+
+where the following default values are used `timedelay_factor=4`, `alpha=2` and `ratio_nr_agents_to_nr_cities=20`
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 3be7940ef97542664cbc6d2e8850f604a64f38bf..58f807ea43309af37fe669d88624c9fd2235efa4 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -169,7 +169,7 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder.set_env(self)
 
-        self._max_episode_steps = None
+        self._max_episode_steps: Optional[int] = None
         self._elapsed_steps = 0
 
         self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
@@ -249,6 +249,35 @@ class RailEnv(Environment):
         """
         self.agents = EnvAgent.list_from_static(self.agents_static)
 
+    @staticmethod
+    def compute_max_episode_steps(width: int, height: int, timedelay_factor: int = 4, alpha: int = 2,
+                                  ratio_nr_agents_to_nr_cities: float = 20.0) -> int:
+        """
+        compute_max_episode_steps(width, height, ratio_nr_agents_to_nr_cities, timedelay_factor, alpha)
+
+        The method computes the max number of episode steps allowed
+
+        Parameters
+        ----------
+        width : int
+            width of environment
+        height : int
+            height of environment
+        ratio_nr_agents_to_nr_cities : float, optional
+            number_of_agents/number_of_cities
+        timedelay_factor : int, optional
+            timedelay_factor
+        alpha : int, optional
+            alpha
+
+        Returns
+        -------
+        max_episode_steps: int
+            maximum number of episode steps
+
+        """
+        return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities))
+
     def reset(self, regen_rail=True, replace_agents=True, activate_agents=False, random_seed=None):
         """ if regen_rail then regenerate the rails.
             if replace_agents then regenerate the agents static.
@@ -282,7 +311,14 @@ class RailEnv(Environment):
             #  why do we need static agents? could we it more elegantly?
             schedule = self.schedule_generator(self.rail, self.get_num_agents(), agents_hints, self.num_resets)
             self.agents_static = EnvAgentStatic.from_lists(schedule)
-            self._max_episode_steps = schedule.max_episode_steps
+
+            if agents_hints and 'city_orientations' in agents_hints:
+                ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations'])
+                self._max_episode_steps = self.compute_max_episode_steps(
+                                                    width=self.width, height=self.height,
+                                                    ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
+            else:
+                self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
 
         self.restart_agents()
 
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index d4f8ed7ed8219dbb248b32d8ee266be0ccdb5e79..85beea071aebd1c39cff22a3895b6f036047718e 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -42,29 +42,6 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float,
     return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
 
 
-def compute_max_episode_steps(width: int,
-                              height: int,
-                              ratio_nr_agents_to_nr_cities: float = 20.0,
-                              timedelay_factor: int = 4,
-                              alpha: int = 2) -> int:
-    """
-
-    The method computes the max number of episode steps allowed
-    Parameters
-    ----------
-    width: width of environment
-    height: height of environment
-    ratio_nr_agents_to_nr_cities: number_of_agents/number_of_cities (default is 20)
-    timedelay_factor
-    alpha
-
-    Returns max number of episode steps
-    -------
-
-    """
-    return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities))
-
-
 def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
     def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0) -> Schedule:
 
@@ -82,10 +59,8 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
         else:
             speeds = [1.0] * len(agents_position)
 
-        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
-                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None,
-                        max_episode_steps=max_episode_steps)
+                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
 
     return generator
 
@@ -151,11 +126,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
         else:
             speeds = [1.0] * len(agents_position)
 
-        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height,
-                                                      ratio_nr_agents_to_nr_cities=num_agents/len(city_orientations))
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
-                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None,
-                        max_episode_steps=max_episode_steps)
+                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
 
     return generator
 
@@ -182,8 +154,6 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
 
         np.random.seed(_runtime_seed)
 
-        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
-
         valid_positions = []
         for r in range(rail.height):
             for c in range(rail.width):
@@ -191,14 +161,12 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
                     valid_positions.append((r, c))
         if len(valid_positions) == 0:
             return Schedule(agent_positions=[], agent_directions=[],
-                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None,
-                            max_episode_steps=max_episode_steps)
+                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
 
         if len(valid_positions) < num_agents:
             warnings.warn("schedule_generators: len(valid_positions) < num_agents")
             return Schedule(agent_positions=[], agent_directions=[],
-                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None,
-                            max_episode_steps=max_episode_steps)
+                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
 
         agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
         agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
@@ -257,8 +225,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
 
         agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed)
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
-                        agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None,
-                        max_episode_steps=max_episode_steps)
+                        agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None)
 
     return generator
 
@@ -303,10 +270,9 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
         else:
             agents_speed = None
             agents_malfunction = None
-        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
                         agent_targets=agents_target, agent_speeds=agents_speed,
-                        agent_malfunction_rates=agents_malfunction, max_episode_steps=max_episode_steps)
+                        agent_malfunction_rates=agents_malfunction)
 
     return generator
 
diff --git a/flatland/envs/schedule_utils.py b/flatland/envs/schedule_utils.py
index b8f28a47785f9bdf30b06ce96d3012bb383835cb..e89f170dbb87388bcecbc6b2e176ba277162a4db 100644
--- a/flatland/envs/schedule_utils.py
+++ b/flatland/envs/schedule_utils.py
@@ -1,8 +1,10 @@
-import collections
+from typing import List, NamedTuple
 
-Schedule = collections.namedtuple('Schedule',   'agent_positions '
-                                                'agent_directions '
-                                                'agent_targets '
-                                                'agent_speeds '
-                                                'agent_malfunction_rates '
-                                                'max_episode_steps')
+from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.core.grid.grid_utils import IntVector2DArray
+
+Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray),
+                                   ('agent_directions', List[Grid4TransitionsEnum]),
+                                   ('agent_targets', IntVector2DArray),
+                                   ('agent_speeds', List[float]),
+                                   ('agent_malfunction_rates', List[int])])
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index e694275bc353e620f0357101bc5fd11a5c273212..22db31572fd6d0013fa799037ec6f021e90334aa 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -151,21 +151,22 @@ def test_malfunction_process_statistically():
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   obs_builder_object=SingleAgentNavigationObs()
                   )
-    env.reset()
+
     # reset to initialize agents_static
     env.reset(True, True, False, random_seed=10)
 
     env.agents[0].target = (0, 0)
-    agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2],
-                              [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
+
+    agent_malfunction_list = [[0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0],
                               [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 6, 5],
-                              [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2],
-                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
+                              [0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4],
                               [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 6, 5, 4],
-                              [0, 0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
-                              [6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
+                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0],
+                              [6, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3],
+                              [0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5],
+                              [0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 0, 6, 5, 4, 3, 2, 1, 0],
+                              [0, 0, 0, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1],
+                              [6, 6, 6, 6, 5, 4, 3, 2, 1, 0, 0, 0, 0, 6, 5, 4, 3, 2, 1, 0]]
 
     for step in range(20):
         action_dict: Dict[int, RailEnvActions] = {}