From 59404cf63a7f09e935226cf9a9a531070e1e19d2 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 30 Sep 2019 19:52:27 -0400 Subject: [PATCH] refactoring of how we chose start&target for agents. and how we place them --- examples/flatland_2_0_example.py | 2 +- flatland/envs/rail_generators.py | 31 +++++++++--------- flatland/envs/schedule_generators.py | 49 ++++++++++------------------ 3 files changed, 35 insertions(+), 47 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 3bc31626..2a6d90b0 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -37,7 +37,7 @@ env = RailEnv(width=50, max_tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(), - number_of_agents=15, + number_of_agents=50, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 0bdf5165..ac898d43 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -545,7 +545,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, :param seed: Random seed to initiate rail :return: generator """ - G = nx.Graph() + G = nx.DiGraph() def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct: @@ -605,10 +605,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, print("City build time", time.time() - city_build_time) # Populate cities train_station_time = time.time() - train_stations, track_numbers, built_num_trainstation = _set_trainstation_positions(node_positions, - city_orientations, - through_tracks, - node_radius, grid_map) + train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, + city_orientations, + through_tracks, + node_radius, grid_map) print("Trainstation placing time", time.time() - train_station_time) # Adjust the number of agents if you could not build enough trainstations @@ -629,8 +629,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, 'num_agents': num_agents, 'agent_start_targets_nodes': agent_start_targets_nodes, 'train_stations': train_stations, - 'city_orientations': city_orientations, - 'track_numbers': track_numbers + 'city_orientations': city_orientations }} def _generate_random_node_positions(nb_nodes, node_radius, height, width): @@ -763,7 +762,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, tmp_direction = (direction - 1) % 4 while neighb_idx is None: neighb_idx = neighbours[tmp_direction] - tmp_direction = (tmp_direction + 1) % 4 + tmp_direction = (direction + 1) % 4 connected_to_city.append(neighb_idx) for tmp_out_connection_point in connection_points[current_node][direction]: @@ -780,7 +779,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, neighb_connection_point = tmp_in_connection_point new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point, city_cells) - G.add_edge(current_node, neighb_idx) + G.add_edge(current_node, neighb_idx, direction=direction, length=len(new_line)) all_paths.extend(new_line) direction += 1 @@ -847,7 +846,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, """ nb_nodes = len(node_positions) train_stations = [[] for i in range(nb_nodes)] - train_station_orientations = [[] for i in range(nb_nodes)] + built_num_trainstations = 0 for current_city in range(len(node_positions)): for possible_location in _city_cells(node_positions[current_city], node_radius - 1): @@ -862,9 +861,9 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, built_num_trainstations += 1 track_nbr = _track_number(node_positions[current_city], city_orientations[current_city], possible_location) - train_stations[current_city].append(possible_location) - train_station_orientations[current_city].append(track_nbr) - return train_stations, train_station_orientations, built_num_trainstations + train_stations[current_city].append((possible_location, track_nbr)) + + return train_stations, built_num_trainstations def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): """ @@ -912,8 +911,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, if found_agent_pair: node_available_start[start_node] -= 1 node_available_target[target_node] -= 1 - agent_start_targets_nodes.append((start_node, target_node)) - print(agent_idx, "has connection", nx.astar_path(G, start_node, target_node)) + shortest_path = nx.astar_path(G, start_node, target_node, weight='length') + start_orientation = nx.get_edge_attributes(G, "direction")[(shortest_path[0], shortest_path[1])] + agent_start_targets_nodes.append((start_node, target_node, start_orientation)) + else: num_agents -= 1 return agent_start_targets_nodes, num_agents diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index 86a41af0..cd2f3996 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -63,7 +63,6 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> agent_start_targets_nodes = hints['agent_start_targets_nodes'] max_num_agents = hints['num_agents'] city_orientations = hints['city_orientations'] - track_numbers = hints['track_numbers'] if num_agents > max_num_agents: num_agents = max_num_agents warnings.warn("Too many agents! Changes number of agents.") @@ -77,43 +76,31 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> current_target_node = agent_start_targets_nodes[agent_idx][1] target_station_idx = np.random.randint(len(train_stations[current_target_node])) target = train_stations[current_target_node][target_station_idx] - tries = 0 - while (target[0], target[1]) in agents_target: - target_station_idx = np.random.randint(len(train_stations[current_target_node])) - target = train_stations[current_target_node][target_station_idx] - tries += 1 - if tries > 100: - warnings.warn("Could not set target position, removing an agent") - break - agents_target.append((target[0], target[1])) + train_stations[current_target_node].pop(target_station_idx) + agents_target.append((target[0][0], target[0][1])) - # Set start for agent + # Set start for agent and corresponding orientation current_start_node = agent_start_targets_nodes[agent_idx][0] - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - current_track_nbr = track_numbers[current_start_node][start_station_idx] - tries = 0 - while (start[0], start[1]) in agents_position: - tries += 1 - if tries > 100: - warnings.warn("Could not set start position, please change initial parameters!!!!") + agent_start_orientation = agent_start_targets_nodes[agent_idx][2] + + # Place the agent on the corresponding track + if city_orientations[current_start_node] == agent_start_orientation: + track_to_use = 0 + else: + track_to_use = 1 + + for i in range(len(train_stations[current_start_node])): + if train_stations[current_start_node][i][1] == track_to_use: + start_station_idx = i break - start_station_idx = np.random.randint(len(train_stations[current_start_node])) - start = train_stations[current_start_node][start_station_idx] - current_track_nbr = track_numbers[current_start_node][start_station_idx] - agents_position.append((start[0], start[1])) + start = train_stations[current_start_node][start_station_idx] + train_stations[current_start_node].pop(start_station_idx) + agents_position.append((start[0][0], start[0][1])) + agents_direction.append(agent_start_orientation) # Orient the agent correctly - if current_track_nbr % 2 != 0: - current_orientation = city_orientations[current_start_node] - agents_direction.append(current_orientation) - else: - current_orientation = (city_orientations[current_start_node] + 2) % 2 - agents_direction.append(current_orientation) - if not rail.check_path_exists(start, current_orientation, target): - print("No path") if speed_ratio_map: speeds = speed_initialization_helper(num_agents, speed_ratio_map) -- GitLab