diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 2ac0b6219b10c183d57dd8604da752799c79a21d..8b83131bd1642fe484f4c4776be893b30ad9b99a 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -28,9 +28,9 @@ speed_ration_map = {1.: 1., # Fast passenger train 1. / 3.: 0., # Slow commuter train 1. / 4.: 0.} # Slow freight train -env = RailEnv(width=100, - height=100, - rail_generator=sparse_rail_generator(num_cities=20, # Number of cities in map (where train stations are) +env = RailEnv(width=50, + height=50, + rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map (where train stations are) seed=10, # Random seed grid_mode=False, max_inter_city_rails=2, diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 610276a66d10d8b2eaba4e83286cd227ac593ac8..b4d60713d0d4481f0e625e97bc9869b2d1927679 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -1,4 +1,5 @@ """Schedule generators (railway undertaking, "EVU").""" +import random import warnings from typing import Tuple, List, Callable, Mapping, Optional, Any @@ -74,11 +75,13 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> # Set target for agent start_city = agent_start_targets_nodes[agent_idx][0] target_city = agent_start_targets_nodes[agent_idx][1] + start = random.choice(train_stations[start_city]) + target = random.choice(train_stations[target_city]) + while start[1] % 2 != 0: + start = random.choice(train_stations[start_city]) + while target[1] % 2 != 1: + target = random.choice(train_stations[start_city]) - start_city_idx = np.random.randint(len(train_stations[start_city])) - start = train_stations[start_city][start_city_idx] - target_station_idx = np.random.randint(len(train_stations[target_city])) - target = train_stations[target_city][target_station_idx] agent_orientation = (agent_start_targets_nodes[agent_idx][2] + 2 * start[1]) % 4 if not rail.check_path_exists(start[0], agent_orientation, target[0]): agent_orientation = (agent_orientation + 2) % 4