diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index 7a857fe5e8d570ce34a32ba53bcaadf603546081..ac25118d8a3647924c3fb5fc3880ec83a4f9c03c 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -13,33 +13,20 @@ from flatland.envs.distance_map import DistanceMap
 from flatland.envs.rail_env_shortest_paths import get_shortest_paths
 
 
-# #### DATA COLLECTION *************************
-# import termplotlib as tpl
-# import matplotlib.pyplot as plt
-# root_path = 'C:\\Users\\nimish\\Programs\\AIcrowd\\flatland\\flatland\\playground'
-# dir_name = 'TEMP'
-# os.mkdir(os.path.join(root_path, dir_name))
-
-# # Histogram 1
-# dist_resolution = 50
-# schedule_dist = np.zeros(shape=(dist_resolution))
-# # Volume dist
-# route_dist = None
-# # Dist - shortest path
-# shortest_paths_len_dist = []
-# # City positions
-# city_positions = []
-# #### DATA COLLECTION *************************
-
 def schedule_generator(agents: List[EnvAgent], config_speeds: List[float],  distance_map: DistanceMap, 
                             agents_hints: dict, np_random: RandomState = None) -> Schedule:
 
     # max_episode_steps calculation
-    city_positions = agents_hints['city_positions']
+    if agents_hints:
+        city_positions = agents_hints['city_positions']
+        num_cities = len(city_positions)
+    else:
+        num_cities = 2
+
     timedelay_factor = 4
     alpha = 2
     max_episode_steps = int(timedelay_factor * alpha * \
-        (distance_map.rail.width + distance_map.rail.height + (len(agents) / len(city_positions))))
+        (distance_map.rail.width + distance_map.rail.height + (len(agents) / num_cities)))
     
     # Multipliers
     old_max_episode_steps_multiplier = 3.0
@@ -71,64 +58,20 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float],  dist
 
     max_episode_steps = min(max_episode_steps_new, max_episode_steps_old)
     
-    end_buffer = max_episode_steps * end_buffer_multiplier
+    end_buffer = int(max_episode_steps * end_buffer_multiplier)
     latest_arrival_max = max_episode_steps-end_buffer
 
     # Useless unless needed by returning
     earliest_departures = []
     latest_arrivals = []
 
-    # #### DATA COLLECTION *************************
-    # # Create info.txt
-    # with open(os.path.join(root_path, dir_name, 'INFO.txt'), 'w') as f:
-    #     f.write('COPY FROM main.py')
-
-    # # Volume dist
-    # route_dist = np.zeros(shape=(max_episode_steps, distance_map.rail.width, distance_map.rail.height), dtype=np.int8)
-
-    # # City positions
-    # # Dummy distance map for shortest path pairs between cities
-    # city_positions = agents_hints['city_positions']
-    # d_rail = distance_map.rail
-    # d_dmap = DistanceMap([], d_rail.height, d_rail.width)
-    # d_city_permutations = list(itertools.permutations(city_positions, 2))
-
-    # d_positions = []
-    # d_targets = []
-    # for position, target in d_city_permutations:
-    #     d_positions.append(position)
-    #     d_targets.append(target)
-    
-    # d_schedule = Schedule(d_positions,
-    #                       [0] * len(d_positions),
-    #                       d_targets,
-    #                       [1.0] * len(d_positions),
-    #                       [None] * len(d_positions),
-    #                       1000)
-    
-    # d_agents = EnvAgent.from_schedule(d_schedule)
-    # d_dmap.reset(d_agents, d_rail)
-    # d_map = d_dmap.get()
-
-    # d_data = {
-    #     'city_positions': city_positions,
-    #     'start': d_positions,
-    #     'end': d_targets,
-    # }
-    # with open(os.path.join(root_path, dir_name, 'city_data.json'), 'w') as f:
-    #     json.dump(d_data, f)
-
-    # with open(os.path.join(root_path, dir_name, 'distance_map.npy'), 'wb') as f:
-    #     np.save(f, d_map)
-    # #### DATA COLLECTION *************************
-
     for agent in agents:
         agent_shortest_path_time = agent_shortest_path_times[agent.handle]
         agent_travel_time_max = int(np.ceil((agent_shortest_path_time * travel_buffer_multiplier) \
                                             + (mean_shortest_path_time * mean_shortest_path_multiplier)))
         
-        departure_window_max = latest_arrival_max - agent_travel_time_max
-
+        departure_window_max = max(latest_arrival_max - agent_travel_time_max, 1)
+        
         earliest_departure = np_random.randint(0, departure_window_max)
         latest_arrival = earliest_departure + agent_travel_time_max
         
@@ -138,45 +81,5 @@ def schedule_generator(agents: List[EnvAgent], config_speeds: List[float],  dist
         agent.earliest_departure = earliest_departure
         agent.latest_arrival = latest_arrival
 
-    # #### DATA COLLECTION *************************
-    #     # Histogram 1
-    #     dist_bounds = get_dist_window(earliest_departure, latest_arrival, latest_arrival_max)
-    #     schedule_dist[dist_bounds[0]: dist_bounds[1]] += 1
-
-    #     # Volume dist
-    #     for waypoint in agent_shortest_path:
-    #         pos = waypoint.position
-    #         route_dist[earliest_departure:latest_arrival, pos[0], pos[1]] += 1
-
-    #     # Dist - shortest path
-    #     shortest_paths_len_dist.append(agent_shortest_path_len)
-
-    # np.save(os.path.join(root_path, dir_name, 'volume.npy'), route_dist)
-    
-    # shortest_paths_len_dist.sort()
-    # save_sp_fig()
-    # #### DATA COLLECTION *************************
-
-    # returns schedule
     return Schedule(earliest_departures=earliest_departures, latest_arrivals=latest_arrivals,
                     max_episode_steps=max_episode_steps)
-
-
-# #### DATA COLLECTION *************************
-# # Histogram 1
-# def get_dist_window(departure_t, arrival_t, latest_arrival_max):
-#     return (int(np.round(np.interp(departure_t, [0, latest_arrival_max], [0, dist_resolution]))),
-#             int(np.round(np.interp(arrival_t, [0, latest_arrival_max], [0, dist_resolution]))))
-
-# def plot_dist():
-#     counts, bin_edges = schedule_dist, [i for i in range(0, dist_resolution+1)]
-#     fig = tpl.figure()
-#     fig.hist(counts, bin_edges, orientation="horizontal", force_ascii=False)
-#     fig.show()
-
-# # Shortest path dist
-# def save_sp_fig():
-#     fig = plt.figure(figsize=(15, 7))
-#     plt.bar(np.arange(len(shortest_paths_len_dist)), shortest_paths_len_dist)
-#     plt.savefig(os.path.join(root_path, dir_name, 'shortest_paths_sorted.png'))
-# #### DATA COLLECTION *************************