diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 53105ab15dbbbf911be688d05b0cad4fcc4a82ef..fc358a2a7da1367c05fe2a52f4ce35a6c551c588 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -32,15 +32,15 @@ speed_ration_map = {1.: 0.25, # Fast passenger train env = RailEnv(width=50, height=50, - rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are) + rail_generator=sparse_rail_generator(num_cities=3, # Number of cities in map (where train stations are) num_trainstations=100, # Number of possible start/targets on map min_node_dist=10, # Minimal distance of nodes node_radius=4, # Proximity of stations to city center num_neighb=3, # Number of connections to other cities/intersections seed=15, # Random seed - grid_mode=True, + grid_mode=False, nr_parallel_tracks=2, - connectin_points_per_side=2, + connectin_points_per_side=100, max_nr_connection_directions=3, ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 853a6376126f8becde8a1b0821bc3ba778910edf..b7ca0f4d94d2bb1ce7e9bbfd8930662ec0893aa1 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -683,11 +683,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n max_nr_connection_directions=2): connection_points = [] connection_info = [] - for node_position in node_positions: - - connection_sides_idx = np.sort( - np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False)) + max_nr_connection_directions = np.clip(max_nr_connection_directions, 0, 4) + if max_nr_connection_points > 2 * node_size + 1: + max_nr_connection_points = 2 * node_size + 1 + for node_position in node_positions: # Chose the directions where close cities are situated neighb_dist = [] for neighb_node in node_positions: @@ -696,7 +696,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n # Store the directions to these neighbours connection_sides_idx = [] - for idx in range(1, max_nr_connection_directions + 1): + for idx in range(1, min(len(neighb_dist) - 1, max_nr_connection_directions) + 1): connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]])) # set the number of connection points for each direction @@ -918,6 +918,35 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n else: num_agents -= 1 return agent_start_targets_nodes + + def _closest_neigh_in_direction(current_node, direction, node_positions): + # Sort available neighbors according to their distance. + available_nodes = np.arange(node_positions) + node_dist = [] + for av_node in available_nodes: + node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node])) + sorted_neighbours = available_nodes[np.argsort(node_dist)] + + for neighb in sorted_neighbours[1:]: + distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0]) + distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1]) + if direction == 0: + if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0: + return neighb + + if direction == 1: + if node_positions[neighb][1] > node_positions[current_node][1] and distance_0 <= distance_1: + return neighb + + if direction == 2: + if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0: + return neighb + + if direction == 3: + if node_positions[neighb][0] < node_positions[current_node][0] and distance_0 <= distance_1: + return neighb + return None + def argsort(seq): # http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python return sorted(range(len(seq)), key=seq.__getitem__)