Skip to content
Snippets Groups Projects
Commit 0b4c3f90 authored by Egli Adrian (IT-SCI-API-PFI)'s avatar Egli Adrian (IT-SCI-API-PFI)
Browse files

v0.2 realistic generator

parent 4f8fbdb8
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import RailGenerator, RailGeneratorProduct
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import AgentRenderVariant, RenderTool
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
FloatArrayType = []
......@@ -124,8 +124,9 @@ def realistic_rail_generator(num_cities=5,
end_nodes_added[city_loop].append(end_node)
# place in the center of path a station slot
#station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
station_slots[city_loop].extend(connection)
# station_slots[city_loop].append(connection[int(np.floor(len(connection) / 2))])
for c_loop in range(len(connection)):
station_slots[city_loop].append(connection[c_loop])
station_slots_cnt += len(connection)
station_tracks[city_loop][track_id] = connection
......@@ -139,7 +140,8 @@ def realistic_rail_generator(num_cities=5,
return nodes_added, station_slots, start_nodes_added, end_nodes_added, station_tracks
def create_switches_at_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
def create_switches_at_stations(rail_trans: RailEnvTransitions,
grid_map: GridTransitionMap,
station_tracks: IntVector2DArrayType,
nodes_added: IntVector2DArrayType,
intern_nbr_of_switches_per_station_track: int) -> IntVector2DArrayType:
......@@ -361,6 +363,20 @@ def realistic_rail_generator(num_cities=5,
if print_out_info:
print("connect_random_stations : connect_nodes -> no path found")
def remove_switch_stations(rail_trans: RailEnvTransitions, grid_map: GridTransitionMap,
train_stations: IntVector2DArrayType):
tmp_train_stations = copy.deepcopy(train_stations)
for city_loop in range(len(train_stations)):
for n in tmp_train_stations[city_loop]:
do_remove = True
trans = rail_trans.transition_list[1]
for _ in range(4):
trans = rail_trans.rotate_transition(trans, rotation=90)
if grid_map.grid[n] == trans:
do_remove = False
if do_remove:
train_stations[city_loop].remove(n)
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
......@@ -435,6 +451,10 @@ def realistic_rail_generator(num_cities=5,
for i in range(len(nodes_added)):
grid_map.fix_transitions(nodes_added[i])
# ----------------------------------------------------------------------------------
# remove stations where rail is a switch
remove_switch_stations(rail_trans, grid_map, train_stations)
# ----------------------------------------------------------------------------------
# Slot availability in node
node_available_start = []
......@@ -481,7 +501,7 @@ def realistic_rail_generator(num_cities=5,
if os.path.exists("./../render_output/"):
for itrials in np.arange(1,1000,1):
for itrials in np.arange(1, 1000, 1):
print(itrials, "generate new city")
np.random.seed(itrials)
env = RailEnv(width=40 + np.random.choice(100),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment