diff --git a/examples/play_model.py b/examples/play_model.py index e69b312b1ceb2f450256d247f4b63c14a728acb5..7f92cb3c90fd95aeca097510ea593363ffa9568f 100644 --- a/examples/play_model.py +++ b/examples/play_model.py @@ -97,7 +97,7 @@ def main(render=True, delay=0.0): # Example generate a random rail env = RailEnv(width=15, height=15, - rail_generator=complex_rail_generator(nr_start_goal=20, min_dist=12), + rail_generator=complex_rail_generator(nr_start_goal=5, nr_extra=20, min_dist=12), number_of_agents=5) if render: diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 4f356e1c673d37af40d59a1b4c297bec59f5ce6c..7452d325530bb189084182f0bbc4bf26369e5881 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -9,7 +9,7 @@ from flatland.envs.env_utils import distance_on_rail, connect_rail, get_directio from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail -def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): +def complex_rail_generator(nr_start_goal=1, nr_extra=10, min_dist=2, max_dist=99999, seed=0): """ Parameters ------- @@ -123,7 +123,27 @@ def complex_rail_generator(nr_start_goal=1, min_dist=2, max_dist=99999, seed=0): # print("failed...") created_sanity += 1 - print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs") + # add extra connections between existing rail + created_sanity = 0 + nr_created = 0 + while nr_created < nr_extra and created_sanity < sanity_max: + all_ok = False + for _ in range(sanity_max): + start = (np.random.randint(0, width), np.random.randint(0, height)) + goal = (np.random.randint(0, height), np.random.randint(0, height)) + # check to make sure start,goal pos are not empty + if rail_array[goal] == 0 or rail_array[start] == 0: + continue + else: + all_ok = True + break + if not all_ok: + break + new_path = connect_rail(rail_trans, rail_array, start, goal) + if len(new_path) >= 2: + nr_created += 1 + + print("\n> Complex Rail Gen: Created #", len(start_goal), "pairs and #", nr_created, "extra connections") # print(start_goal) agents_position = [sg[0] for sg in start_goal]