diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index f7920c29729b30153d153ce55d4ce45ea0894082..3bfbf1eead06728799b19bfaaf04edf1acf47632 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -688,6 +688,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n max_nr_connection_points = 2 * node_size + 1 for node_position in node_positions: + # Chose the directions where close cities are situated neighb_dist = [] for neighb_node in node_positions: @@ -696,9 +697,17 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Store the directions to these neighbours connection_sides_idx = [] - - for idx in range(1, min(len(neighb_dist) - 1, max_nr_connection_directions) + 1): - connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]])) + idx = 1 + # TODO: Change the way this code works! Check that we get sufficient direction. + # TODO: Check if this works as expected + while len(connection_sides_idx) < max_nr_connection_directions and idx < len(neighb_dist): + if closest_direction(node_position, + node_positions[closest_neighb_idx[idx]]) not in connection_sides_idx: + connection_sides_idx.append( + closest_direction(node_position, node_positions[closest_neighb_idx[idx]])) + idx += 1 + else: + idx += 1 # set the number of connection points for each direction connections_per_direction = np.zeros(4, dtype=int)