diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index c6166d5891dcf765fe29af2be76db5c78c37dfe8..129ab75a98c51c5561934f3e7c3747bd8b2bc954 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -122,10 +122,10 @@ def realistic_rail_generator(num_cities=5, # station main orientation (horizontal or vertical rot_angle = np.random.choice(allowed_rotation_angles) add_pos_val = vec2d.scale_pos(vec2d.rotate_pos((1, 0), rot_angle), - (max(1, (intern_city_size - 3) / 2))) + (max(1, (intern_city_size - 3) / 2))) generate_city_locations[i][0] = vec2d.add_pos(generate_city_locations[i][1], add_pos_val) add_pos_val = vec2d.scale_pos(vec2d.rotate_pos((1, 0), 180 + rot_angle), - (max(1, (intern_city_size - 3) / 2))) + (max(1, (intern_city_size - 3) / 2))) generate_city_locations[i][1] = vec2d.add_pos(generate_city_locations[i][1], add_pos_val) return generate_city_locations @@ -179,21 +179,26 @@ def realistic_rail_generator(num_cities=5, for city_loop in range(len(station_tracks)): datas = station_tracks[city_loop] - if len(datas)>1: + if len(datas) > 1: a = datas[0] - b = [] - for i in range(len(datas)): - tmp = datas[i] - if len(tmp)>0: - b = tmp - start_node = a[min(2,len(a))] - end_node = b[len(b)-1] - rail_array[start_node] = 0 - rail_array[end_node] = 0 - connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node) - if len(connection) > 0: - nodes_added.append(start_node) - nodes_added.append(end_node) + if len(a) > 2: + j = 2 + start_node = a[j] + b = [] + for i in np.arange(1, len(datas)): + b = datas[i] + if len(b) > 2: + end_node = b[j + 2] + connection = connect_from_nodes(rail_trans, rail_array, start_node, end_node) + if len(connection) > 0: + nodes_added.append(start_node) + nodes_added.append(end_node) + if i % 2 == 0: + j = j - 2 + else: + j = j + 2 + start_node = end_node + return nodes_added def calc_nbr_of_graphs(graph): @@ -465,7 +470,7 @@ def realistic_rail_generator(num_cities=5, for itrials in range(1000): print(itrials, "generate new city") - np.random.seed(0*int(time.time())) + np.random.seed(0 * int(time.time())) env = RailEnv(width=40 + np.random.choice(100), height=40 + np.random.choice(100), rail_generator=realistic_rail_generator(num_cities=2 + np.random.choice(10), @@ -480,7 +485,7 @@ for itrials in range(1000): print_out_info=False ), schedule_generator=sparse_schedule_generator(), - number_of_agents=1 + np.random.choice(10), + number_of_agents=1, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static