From 0b4c3f901d233abb5bc2d7d83c221cd2a4736d43 Mon Sep 17 00:00:00 2001 From: "Egli Adrian (IT-SCI-API-PFI)" <adrian.egli@sbb.ch> Date: Tue, 17 Sep 2019 22:10:33 +0200 Subject: [PATCH] v0.2 realistic generator --- .../Simple_Realistic_Railway_Generator.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/examples/Simple_Realistic_Railway_Generator.py b/examples/Simple_Realistic_Railway_Generator.py index ec47bdc5..082e9b67 100644 --- a/examples/Simple_Realistic_Railway_Generator.py +++ b/examples/Simple_Realistic_Railway_Generator.py @@ -12,7 +12,7 @@ from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct from flatland.envs.schedule_generators import sparse_schedule_generator -from flatland.utils.rendertools import AgentRenderVariant, RenderTool +from flatland.utils.rendertools import RenderTool, AgentRenderVariant FloatArrayType = [] @@ -124,8 +124,9 @@ def realistic_rail_generator(num_cities=5, end_nodes_added[city_loop].append(end_node) # place in the center of path a station slot - #station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) - station_slots[city_loop].extend(connection) + # station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))]) + for c_loop in range(len(connection)): + station_slots[city_loop].append(connection[c_loop]) station_slots_cnt += len(connection) station_tracks[city_loop][track_id] = connection @@ -139,7 +140,8 @@ def realistic_rail_generator(num_cities=5, return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks - def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + def create_switches_at_stations(rail_trans: RailEnvTransitions, + grid_map: GridTransitionMap, station_tracks: IntVector2DArrayType, nodes_added: IntVector2DArrayType, intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType: @@ -361,6 +363,20 @@ def realistic_rail_generator(num_cities=5, if print_out_info: print("connect_random_stations : connect_nodes -> no path found") + def remove_switch_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap, + train_stations: IntVector2DArrayType): + tmp_train_stations = copy.deepcopy(train_stations) + for city_loop in range(len(train_stations)): + for n in tmp_train_stations[city_loop]: + do_remove = True + trans = rail_trans.transition_list[1] + for _ in range(4): + trans = rail_trans.rotate_transition(trans, rotation=90) + if grid_map.grid[n] == trans: + do_remove = False + if do_remove: + train_stations[city_loop].remove(n) + def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct: rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) @@ -435,6 +451,10 @@ def realistic_rail_generator(num_cities=5, for i in range(len(nodes_added)): grid_map.fix_transitions(nodes_added[i]) + # ---------------------------------------------------------------------------------- + # remove stations where rail is a switch + remove_switch_stations(rail_trans, grid_map, train_stations) + # ---------------------------------------------------------------------------------- # Slot availability in node node_available_start = [] @@ -481,7 +501,7 @@ def realistic_rail_generator(num_cities=5, if os.path.exists("./../render_output/"): - for itrials in np.arange(1,1000,1): + for itrials in np.arange(1, 1000, 1): print(itrials, "generate new city") np.random.seed(itrials) env = RailEnv(width=40 + np.random.choice(100), -- GitLab