Commit e1c4df4d authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

fixed connection bug when not enough neighbors are present

parent bd052de6
Pipeline #2192 failed with stages
in 17 minutes and 39 seconds
......@@ -32,15 +32,15 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are)
rail_generator=sparse_rail_generator(num_cities=3, # Number of cities in map (where train stations are)
num_trainstations=100, # Number of possible start/targets on map
min_node_dist=10, # Minimal distance of nodes
node_radius=4, # Proximity of stations to city center
num_neighb=3, # Number of connections to other cities/intersections
seed=15, # Random seed
grid_mode=True,
grid_mode=False,
nr_parallel_tracks=2,
connectin_points_per_side=2,
connectin_points_per_side=100,
max_nr_connection_directions=3,
),
schedule_generator=sparse_schedule_generator(),
......
......@@ -683,11 +683,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
max_nr_connection_directions=2):
connection_points = []
connection_info = []
for node_position in node_positions:
connection_sides_idx = np.sort(
np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False))
max_nr_connection_directions = np.clip(max_nr_connection_directions, 0, 4)
if max_nr_connection_points > 2 * node_size + 1:
max_nr_connection_points = 2 * node_size + 1
for node_position in node_positions:
# Chose the directions where close cities are situated
neighb_dist = []
for neighb_node in node_positions:
......@@ -696,7 +696,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
# Store the directions to these neighbours
connection_sides_idx = []
for idx in range(1, max_nr_connection_directions + 1):
for idx in range(1, min(len(neighb_dist) - 1, max_nr_connection_directions) + 1):
connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]]))
# set the number of connection points for each direction
......@@ -918,6 +918,35 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
else:
num_agents -= 1
return agent_start_targets_nodes
def _closest_neigh_in_direction(current_node, direction, node_positions):
# Sort available neighbors according to their distance.
available_nodes = np.arange(node_positions)
node_dist = []
for av_node in available_nodes:
node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
sorted_neighbours = available_nodes[np.argsort(node_dist)]
for neighb in sorted_neighbours[1:]:
distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0])
distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1])
if direction == 0:
if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
return neighb
if direction == 1:
if node_positions[neighb][1] > node_positions[current_node][1] and distance_0 <= distance_1:
return neighb
if direction == 2:
if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
return neighb
if direction == 3:
if node_positions[neighb][0] < node_positions[current_node][0] and distance_0 <= distance_1:
return neighb
return None
def argsort(seq):
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment