diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 3a3783878edb7344928121acb21962f864f4fc39..85fb380a3f244e870c7d233c585e8b04fe339e01 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -70,13 +70,14 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> agents_position = [] agents_target = [] agents_direction = [] - + start_slots = train_stations + target_slots = train_stations for agent_idx in range(num_agents): # Set target for agent current_target_node = agent_start_targets_nodes[agent_idx][1] - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - train_stations[current_target_node].pop(target_station_idx) + target_station_idx = np.random.randint(len(target_slots[current_target_node])) + target = target_slots[current_target_node][target_station_idx] + target_slots[current_target_node].pop(target_station_idx) agents_target.append((target[0][0], target[0][1])) # Set start for agent and corresponding orientation @@ -88,13 +89,12 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> track_to_use = 0 else: track_to_use = 1 - for i in range(len(train_stations[current_start_node])): - if train_stations[current_start_node][i][1] == track_to_use: + for i in range(len(start_slots[current_start_node])): + if start_slots[current_start_node][i][1] == track_to_use: start_station_idx = i break - - start = train_stations[current_start_node][start_station_idx] - train_stations[current_start_node].pop(start_station_idx) + start = start_slots[current_start_node][start_station_idx] + start_slots[current_start_node].pop(start_station_idx) agents_position.append((start[0][0], start[0][1])) agents_direction.append(agent_start_orientation)