diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index ab2d02912bf61b796b5c7ced312918b7cf56615c..6d1205b3b505f77dd57cd670e669c93401200b5d 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,3 +1,4 @@ +import warnings from enum import IntEnum import msgpack @@ -8,7 +9,7 @@ from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.grid4_generators_utils import connect_rail +from flatland.envs.grid4_generators_utils import connect_rail, connect_from_nodes, connect_nodes, connect_to_nodes from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -692,7 +693,7 @@ def realistic_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dis return generator -def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbours=2, min_node_dist=20, node_radius=2, +def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, max_neigbours=2, min_node_dist=20, node_radius=2, seed=0): ''' @@ -708,7 +709,7 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbour if num_agents > nr_train_stations: num_agents = nr_train_stations - print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") + warnings.warn("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") rail_trans = RailEnvTransitions() grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans) rail_array = grid_map.grid @@ -719,21 +720,52 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbour node_positions = [] for node_idx in range(nr_nodes): to_close = True + tries = 0 while to_close: - x_tmp = np.random.randint(width) - y_tmp = np.random.randint(height) + x_tmp = np.random.randint(height) + y_tmp = np.random.randint(width) to_close = False for node_pos in node_positions: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: to_close = True if not to_close: node_positions.append((x_tmp, y_tmp)) + tries += 1 + if tries > 100: + warnings.warn("Could not set nodes, please change initial parameters!!!!") + break + + # Chose node connection + available_nodes = np.arange(nr_nodes) + current_node = 0 + node_stack = [current_node] + + while len(node_stack) > 0: + current_node = node_stack[0] + delete_idx = np.where(available_nodes == current_node) + available_nodes = np.delete(available_nodes, delete_idx, 0) + + # Get random number of neighbors + num_neighb = 2 # np.random.randint(1, max_neigbours) + if len(available_nodes) >= num_neighb: + connected_neighb_idx = np.random.choice(available_nodes, num_neighb, replace=False) + else: + connected_neighb_idx = available_nodes + + for neighb in connected_neighb_idx: + if neighb not in node_stack: + node_stack.append(neighb) + new_path = connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) + node_stack.pop(0) + + # Generate start and target node directory for all agents agent_start_targets_nodes = [] for agent_idx in range(num_agents): start_target_tuple = np.random.choice(nr_nodes, 2, replace=False) agent_start_targets_nodes.append(start_target_tuple) + # Generate actual start and target locations from around nodes agents_position = [] agents_target = [] @@ -741,42 +773,39 @@ def sparse_rail_generator(nr_train_stations=1, nr_nodes=100, mean_node_neighbour agent_idx = 0 for start_target in agent_start_targets_nodes: start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0, - width - 1) - start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0, height - 1) + start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0, + width - 1) target_x = np.clip(node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, - width - 1) - target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, height - 1) + target_y = np.clip(node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, + width - 1) if agent_idx == 0: - agents_position.append((start_y, start_x)) - agents_target.append((target_y, target_x)) + agents_position.append((start_x, start_y)) + agents_target.append((target_x, target_y)) else: - while ((start_x, start_y) in agents_position or (target_x, target_y) in agents_target): + # Make sure we don't put to starts or targets on same cell + while (start_x, start_y) in agents_position or (target_x, target_y) in agents_target: start_x = np.clip(node_positions[start_target[0]][0] + np.random.randint(-node_radius, node_radius), 0, - width - 1) + height - 1) start_y = np.clip(node_positions[start_target[0]][1] + np.random.randint(-node_radius, node_radius), 0, - height - 1) + width - 1) target_x = np.clip( node_positions[start_target[1]][0] + np.random.randint(-node_radius, node_radius), 0, - width - 1) + height - 1) target_y = np.clip( node_positions[start_target[1]][1] + np.random.randint(-node_radius, node_radius), 0, - height - 1) - agents_position.append((start_y, start_x)) - agents_target.append((target_y, target_x)) - + width - 1) + agents_position.append((start_x, start_y)) + agents_target.append((target_x, target_y)) + new_path = connect_to_nodes(rail_trans, rail_array, agents_position[agent_idx], + node_positions[start_target[0]]) + new_path = connect_from_nodes(rail_trans, rail_array, node_positions[start_target[1]], + agents_target[agent_idx]) agents_direction.append(0) agent_idx += 1 - - print(agents_position) - print(agents_target) - print(node_positions) - for n in node_positions: - for m in node_positions: - print(distance_on_rail(n, m)) return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator diff --git a/flatland/envs/grid4_generators_utils.py b/flatland/envs/grid4_generators_utils.py index dedd76b6bfd04c13ad59092adfefbde6ae98fc18..9116adb6639fc699d00b45b58241a4bdcdfe9c74 100644 --- a/flatland/envs/grid4_generators_utils.py +++ b/flatland/envs/grid4_generators_utils.py @@ -57,6 +57,143 @@ def connect_rail(rail_trans, rail_array, start, end): return path +def connect_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def connect_from_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, mirror(new_dir), 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + + +def connect_to_nodes(rail_trans, rail_array, start, end): + """ + Creates a new path [start,end] in rail_array, based on rail_trans. + """ + # in the worst case we will need to do a A* search, so we might as well set that up + path = a_star(rail_trans, rail_array, start, end) + if len(path) < 2: + return [] + current_dir = get_direction(path[0], path[1]) + end_pos = path[-1] + for index in range(len(path) - 1): + current_pos = path[index] + new_pos = path[index + 1] + new_dir = get_direction(current_pos, new_pos) + + new_trans = rail_array[current_pos] + if index == 0: + if new_trans == 0: + # end-point + # need to flip direction because of how end points are defined + new_trans = rail_trans.set_transition(new_trans, mirror(current_dir), new_dir, 1) + else: + # into existing rail + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + else: + # set the forward path + new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1) + # set the backwards path + new_trans = rail_trans.set_transition(new_trans, mirror(new_dir), mirror(current_dir), 1) + rail_array[current_pos] = new_trans + + if new_pos == end_pos: + # setup end pos setup + new_trans_e = rail_array[end_pos] + if new_trans_e == 0: + # end-point + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + else: + # into existing rail + new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1) + rail_array[end_pos] = new_trans_e + + current_dir = new_dir + return path + def get_rnd_agents_pos_tgt_dir_on_rail(rail, num_agents): """ Given a `rail' GridTransitionMap, return a random placement of agents (initial position, direction and target). diff --git a/tests/test_flatland_env_sparse_rail_generator.py b/tests/test_flatland_env_sparse_rail_generator.py index 8934a4022611364f0adc4df806376fd43e5bc8f4..b64aaa640d27f24b1ab2bd87e30a29784f7787e7 100644 --- a/tests/test_flatland_env_sparse_rail_generator.py +++ b/tests/test_flatland_env_sparse_rail_generator.py @@ -1,14 +1,20 @@ +import time + from flatland.envs.generators import sparse_rail_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool def test_sparse_rail_generator(): + env = RailEnv(width=20, height=20, - rail_generator=sparse_rail_generator(nr_train_stations=10, nr_nodes=5, min_node_dist=10, + rail_generator=sparse_rail_generator(nr_train_stations=3, nr_nodes=2, min_node_dist=5, node_radius=4), - number_of_agents=10, + number_of_agents=3, obs_builder_object=GlobalObsForRailEnv()) # reset to initialize agents_static - env.reset() + env_renderer = RenderTool(env, gl="PILSVG", ) + env_renderer.render_env(show=True, show_observations=True, show_predictions=False) + time.sleep(10)