From c2aef18b179f66700cb641ca49fb17c66726d2a8 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Thu, 12 Sep 2019 16:38:20 +0200 Subject: [PATCH] generator --- .../Simple_Realistic_Railway_Generator.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index 7e17b394..f23ca73e 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -316,10 +316,14 @@ def realistic_rail_generator(num_cities=5, return nodes_added def connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added, - inter_connect_max_nbr_of_shortes_city): + inter_connect_max_nbr_of_shortes_city, start_to_end=True): - s_nodes = org_s_nodes.copy() - e_nodes = org_e_nodes.copy() + if start_to_end: + s_nodes = org_s_nodes.copy() + e_nodes = org_e_nodes.copy() + else: + e_nodes = org_s_nodes.copy() + s_nodes = org_e_nodes.copy() for city_loop in range(len(s_nodes)): old_cl = [] @@ -358,6 +362,9 @@ def realistic_rail_generator(num_cities=5, rail_array[start_node] = tmp_trans_sn rail_array[end_node] = tmp_trans_en + if start_to_end: + connect_stations(rail_trans, rail_array, org_s_nodes, org_e_nodes, nodes_added, + inter_connect_max_nbr_of_shortes_city, start_to_end=False) def connect_random_stations(rail_trans, rail_array, start_nodes_added, end_nodes_added, nodes_added, inter_connect_max_nbr_of_shortes_city): @@ -525,14 +532,14 @@ def realistic_rail_generator(num_cities=5, for itrials in range(100): print(itrials, "generate new city") np.random.seed(int(time.time())) - env = RailEnv(width=100, # 20+np.random.choice(100), - height=100, # 20+np.random.choice(100), + env = RailEnv(width=30+np.random.choice(100), + height=30+np.random.choice(100), rail_generator=realistic_rail_generator(num_cities=np.random.choice(10) + 2, city_size=np.random.choice(10) + 10, allowed_rotation_angles=np.arange(0, 360, 90), max_number_of_station_tracks=4, nbr_of_switches_per_station_track=2, - connect_max_nbr_of_shortes_city=4, + connect_max_nbr_of_shortes_city=2, do_random_connect_stations=False, # Number of cities in map seed=int(time.time()) # Random seed -- GitLab