diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 085d6fd33e02342a89e7bb1d01eb873bf1763f5a..ca14667424d2c93d1466e3b7e96c2e5c1fbd41e5 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -75,8 +75,9 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= while nr_created < nr_start_goal and created_sanity < sanity_max: all_ok = False for _ in range(sanity_max): - start = (np.random.randint(0, width), np.random.randint(0, height)) - goal = (np.random.randint(0, height), np.random.randint(0, height)) + start = (np.random.randint(0, height), np.random.randint(0, width)) + goal = (np.random.randint(0, height), np.random.randint(0, width)) + # check to make sure start,goal pos is empty? if rail_array[goal] != 0 or rail_array[start] != 0: continue @@ -121,8 +122,8 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= while nr_created < nr_extra and created_sanity < sanity_max: all_ok = False for _ in range(sanity_max): - start = (np.random.randint(0, width), np.random.randint(0, height)) - goal = (np.random.randint(0, height), np.random.randint(0, height)) + start = (np.random.randint(0, height), np.random.randint(0, width)) + goal = (np.random.randint(0, height), np.random.randint(0, width)) # check to make sure start,goal pos are not empty if rail_array[goal] == 0 or rail_array[start] == 0: continue