diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index ef87c96e995015f98022ea9e619276fd28ba49ad..720c9379ef0bb0d6536d786228e0cafc3339ddc1 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -39,7 +39,7 @@ class EnvAgent:
     target = attrib(type=Tuple[int, int])
     moving = attrib(default=False, type=bool)
 
-    # NEW - time scheduling
+    # NEW : Agent properties for scheduling
     earliest_departure = attrib(default=None, type=int)  # default None during _from_schedule()
     latest_arrival = attrib(default=None, type=int)  # default None during _from_schedule()
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index cef55ae732eba625a92a04fbdfc315fadda57aa7..958d228220253c203e7809c002f4c057f3ce7129 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -34,7 +34,7 @@ from gym.utils import seeding
 # from flatland.envs.rail_generators import random_rail_generator, RailGenerator
 # from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
 
-# NEW 
+# NEW : Imports 
 from flatland.envs.schedule_time_generators import schedule_time_generator
 
 # Adrian Egli performance fix (the fast methods brings more than 50%)
@@ -379,22 +379,31 @@ class RailEnv(Environment):
             if optionals and 'agents_hints' in optionals:
                 agents_hints = optionals['agents_hints']
 
-            schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, self.num_resets,
-                                               self.np_random)
+            schedule = self.schedule_generator(self.rail, self.number_of_agents, agents_hints, 
+                                               self.num_resets, self.np_random)
             self.agents = EnvAgent.from_schedule(schedule)
 
             # Get max number of allowed time steps from schedule generator
             # Look at the specific schedule generator used to see where this number comes from
-            self._max_episode_steps = schedule.max_episode_steps
+            self._max_episode_steps = schedule.max_episode_steps # NEW UPDATE THIS!
 
         self.agent_positions = np.zeros((self.height, self.width), dtype=int) - 1
 
-        # Reset agents to initial
-        self.reset_agents()
+        # Reset distance map - basically initializing
         self.distance_map.reset(self.agents, self.rail)
 
-        # NEW - time window scheduling
-        schedule_time_generator(self.agents, self.distance_map, schedule, self.np_random, temp_info=optionals)
+        # NEW : Time Schedule Generation
+        # find agent speeds (needed for max_ep_steps recalculation)
+        if (type(self.schedule_generator.speed_ratio_map) is dict):
+            config_speeds = list(self.schedule_generator.speed_ratio_map.keys())
+        else:
+            config_speeds = [1.0]
+
+        self._max_episode_steps = schedule_time_generator(self.agents, config_speeds, self.distance_map, 
+                                        self._max_episode_steps, self.np_random, temp_info=optionals)
+        
+        # Reset agents to initial states
+        self.reset_agents()
 
         for agent in self.agents:
             # Induce malfunctions
diff --git a/flatland/envs/schedule_time_generators.py b/flatland/envs/schedule_time_generators.py
index 1587ffc378b424b21cfb9e4f3660f631521e5f32..dafa0aca361037a7758ce3e3dff740b52879e8bd 100644
--- a/flatland/envs/schedule_time_generators.py
+++ b/flatland/envs/schedule_time_generators.py
@@ -30,17 +30,44 @@ from flatland.envs.distance_map import DistanceMap
 # city_positions = []
 # #### DATA COLLECTION *************************
 
-def schedule_time_generator(agents: List[EnvAgent], distance_map: DistanceMap, schedule: Schedule,
-                            np_random: RandomState = None, temp_info=None) -> None:
+def schedule_time_generator(agents: List[EnvAgent], config_speeds: List[float],  distance_map: DistanceMap, 
+                            max_episode_steps: int, np_random: RandomState = None, temp_info=None) -> int:
+    
+    # Multipliers
+    old_max_episode_steps_multiplier = 3.0
+    new_max_episode_steps_multiplier = 1.5
+    travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier
+    end_buffer_multiplier = 0.05
+    mean_shortest_path_multiplier = 0.2
     
     from flatland.envs.rail_env_shortest_paths import get_shortest_paths
     shortest_paths = get_shortest_paths(distance_map)
+    shortest_paths_lengths = [len(v) for k,v in shortest_paths.items()]
+
+    # Find mean_shortest_path_time
+    agent_shortest_path_times = []
+    for agent in agents:
+        speed = agent.speed_data['speed']
+        distance = shortest_paths_lengths[agent.handle]
+        agent_shortest_path_times.append(int(np.ceil(distance / speed)))
+
+    mean_shortest_path_time = np.mean(agent_shortest_path_times)
+
+    # Deciding on a suitable max_episode_steps
+    max_sp_len = max(shortest_paths_lengths) # longest path
+    min_speed = min(config_speeds)           # slowest possible speed in config
     
-    max_episode_steps = int(schedule.max_episode_steps * 1.0) #needs to be increased due to fractional speeds taking way longer (best - be calculated here)
-    end_buffer = max_episode_steps // 20                 #schedule.end_buffer
+    longest_sp_time = max_sp_len / min_speed
+    max_episode_steps_new = int(np.ceil(longest_sp_time * new_max_episode_steps_multiplier))
+    
+    max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier)
+
+    max_episode_steps = min(max_episode_steps_new, max_episode_steps_old)
+    
+    end_buffer = max_episode_steps * end_buffer_multiplier
     latest_arrival_max = max_episode_steps-end_buffer
-    travel_buffer_multiplier = 1.7
 
+    # Useless unless needed by returning
     earliest_departures = []
     latest_arrivals = []
 
@@ -89,11 +116,10 @@ def schedule_time_generator(agents: List[EnvAgent], distance_map: DistanceMap, s
     # #### DATA COLLECTION *************************
 
     for agent in agents:
-        agent_speed = agent.speed_data['speed']
-        agent_shortest_path = shortest_paths[agent.handle]
-        agent_shortest_path_len = len(agent_shortest_path)
-        agent_shortest_path_time = int(np.ceil(agent_shortest_path_len / agent_speed)) # for fractional speeds 1/3 etc
-        agent_travel_time_max = min( int(np.ceil(agent_shortest_path_time * travel_buffer_multiplier)), latest_arrival_max) # min(this, latest_arrival_max), SHOULD NOT BE lesser than shortest path time
+        agent_shortest_path_time = agent_shortest_path_times[agent.handle]
+        agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) \
+                                            + (mean_shortest_path_time * mean_shortest_path_multiplier)))
+        
         departure_window_max = latest_arrival_max - agent_travel_time_max
 
         earliest_departure = np_random.randint(0, departure_window_max)
@@ -124,6 +150,9 @@ def schedule_time_generator(agents: List[EnvAgent], distance_map: DistanceMap, s
     # save_sp_fig()
     # #### DATA COLLECTION *************************
 
+    # returns max_episode_steps after deciding on the new value 
+    return max_episode_steps
+
 
 # #### DATA COLLECTION *************************
 # # Histogram 1