From f3f49ca65228604774183f6087fcedf47c6ca402 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 17 Oct 2019 11:26:08 +0200
Subject: [PATCH] add method to compute the max_episode_steps to
 sparse_rail_generator

---
 flatland/envs/rail_generators.py | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index 6c9ec67d..39bc1088 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -542,6 +542,15 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11, seed=1) -> R
     return generator
 
 
+def compute_max_episode_steps(width: int,
+                              height: int,
+                              num_agents: int,
+                              num_cities: int = 1,
+                              timedelay_factor: int = 4,
+                              alpha: int = 2) -> int:
+    return int(timedelay_factor * alpha * (width + height + (num_agents/num_cities)))
+
+
 def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4,
                           max_rails_in_city: int = 4, seed: int = 1) -> RailGenerator:
     """
@@ -611,11 +620,14 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
         # Generate start target pairs
         agent_start_targets_cities = _generate_start_target_pairs(num_agents, num_cities, train_stations,
                                                                   city_orientations)
+        max_episode_steps = compute_max_episode_steps(width=width, height=height, num_agents=num_agents, num_cities=num_cities)
+
         return grid_map, {'agents_hints': {
             'num_agents': num_agents,
             'agent_start_targets_cities': agent_start_targets_cities,
             'train_stations': train_stations,
-            'city_orientations': city_orientations
+            'city_orientations': city_orientations,
+            'max_episode_steps': max_episode_steps
         }}
 
     def _generate_random_city_positions(num_cities: int, city_radius: int, width: int,
-- 
GitLab