From 2e5b015c1c76ad083384515b6e65a2596265ae78 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Fri, 18 Oct 2019 11:27:51 +0200
Subject: [PATCH] add named tuple for result of schedule generator and add
 max_episode_steps to the schedule generator

---
 changelog.md                         |  3 ++
 docs/specifications/railway.md       | 11 +++--
 docs/tutorials/05_multispeed.md      | 49 +++++++++++---------
 examples/custom_railmap_example.py   | 10 +++--
 flatland/envs/agent_utils.py         | 19 ++++----
 flatland/envs/rail_env.py            | 25 +++--------
 flatland/envs/rail_generators.py     | 14 +-----
 flatland/envs/schedule_generators.py | 67 ++++++++++++++++++++++------
 flatland/envs/schedule_utils.py      |  8 ++++
 flatland/evaluators/client.py        | 26 +++--------
 flatland/evaluators/service.py       | 54 +++++++++-------------
 11 files changed, 150 insertions(+), 136 deletions(-)
 create mode 100644 flatland/envs/schedule_utils.py

diff --git a/changelog.md b/changelog.md
index 3d256c1d..109f748d 100644
--- a/changelog.md
+++ b/changelog.md
@@ -11,6 +11,9 @@ Changes since Flatland 2.0.0
 - renaming of `distance_maps` into `distance_map`
 - by default the reset method of RailEnv is not called in the constructor of RailEnv anymore. Therefore the reset method needs to be called after the creation of a RailEnv object
 
+### Changes in schedule generation
+- return value of schedule generator has changed to the named tuple Schedule
+
 Changes since Flatland 1.0.0
 --------------------------
 ### Changes in stock predictors
diff --git a/docs/specifications/railway.md b/docs/specifications/railway.md
index 69abd15a..0299badb 100644
--- a/docs/specifications/railway.md
+++ b/docs/specifications/railway.md
@@ -373,8 +373,13 @@ RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 
 AgentPosition = Tuple[int, int]
-ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
-ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct]
+Schedule = collections.namedtuple('Schedule',   'agent_positions '
+                                                'agent_directions '
+                                                'agent_targets '
+                                                'agent_speeds '
+                                                'agent_malfunction_rates '
+                                                'max_episode_steps')
+ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Schedule]
 ```
 
 We can then produce `RailGenerator`s by currying:
@@ -435,7 +440,7 @@ The environment's `reset` takes care of applying the two generators:
             if optionals and 'agents_hints' in optionals:
                 agents_hints = optionals['agents_hints']
             self.agents_static = EnvAgentStatic.from_lists(
-                *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
+                self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
 ```
 
 
diff --git a/docs/tutorials/05_multispeed.md b/docs/tutorials/05_multispeed.md
index 118d4c59..cc45c65e 100644
--- a/docs/tutorials/05_multispeed.md
+++ b/docs/tutorials/05_multispeed.md
@@ -1,19 +1,19 @@
 # Different speed profiles Tutorial
 
-One of the main contributions to the complexity of railway network operations stems from the fact that all trains travel at different speeds while sharing a very limited railway network. 
+One of the main contributions to the complexity of railway network operations stems from the fact that all trains travel at different speeds while sharing a very limited railway network.
 In **Flat**land 2.0 this feature will be enabled as well and will lead to much more complex configurations. Here we count on your support if you find bugs or improvements  :).
 
-The different speed profiles can be generated using the `schedule_generator`, where you can actually chose as many different speeds as you like. 
-Keep in mind that the *fastest speed* is 1 and all slower speeds must be between 1 and 0. 
+The different speed profiles can be generated using the `schedule_generator`, where you can actually chose as many different speeds as you like.
+Keep in mind that the *fastest speed* is 1 and all slower speeds must be between 1 and 0.
 For the submission scoring you can assume that there will be no more than 5 speed profiles.
 
 
- 
-Later versions of **Flat**land might have varying speeds during episodes. Therefore, we return the agent speeds. 
+
+Later versions of **Flat**land might have varying speeds during episodes. Therefore, we return the agent speeds.
 Notice that we do not guarantee that the speed will be computed at each step, but if not costly we will return it at each step.
-In your controller, you can get the agents' speed from the `info` returned by `step`: 
+In your controller, you can get the agents' speed from the `info` returned by `step`:
 ```python
-obs, rew, done, info = env.step(actions) 
+obs, rew, done, info = env.step(actions)
 ...
 for a in range(env.get_num_agents()):
     speed = info['speed'][a]
@@ -21,9 +21,9 @@ for a in range(env.get_num_agents()):
 
 ## Actions and observation with different speed levels
 
-Because the different speeds are implemented as fractions the agents ability to perform actions has been updated. 
-We **do not allow actions to change within the cell **. 
-This means that each agent can only chose an action to be taken when entering a cell. 
+Because the different speeds are implemented as fractions the agents ability to perform actions has been updated.
+We **do not allow actions to change within the cell **.
+This means that each agent can only chose an action to be taken when entering a cell.
 This action is then executed when a step to the next cell is valid. For example
 
 - Agent enters switch and choses to deviate left. Agent fractional speed is 1/4 and thus the agent will take 4 time steps to complete its journey through the cell. On the 4th time step the agent will leave the cell deviating left as chosen at the entry of the cell.
@@ -31,9 +31,9 @@ This action is then executed when a step to the next cell is valid. For example
     - Agents can make observations at any time step. Make sure to discard observations without any information. See this [example](https://gitlab.aicrowd.com/flatland/baselines/blob/master/torch_training/training_navigation.py) for a simple implementation.
 - The environment checks if agent is allowed to move to next cell only at the time of the switch to the next cell
 
-In your controller, you can check whether an agent requires an action by checking `info`: 
+In your controller, you can check whether an agent requires an action by checking `info`:
 ```python
-obs, rew, done, info = env.step(actions) 
+obs, rew, done, info = env.step(actions)
 ...
 action_dict = dict()
 for a in range(env.get_num_agents()):
@@ -41,8 +41,8 @@ for a in range(env.get_num_agents()):
         action_dict.update({a: ...})
 
 ```
-Notice that `info['action_required'][a]` does not mean that the action will have an effect: 
-if the next cell is blocked or the agent breaks down, the action cannot be performed and an action will be required again in the next step. 
+Notice that `info['action_required'][a]` does not mean that the action will have an effect:
+if the next cell is blocked or the agent breaks down, the action cannot be performed and an action will be required again in the next step.
 
 ## Rail Generators and Schedule Generators
 The separation between rail generator and schedule generator reflects the organisational separation in the railway domain
@@ -51,14 +51,19 @@ The separation between rail generator and schedule generator reflects the organi
 Usually, there is a third organisation, which ensures discrimination-free access to the infrastructure for concurrent requests for the infrastructure in a **schedule planning phase**.
 However, in the **Flat**land challenge, we focus on the re-scheduling problem during live operations.
 
-Technically, 
+Technically,
 ```python
 RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Any]]
 RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
 
 AgentPosition = Tuple[int, int]
-ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
-ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct]
+Schedule = collections.namedtuple('Schedule',   'agent_positions '
+                                                'agent_directions '
+                                                'agent_targets '
+                                                'agent_speeds '
+                                                'agent_malfunction_rates '
+                                                'max_episode_steps')
+ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Schedule]
 ```
 
 We can then produce `RailGenerator`s by currying:
@@ -67,10 +72,10 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
                           num_neighb=3, grid_mode=False, enhance_intersection=False, seed=1):
 
     def generator(width, height, num_agents, num_resets=0):
-    
+
         # generate the grid and (optionally) some hints for the schedule_generator
         ...
-         
+
         return grid_map, {'agents_hints': {
             'num_agents': num_agents,
             'agent_start_targets_nodes': agent_start_targets_nodes,
@@ -89,7 +94,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) ->
         # - (initial) speed
         # - malfunction
         ...
-                
+
         return agents_position, agents_direction, agents_target, speeds, agents_malfunction
 
     return generator
@@ -108,7 +113,7 @@ The environment's `reset` takes care of applying the two generators:
              ):
         self.rail_generator: RailGenerator = rail_generator
         self.schedule_generator: ScheduleGenerator = schedule_generator
-        
+
     def reset(self, regen_rail=True, replace_agents=True):
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
@@ -119,7 +124,7 @@ The environment's `reset` takes care of applying the two generators:
             if optionals and 'agents_hints' in optionals:
                 agents_hints = optionals['agents_hints']
             self.agents_static = EnvAgentStatic.from_lists(
-                *self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
+                self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
 ```
 
 
diff --git a/examples/custom_railmap_example.py b/examples/custom_railmap_example.py
index d33d4925..bd662aa5 100644
--- a/examples/custom_railmap_example.py
+++ b/examples/custom_railmap_example.py
@@ -7,7 +7,8 @@ from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
-from flatland.envs.schedule_generators import ScheduleGenerator, ScheduleGeneratorProduct
+from flatland.envs.schedule_generators import ScheduleGenerator, compute_max_episode_steps
+from flatland.envs.schedule_utils import Schedule
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
@@ -30,12 +31,15 @@ def custom_rail_generator() -> RailGenerator:
 
 
 def custom_schedule_generator() -> ScheduleGenerator:
-    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> Schedule:
         agents_positions = []
         agents_direction = []
         agents_target = []
         speeds = []
-        return agents_positions, agents_direction, agents_target, speeds
+        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
+        return Schedule(agent_positions=agents_positions, agent_directions=agents_direction,
+                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None,
+                        max_episode_steps=max_episode_steps)
 
     return generator
 
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index d8c05c20..6a0e595b 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -6,6 +6,7 @@ import numpy as np
 from attr import attrs, attrib, Factory
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
+from flatland.envs.schedule_utils import Schedule
 
 
 class RailAgentStatus(IntEnum):
@@ -45,30 +46,30 @@ class EnvAgentStatic(object):
     position = attrib(default=None, type=Optional[Tuple[int, int]])
 
     @classmethod
-    def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
+    def from_lists(cls, schedule: Schedule):
         """ Create a list of EnvAgentStatics from lists of positions, directions and targets
         """
         speed_datas = []
 
-        for i in range(len(positions)):
+        for i in range(len(schedule.agent_positions)):
             speed_datas.append({'position_fraction': 0.0,
-                                'speed': speeds[i] if speeds is not None else 1.0,
+                                'speed': schedule.agent_speeds[i] if schedule.agent_speeds is not None else 1.0,
                                 'transition_action_on_cellexit': 0})
 
         # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
         # some as broken?
 
         malfunction_datas = []
-        for i in range(len(positions)):
+        for i in range(len(schedule.agent_positions)):
             malfunction_datas.append({'malfunction': 0,
-                                      'malfunction_rate': malfunction_rates[i] if malfunction_rates is not None else 0.,
+                                      'malfunction_rate': schedule.agent_malfunction_rates[i] if schedule.agent_malfunction_rates is not None else 0.,
                                       'next_malfunction': 0,
                                       'nr_malfunctions': 0})
 
-        return list(starmap(EnvAgentStatic, zip(positions,
-                                                directions,
-                                                targets,
-                                                [False] * len(positions),
+        return list(starmap(EnvAgentStatic, zip(schedule.agent_positions,
+                                                schedule.agent_directions,
+                                                schedule.agent_targets,
+                                                [False] * len(schedule.agent_positions),
                                                 speed_datas,
                                                 malfunction_datas)))
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index eb8a4e8c..10908ddd 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -115,7 +115,6 @@ class RailEnv(Environment):
                  schedule_generator: ScheduleGenerator = random_schedule_generator(),
                  number_of_agents=1,
                  obs_builder_object: ObservationBuilder = GlobalObsForRailEnv(),
-                 max_episode_steps=None,
                  stochastic_data=None,
                  remove_agents_at_target=False,
                  random_seed=1
@@ -148,7 +147,6 @@ class RailEnv(Environment):
         obs_builder_object: ObservationBuilder object
             ObservationBuilder-derived object that takes builds observation
             vectors for each agent.
-        max_episode_steps : int or None
         remove_agents_at_target : bool
             If remove_agents_at_target is set to true then the agents will be removed by placing to
             RailEnv.DEPOT_POSITION when the agent has reach it's target position.
@@ -171,7 +169,7 @@ class RailEnv(Environment):
         self.obs_builder = obs_builder_object
         self.obs_builder.set_env(self)
 
-        self._max_episode_steps = max_episode_steps
+        self._max_episode_steps = None
         self._elapsed_steps = 0
 
         self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False)
@@ -266,20 +264,8 @@ class RailEnv(Environment):
 
             self.rail = rail
             self.height, self.width = self.rail.grid.shape
-            # NOTE : Ignore Validation on every reset. rail_generator should ensure that
-            #        only valid grids are generated.
-            #
-            # for r in range(self.height):
-            #     for c in range(self.width):
-            #         rc_pos = (r, c)
-            #         check = self.rail.cell_neighbours_valid(rc_pos, True)
-            #         if not check:
-            #             print(self.rail.grid[rc_pos])
-            #             warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
-        # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/172
-        #  hacky: we must re-compute the distance map and not use the initial distance_map loaded from file by
-        #  rail_from_file!!!
-        elif optionals and 'distance_map' in optionals:
+
+        if optionals and 'distance_map' in optionals:
             self.distance_map.set(optionals['distance_map'])
 
         if replace_agents:
@@ -289,8 +275,9 @@ class RailEnv(Environment):
 
             # TODO https://gitlab.aicrowd.com/flatland/flatland/issues/185
             #  why do we need static agents? could we it more elegantly?
-            self.agents_static = EnvAgentStatic.from_lists(
-                *self.schedule_generator(self.rail, self.get_num_agents(), agents_hints, self.num_resets))
+            schedule = self.schedule_generator(self.rail, self.get_num_agents(), agents_hints, self.num_resets)
+            self.agents_static = EnvAgentStatic.from_lists(schedule)
+            self._max_episode_steps = schedule.max_episode_steps
 
         self.restart_agents()
 
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 39bc1088..6c9ec67d 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -542,15 +542,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R
     return generator
 
 
-def compute_max_episode_steps(width: int,
-                              height: int,
-                              num_agents: int,
-                              num_cities: int = 1,
-                              timedelay_factor: int = 4,
-                              alpha: int = 2) -> int:
-    return int(timedelay_factor * alpha * (width + height + (num_agents/num_cities)))
-
-
 def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4,
                           max_rails_in_city: int = 4, seed: int = 1) -> RailGenerator:
     """
@@ -620,14 +611,11 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         # Generate start target pairs
         agent_start_targets_cities = _generate_start_target_pairs(num_agents, num_cities, train_stations,
                                                                   city_orientations)
-        max_episode_steps = compute_max_episode_steps(width=width, height=height, num_agents=num_agents, num_cities=num_cities)
-
         return grid_map, {'agents_hints': {
             'num_agents': num_agents,
             'agent_start_targets_cities': agent_start_targets_cities,
             'train_stations': train_stations,
-            'city_orientations': city_orientations,
-            'max_episode_steps': max_episode_steps
+            'city_orientations': city_orientations
         }}
 
     def _generate_random_city_positions(num_cities: int, city_radius: int, width: int,
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index ab1e16cf..d4f8ed7e 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -8,10 +8,10 @@ import numpy as np
 from flatland.core.grid.grid4_utils import get_new_position
 from flatland.core.transition_map import GridTransitionMap
 from flatland.envs.agent_utils import EnvAgentStatic
+from flatland.envs.schedule_utils import Schedule
 
 AgentPosition = Tuple[int, int]
-ScheduleGeneratorProduct = Tuple[List[AgentPosition], List[AgentPosition], List[AgentPosition], List[float]]
-ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any]], ScheduleGeneratorProduct]
+ScheduleGenerator = Callable[[GridTransitionMap, int, Optional[Any], Optional[int]], Schedule]
 
 
 def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float, float] = None,
@@ -42,8 +42,31 @@ def speed_initialization_helper(nb_agents: int, speed_ratio_map: Mapping[float,
     return list(map(lambda index: speeds[index], np.random.choice(nb_classes, nb_agents, p=speed_ratios)))
 
 
+def compute_max_episode_steps(width: int,
+                              height: int,
+                              ratio_nr_agents_to_nr_cities: float = 20.0,
+                              timedelay_factor: int = 4,
+                              alpha: int = 2) -> int:
+    """
+
+    The method computes the max number of episode steps allowed
+    Parameters
+    ----------
+    width: width of environment
+    height: height of environment
+    ratio_nr_agents_to_nr_cities: number_of_agents/number_of_cities (default is 20)
+    timedelay_factor
+    alpha
+
+    Returns max number of episode steps
+    -------
+
+    """
+    return int(timedelay_factor * alpha * (width + height + ratio_nr_agents_to_nr_cities))
+
+
 def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
-    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0):
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0) -> Schedule:
 
         _runtime_seed = seed + num_resets
         np.random.seed(_runtime_seed)
@@ -59,14 +82,17 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
         else:
             speeds = [1.0] * len(agents_position)
 
-        return agents_position, agents_direction, agents_target, speeds
+        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
+        return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
+                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None,
+                        max_episode_steps=max_episode_steps)
 
     return generator
 
 
 def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, seed: int = 1) -> ScheduleGenerator:
 
-    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0):
+    def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None, num_resets: int = 0) -> Schedule:
 
         _runtime_seed = seed + num_resets
         np.random.seed(_runtime_seed)
@@ -74,7 +100,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
         train_stations = hints['train_stations']
         agent_start_targets_cities = hints['agent_start_targets_cities']
         max_num_agents = hints['num_agents']
-        # city_orientations = hints['city_orientations']
+        city_orientations = hints['city_orientations']
         if num_agents > max_num_agents:
             num_agents = max_num_agents
             warnings.warn("Too many agents! Changes number of agents.")
@@ -125,7 +151,11 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
         else:
             speeds = [1.0] * len(agents_position)
 
-        return agents_position, agents_direction, agents_target, speeds, None
+        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height,
+                                                      ratio_nr_agents_to_nr_cities=num_agents/len(city_orientations))
+        return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
+                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None,
+                        max_episode_steps=max_episode_steps)
 
     return generator
 
@@ -147,22 +177,28 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
     """
 
     def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None,
-                num_resets: int = 0) -> ScheduleGeneratorProduct:
+                num_resets: int = 0) -> Schedule:
         _runtime_seed = seed + num_resets
 
         np.random.seed(_runtime_seed)
 
+        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
+
         valid_positions = []
         for r in range(rail.height):
             for c in range(rail.width):
                 if rail.get_full_transitions(r, c) > 0:
                     valid_positions.append((r, c))
         if len(valid_positions) == 0:
-            return [], [], [], []
+            return Schedule(agent_positions=[], agent_directions=[],
+                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None,
+                            max_episode_steps=max_episode_steps)
 
         if len(valid_positions) < num_agents:
             warnings.warn("schedule_generators: len(valid_positions) < num_agents")
-            return [], [], [], []
+            return Schedule(agent_positions=[], agent_directions=[],
+                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None,
+                            max_episode_steps=max_episode_steps)
 
         agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
         agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
@@ -220,7 +256,9 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
                         np.random.choice(len(valid_starting_directions), 1)[0]]
 
         agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed)
-        return agents_position, agents_direction, agents_target, agents_speed, None
+        return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
+                        agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None,
+                        max_episode_steps=max_episode_steps)
 
     return generator
 
@@ -240,7 +278,7 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
     """
 
     def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None,
-                  num_resets: int = 0) -> ScheduleGeneratorProduct:
+                  num_resets: int = 0) -> Schedule:
         if load_from_package is not None:
             from importlib_resources import read_binary
             load_data = read_binary(load_from_package, filename)
@@ -265,7 +303,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
         else:
             agents_speed = None
             agents_malfunction = None
-        return agents_position, agents_direction, agents_target, agents_speed, agents_malfunction
+        max_episode_steps = compute_max_episode_steps(width=rail.width, height=rail.height)
+        return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
+                        agent_targets=agents_target, agent_speeds=agents_speed,
+                        agent_malfunction_rates=agents_malfunction, max_episode_steps=max_episode_steps)
 
     return generator
 
diff --git a/flatland/envs/schedule_utils.py b/flatland/envs/schedule_utils.py
new file mode 100644
index 00000000..b8f28a47
--- /dev/null
+++ b/flatland/envs/schedule_utils.py
@@ -0,0 +1,8 @@
+import collections
+
+Schedule = collections.namedtuple('Schedule',   'agent_positions '
+                                                'agent_directions '
+                                                'agent_targets '
+                                                'agent_speeds '
+                                                'agent_malfunction_rates '
+                                                'max_episode_steps')
diff --git a/flatland/evaluators/client.py b/flatland/evaluators/client.py
index 6ba7bffd..e6ea1c50 100644
--- a/flatland/evaluators/client.py
+++ b/flatland/evaluators/client.py
@@ -35,12 +35,12 @@ class FlatlandRemoteClient(object):
     """
         Redis client to interface with flatland-rl remote-evaluation-service
         The Docker container hosts a redis-server inside the container.
-        This client connects to the same redis-server, 
+        This client connects to the same redis-server,
         and communicates with the service.
-        The service eventually will reside outside the docker container, 
+        The service eventually will reside outside the docker container,
         and will communicate
         with the client only via the redis-server of the docker container.
-        On the instantiation of the docker container, one service will be 
+        On the instantiation of the docker container, one service will be
         instantiated parallely.
         The service will accepts commands at "`service_id`::commands"
         where `service_id` is either provided as an `env` variable or is
@@ -160,7 +160,7 @@ class FlatlandRemoteClient(object):
 
     def env_create(self, obs_builder_object):
         """
-            Create a local env and remote env on which the 
+            Create a local env and remote env on which the
             local agent can operate.
             The observation builder is only used in the local env
             and the remote env uses a DummyObservationBuilder
@@ -201,20 +201,6 @@ class FlatlandRemoteClient(object):
             obs_builder_object=obs_builder_object
         )
 
-        # Set max episode steps allowed
-        #
-        # the maximum number of episode steps is determined by : 
-        # 
-        # timedelay_factor * alpha * (grid_width + grid_height + (number_of_agents/number_of_cities))  # noqa
-        # 
-        # in the current sprase rail generator, the ratio of 
-        # `number_of_agents/number_of_cities` is roughly 20
-        #
-        # TODO: the serialized env should include the max allowed timesteps per 
-        # env, and should ideally be returned by the rail generator
-        self.env._max_episode_steps = \
-            int(4 * 2 * (self.env.width + self.env.height + 20))
-
         local_observation, info = self.env.reset(
                                 regen_rail=False,
                                 replace_agents=False,
@@ -222,7 +208,7 @@ class FlatlandRemoteClient(object):
                                 random_seed=random_seed
                             )
 
-        # Use the local observation 
+        # Use the local observation
         # as the remote server uses a dummy observation builder
         return local_observation, info
 
@@ -256,7 +242,7 @@ class FlatlandRemoteClient(object):
         # Return local_observation instead of remote_observation
         # as the remote_observation is build using a dummy observation
         # builder
-        # We return the remote rewards and done as they are the 
+        # We return the remote rewards and done as they are the
         # once used by the evaluator
         return [local_observation, remote_reward, remote_done, remote_info]
 
diff --git a/flatland/evaluators/service.py b/flatland/evaluators/service.py
index ef69fc53..5c978d29 100644
--- a/flatland/evaluators/service.py
+++ b/flatland/evaluators/service.py
@@ -55,15 +55,15 @@ class FlatlandRemoteEvaluationService:
     - env_step
     and an additional `env_submit` to cater to score computation and on-episode-complete post processings.
 
-    This service is designed to be used in conjunction with 
-    `FlatlandRemoteClient` and both the srevice and client maintain a 
+    This service is designed to be used in conjunction with
+    `FlatlandRemoteClient` and both the srevice and client maintain a
     local instance of the RailEnv instance, and in case of any unexpected
-    divergences in the state of both the instances, the local RailEnv 
-    instance of the `FlatlandRemoteEvaluationService` is supposed to act 
+    divergences in the state of both the instances, the local RailEnv
+    instance of the `FlatlandRemoteEvaluationService` is supposed to act
     as the single source of truth.
 
-    Both the client and remote service communicate with each other 
-    via Redis as a message broker. The individual messages are packed and 
+    Both the client and remote service communicate with each other
+    via Redis as a message broker. The individual messages are packed and
     unpacked with `msgpack` (a patched version of msgpack which also supports
     numpy arrays).
     """
@@ -172,8 +172,8 @@ class FlatlandRemoteEvaluationService:
                 "*/*.pkl"
             )
         ))
-        # Remove the root folder name from the individual 
-        # lists, so that we only have the path relative 
+        # Remove the root folder name from the individual
+        # lists, so that we only have the path relative
         # to the test root folder
         env_paths = sorted([os.path.relpath(
             x, self.test_env_folder
@@ -183,7 +183,7 @@ class FlatlandRemoteEvaluationService:
 
     def instantiate_redis_connection_pool(self):
         """
-        Instantiates a Redis connection pool which can be used to 
+        Instantiates a Redis connection pool which can be used to
         communicate with the message broker
         """
         if self.verbose or self.report:
@@ -220,7 +220,7 @@ class FlatlandRemoteEvaluationService:
 
     def _error_template(self, payload):
         """
-        Simple helper function to pass a payload as a part of a 
+        Simple helper function to pass a payload as a part of a
         flatland comms error template.
         """
         _response = {}
@@ -233,9 +233,9 @@ class FlatlandRemoteEvaluationService:
         use_signals=use_signals_in_timeout)  # timeout for each command
     def _get_next_command(self, _redis):
         """
-        A low level wrapper for obtaining the next command from a 
+        A low level wrapper for obtaining the next command from a
         pre-agreed command channel.
-        At the momment, the communication protocol uses lpush for pushing 
+        At the momment, the communication protocol uses lpush for pushing
         in commands, and brpop for reading out commands.
         """
         command = _redis.brpop(self.command_channel)[1]
@@ -243,9 +243,9 @@ class FlatlandRemoteEvaluationService:
 
     def get_next_command(self):
         """
-        A helper function to obtain the next command, which transparently 
-        also deals with things like unpacking of the command from the 
-        packed message, and consider the timeouts, etc when trying to 
+        A helper function to obtain the next command, which transparently
+        also deals with things like unpacking of the command from the
+        packed message, and consider the timeouts, etc when trying to
         fetch a new command.
         """
         try:
@@ -306,7 +306,7 @@ class FlatlandRemoteEvaluationService:
                 "[ Server Version : {} ] ".format(service_version)
             self.send_response(_command_response, command)
             raise Exception(_command_response['payload']['message'])
-        
+
         self.send_response(_command_response, command)
 
     def handle_env_create(self, command):
@@ -339,22 +339,8 @@ class FlatlandRemoteEvaluationService:
                     del self.env_renderer
                 self.env_renderer = RenderTool(self.env, gl="PILSVG", )
 
-            # Set max episode steps allowed
-            #
-            # the maximum number of episode steps is determined by : 
-            # 
-            # timedelay_factor * alpha * (grid_width + grid_height + (number_of_agents/number_of_cities))  # noqa
-            # 
-            # in the current sprase rail generator, the ratio of 
-            # `number_of_agents/number_of_cities` is roughly 20
-            #
-            # TODO: the serialized env should include the max allowed timesteps per 
-            # env, and should ideally be returned by the rail generator
-            self.env._max_episode_steps = \
-                int(4 * 2 * (self.env.width + self.env.height + 20))
-
             if self.begin_simulation:
-                # If begin simulation has already been initialized 
+                # If begin simulation has already been initialized
                 # atleast once
                 self.simulation_times.append(time.time() - self.begin_simulation)
             self.begin_simulation = time.time()
@@ -365,7 +351,7 @@ class FlatlandRemoteEvaluationService:
             self.simulation_steps.append(0)
 
             self.current_step = 0
-            
+
             _observation, _info = self.env.reset(
                                 regen_rail=False,
                                 replace_agents=False,
@@ -507,8 +493,8 @@ class FlatlandRemoteEvaluationService:
         if self.visualize and len(os.listdir(self.vizualization_folder_name)) > 0:
             # Generate the video
             #
-            # Note, if you had depdency issues due to ffmpeg, you can 
-            # install it by : 
+            # Note, if you had depdency issues due to ffmpeg, you can
+            # install it by :
             #
             # conda install -c conda-forge x264 ffmpeg
 
-- 
GitLab