diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index ca9346fc71a8d5485d7b71098dc91558d09bd929..44a086b770fb3b276ab87e825f7d10802a86904b 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -40,7 +40,7 @@ env = RailEnv(width=50, seed=15, # Random seed grid_mode=True, nr_parallel_tracks=2, - connectin_points_per_side=3, + connectin_points_per_side=5, max_nr_connection_directions=2, ), schedule_generator=sparse_schedule_generator(), diff --git a/flatland/core/grid/grid_utils.py b/flatland/core/grid/grid_utils.py index d39fc8a771cf7f78d9203f2f92632694be26ed91..7c34796cb045165819cc442590565c69d17cad91 100644 --- a/flatland/core/grid/grid_utils.py +++ b/flatland/core/grid/grid_utils.py @@ -296,3 +296,25 @@ def coordinate_to_position(depth, coords): def distance_on_rail(pos1, pos2): return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2)) + + +def closest_direction(pos1, pos2): + """ + Returns the closest direction orientation of position 2 relative to position 1 + :param pos1: position we are interested in + :param pos2: position we want to know it is facing + :return: direction NESW as int N:0 E:1 S:2 W:3 + """ + diff_vec = np.array((pos1[0] - pos2[0], pos1[1] - pos2[1])) + axis = np.argmax(np.power(diff_vec, 2)) + direction = np.sign(diff_vec[axis]) + if axis == 0: + if direction > 0: + return 2 + else: + return 0 + else: + if direction > 0: + return 3 + else: + return 1 diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 117862a71593f8483928323574d8977898cfdc9c..33fc408d67e52557aa89b5f8c8526557dd608b2a 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -6,7 +6,7 @@ import msgpack import numpy as np from flatland.core.grid.grid4_utils import get_direction, mirror -from flatland.core.grid.grid_utils import distance_on_rail +from flatland.core.grid.grid_utils import distance_on_rail, closest_direction from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes @@ -871,10 +871,20 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n max_nr_connection_directions=2): connection_points = [] for node_position in node_positions: + connection_sides_idx = np.sort( np.random.choice(np.arange(4), size=max_nr_connection_directions, replace=False)) - connections_per_direction = np.zeros(4, dtype=int) + # Chose the directions where close cities are situated + neighb_dist = [] + for neighb_node in node_positions: + neighb_dist.append(distance_on_rail(node_position, neighb_node)) + closest_neighb_idx = argsort(neighb_dist) + connection_sides_idx = [] + for idx in range(1, max_nr_connection_directions + 1): + connection_sides_idx.append(closest_direction(node_position, node_positions[closest_neighb_idx[idx]])) + + connections_per_direction = np.zeros(4, dtype=int) # set the number of connection points for each direction for idx in connection_sides_idx: connections_per_direction[idx] = max_nr_connection_points