Skip to content
Snippets Groups Projects
Commit b9092b25 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed bug in generator related to random sampling

parent f3471970
No related branches found
No related tags found
No related merge requests found
......@@ -87,13 +87,18 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
# Set target for agent
start_city = agent_start_targets_cities[agent_idx][0]
target_city = agent_start_targets_cities[agent_idx][1]
start = np.random.choice(train_stations[start_city])
target = np.random.choice(train_stations[target_city])
start_idx = np.random.choice(np.arange(len(train_stations[start_city])))
target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
start = train_stations[start_city][start_idx]
target = train_stations[target_city][target_idx]
while start[1] % 2 != 0:
start = np.random.choice(train_stations[start_city])
start_idx = np.random.choice(np.arange(len(train_stations[start_city])))
start = train_stations[start_city][start_idx]
while target[1] % 2 != 1:
target = np.random.choice(train_stations[target_city])
target_idx = np.random.choice(np.arange(len(train_stations[target_city])))
target = train_stations[target_city][target_idx]
agent_orientation = (agent_start_targets_cities[agent_idx][2] + 2 * start[1]) % 4
if not rail.check_path_exists(start[0], agent_orientation, target[0]):
agent_orientation = (agent_orientation + 2) % 4
......
......@@ -20,7 +20,7 @@ def test_random_seeding():
# Test generation print
env.agents[0].target = (0, 0)
for step in range(100):
for step in range(10):
actions = {}
actions[0] = 2
env.step(actions)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment