Commit 4dc054cf authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

changed behavior of fix transition. If only one inbound connection is found to...

changed behavior of fix transition. If only one inbound connection is found to empty cell, let cell be empty and instead fix the incoming line to a dead end in seperate step
parent 53f27b77
Pipeline #2213 failed with stages
in 60 minutes
......@@ -33,15 +33,15 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are)
min_node_dist=10, # Minimal distance of nodes
min_node_dist=12, # Minimal distance of nodes
node_radius=4, # Proximity of stations to city center
seed=0, # Random seed
grid_mode=False,
max_inter_city_rails=2,
tracks_in_city=4,
tracks_in_city=50,
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=100,
number_of_agents=50,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv())
......
......@@ -516,7 +516,6 @@ class GridTransitionMap(TransitionMap):
double_slip = cells[5]
three_way_transitions = [simple_switch_east_south, simple_switch_west_south]
# loop over available outbound directions (indices) for rcPos
self.set_transitions(rcPos, 0)
incoming_connections = np.zeros(4)
for iDirOut in np.arange(4):
......@@ -539,19 +538,28 @@ class GridTransitionMap(TransitionMap):
incoming_connections[iDirOut] = 1
number_of_incoming = np.sum(incoming_connections)
# Only one incoming direction --> Straight line
# Only one incoming direction --> Straight line set deadend
if number_of_incoming == 1:
if self.get_full_transitions(*rcPos) == 0:
self.set_transitions(rcPos, 0)
else:
self.set_transitions(rcPos, 0)
for direction in range(4):
if incoming_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
self.set_transitions(rcPos, 0)
connect_directions = np.argwhere(incoming_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection for three entries
if number_of_incoming == 3:
self.set_transitions(rcPos, 0)
transition = np.random.choice(three_way_transitions, 1)
hole = np.argwhere(incoming_connections < 1)[0][0]
transition = transitions.rotate_transition(transition, int(hole * 90))
......
......@@ -659,8 +659,8 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
def _generate_node_connection_points(node_positions, node_size, tracks_in_city=2):
connection_points = []
connection_info = []
if tracks_in_city > 2 * node_size + 1:
tracks_in_city = 2 * node_size + 1
if tracks_in_city > 2 * node_size - 1:
tracks_in_city = 2 * node_size - 1
for node_position in node_positions:
......@@ -811,32 +811,6 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
train_stations[current_city].append(possible_location)
return train_stations, built_num_trainstations
def _fix_transitions(grid_map):
"""
Function to fix all transition elements in environment
"""
# Fix all nodes with illegal transition maps
empty_to_fix = []
rails_to_fix = []
height, width = np.shape(grid_map.grid)
for r in range(height):
for c in range(width):
rc_pos = (r, c)
check = grid_map.cell_neighbours_valid(rc_pos, True)
if not check:
if grid_map.grid[rc_pos] == 0:
empty_to_fix.append(rc_pos)
else:
rails_to_fix.append(rc_pos)
# Fix empty cells first to avoid cutting the network
for cell in empty_to_fix:
grid_map.fix_transitions(cell)
# Fix all other cells
for cell in rails_to_fix:
grid_map.fix_transitions(cell)
def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
# Generate start and target node directory for all agents.
......@@ -876,6 +850,32 @@ def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
num_agents -= 1
return agent_start_targets_nodes, num_agents
def _fix_transitions(grid_map):
"""
Function to fix all transition elements in environment
"""
# Fix all nodes with illegal transition maps
empty_to_fix = []
rails_to_fix = []
height, width = np.shape(grid_map.grid)
for r in range(height):
for c in range(width):
rc_pos = (r, c)
check = grid_map.cell_neighbours_valid(rc_pos, True)
if not check:
if grid_map.grid[rc_pos] == 0:
empty_to_fix.append(rc_pos)
else:
rails_to_fix.append(rc_pos)
# Fix empty cells first to avoid cutting the network
for cell in empty_to_fix:
grid_map.fix_transitions(cell)
# Fix all other cells
for cell in rails_to_fix:
grid_map.fix_transitions(cell)
def _closest_neigh_in_direction(current_node, direction, node_positions):
# Sort available neighbors according to their distance.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment