From 5f8a5de1e35d74b1fd1501d6b2f70c3f83355019 Mon Sep 17 00:00:00 2001
From: Dipam Chakraborty <dipam@aicrowd.com>
Date: Wed, 4 Aug 2021 21:01:34 +0530
Subject: [PATCH] update calculation of speed normalized path lengths

---
 flatland/envs/rail_env_utils.py      |  2 +-
 flatland/envs/schedule_generators.py | 20 +++++++-------------
 2 files changed, 8 insertions(+), 14 deletions(-)

diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index 94f1c2d4..22f73f98 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -33,7 +33,7 @@ def load_flatland_environment_from_file(file_name: str,
             max_depth=2,
             predictor=ShortestPathPredictorForRailEnv(max_depth=10))
     environment = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name, load_from_package),
-                          schedule_generator=line_from_file(file_name, load_from_package),
+                          line_generator=line_from_file(file_name, load_from_package),
                           number_of_agents=1,
                           obs_builder_object=obs_builder_object,
                           record_steps=record_steps,
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index ac25118d..a8bd42b9 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -32,6 +32,7 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float],  dist
     old_max_episode_steps_multiplier = 3.0
     new_max_episode_steps_multiplier = 1.5
     travel_buffer_multiplier = 1.3 # must be strictly lesser than new_max_episode_steps_multiplier
+    assert new_max_episode_steps_multiplier > travel_buffer_multiplier
     end_buffer_multiplier = 0.05
     mean_shortest_path_multiplier = 0.2
     
@@ -39,20 +40,14 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float],  dist
     shortest_paths_lengths = [len(v) for k,v in shortest_paths.items()]
 
     # Find mean_shortest_path_time
-    agent_shortest_path_times = []
-    for agent in agents:
-        speed = agent.speed_data['speed']
-        distance = shortest_paths_lengths[agent.handle]
-        agent_shortest_path_times.append(int(np.ceil(distance / speed)))
-
+    agent_speeds = [agent.speed_data['speed'] for agent in agents]
+    agent_shortest_path_times = np.array(shortest_paths_lengths)/ np.array(agent_speeds)
     mean_shortest_path_time = np.mean(agent_shortest_path_times)
 
     # Deciding on a suitable max_episode_steps
-    max_sp_len = max(shortest_paths_lengths) # longest path
-    min_speed = min(config_speeds)           # slowest possible speed in config
-    
-    longest_sp_time = max_sp_len / min_speed
-    max_episode_steps_new = int(np.ceil(longest_sp_time * new_max_episode_steps_multiplier))
+    longest_speed_normalized_time = np.max(agent_shortest_path_times)
+    mean_path_delay = mean_shortest_path_time * mean_shortest_path_multiplier
+    max_episode_steps_new = int(np.ceil(longest_speed_normalized_time * new_max_episode_steps_multiplier) + mean_path_delay)
     
     max_episode_steps_old = int(max_episode_steps * old_max_episode_steps_multiplier)
 
@@ -67,8 +62,7 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float],  dist
 
     for agent in agents:
         agent_shortest_path_time = agent_shortest_path_times[agent.handle]
-        agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) \
-                                            + (mean_shortest_path_time * mean_shortest_path_multiplier)))
+        agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) + mean_path_delay))
         
         departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1)
         
-- 
GitLab