diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index 6396f4123adf895c381c2a21f6d8dc6e4e823b92..2a6d90b02de79ca40c09bdcd0c80888386b4c8f8 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -1,5 +1,3 @@ -import time - import numpy as np from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv @@ -39,7 +37,7 @@ env = RailEnv(width=50, max_tracks_in_city=4, ), schedule_generator=sparse_schedule_generator(), - number_of_agents=10, + number_of_agents=50, stochastic_data=stochastic_data, # Malfunction data generator obs_builder_object=GlobalObsForRailEnv()) @@ -113,7 +111,6 @@ for step in range(500): # reward and whether their are done next_obs, all_rewards, done, _ = env.step(action_dict) env_renderer.render_env(show=True, show_observations=False, show_predictions=False) - time.sleep(100) frame_step += 1 # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 5fd8ba37be1fa763b199655b5d29419deb422447..08982fc6e9786d980d95e820a22d472149d3b632 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -575,7 +575,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, print("City position time", time.time() - node_time_start, "Seconds") # Set up connection points for all cities node_connection_time = time.time() - inner_connection_points, outer_connection_points, connection_info = _generate_node_connection_points( + inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_node_connection_points( node_positions, node_radius, max_inter_city_rails_allowed, max_tracks_in_city) print("Connection points", time.time() - node_connection_time) @@ -594,8 +594,10 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, print("City build time", time.time() - city_build_time) # Populate cities train_station_time = time.time() - train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, through_tracks, - node_radius, grid_map) + train_stations, track_numbers, built_num_trainstation = _set_trainstation_positions(node_positions, + city_orientations, + through_tracks, + node_radius, grid_map) print("Trainstation placing time", time.time() - train_station_time) # Adjust the number of agents if you could not build enough trainstations @@ -615,7 +617,9 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, return grid_map, {'agents_hints': { 'num_agents': num_agents, 'agent_start_targets_nodes': agent_start_targets_nodes, - 'train_stations': train_stations + 'train_stations': train_stations, + 'city_orientations': city_orientations, + 'track_numbers': track_numbers }} def _generate_random_node_positions(nb_nodes, node_radius, height, width): @@ -669,6 +673,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, inner_connection_points = [] outer_connection_points = [] connection_info = [] + city_orientations = [] for node_position in node_positions: # Chose the directions where close cities are situated @@ -683,7 +688,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]]) connection_sides_idx.append(current_closest_direction) connection_sides_idx.append((current_closest_direction + 2) % 4) - + city_orientations.append(current_closest_direction) # set the number of tracks within a city, at least 2 tracks per city connections_per_direction = np.zeros(4, dtype=int) nr_of_connection_points = np.random.randint(2, tracks_in_city + 1) @@ -716,7 +721,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, inner_connection_points.append(connection_points_coordinates_inner) outer_connection_points.append(connection_points_coordinates_outer) connection_info.append(connections_per_direction) - return inner_connection_points, outer_connection_points, connection_info + return inner_connection_points, outer_connection_points, connection_info, city_orientations def _connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map): @@ -792,7 +797,6 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, break opposite_boarder = (boarder + 2) % 4 - track_direction = opposite_boarder boarder_one = inner_connection_points[current_city][boarder] boarder_two = inner_connection_points[current_city][opposite_boarder] @@ -820,7 +824,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, return through_path_cells - def _set_trainstation_positions(node_positions, through_tracks, node_radius, grid_map): + def _set_trainstation_positions(node_positions, city_orientations, through_tracks, node_radius, grid_map): """ :param node_positions: @@ -829,6 +833,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, """ nb_nodes = len(node_positions) train_stations = [[] for i in range(nb_nodes)] + train_station_orientations = [[] for i in range(nb_nodes)] built_num_trainstations = 0 for current_city in range(len(node_positions)): for possible_location in _city_cells(node_positions[current_city], node_radius - 1): @@ -841,8 +846,11 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, cell_type = cell_type >> 1 if 1 <= nbits <= 2: built_num_trainstations += 1 + track_nbr = _track_number(node_positions[current_city], city_orientations[current_city], + possible_location) train_stations[current_city].append(possible_location) - return train_stations, built_num_trainstations + train_station_orientations[current_city].append(track_nbr) + return train_stations, train_station_orientations, built_num_trainstations def _generate_start_target_pairs(num_agents, nb_nodes, train_stations): """ @@ -970,4 +978,17 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, def _city_overlap(center_1, center_2, radius): return (np.abs(center_1[0] - center_2[0]) < radius and np.abs(center_1[1] - center_2[1]) < radius) + def _track_number(city_position, city_orientation, position): + """ + FUnction that tells you if you are on even or uneven track number + :param city_position: + :param city_orientation: + :param position: + :return: + """ + if city_orientation % 2 == 0: + return np.abs(city_position[1] - position[1]) % 2 + else: + return np.abs(city_position[0] - position[0]) % 2 + return generator diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index f41db2b0c453adfb578a6de489f780071287dc1b..eb56cc56c434a6b53aff08b98a780a5195806319 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -62,6 +62,8 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> train_stations = hints['train_stations'] agent_start_targets_nodes = hints['agent_start_targets_nodes'] max_num_agents = hints['num_agents'] + city_orientations = hints['city_orientations'] + track_numbers = hints['track_numbers'] if num_agents > max_num_agents: num_agents = max_num_agents warnings.warn("Too many agents! Changes number of agents.") @@ -89,6 +91,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> current_start_node = agent_start_targets_nodes[agent_idx][0] start_station_idx = np.random.randint(len(train_stations[current_start_node])) start = train_stations[current_start_node][start_station_idx] + current_track_nbr = track_numbers[current_start_node][start_station_idx] tries = 0 while (start[0], start[1]) in agents_position: tries += 1 @@ -97,15 +100,16 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None) -> break start_station_idx = np.random.randint(len(train_stations[current_start_node])) start = train_stations[current_start_node][start_station_idx] + current_track_nbr = track_numbers[current_start_node][start_station_idx] agents_position.append((start[0], start[1])) # Orient the agent correctly - for orientation in range(4): - transitions = rail.get_transitions(start[0], start[1], orientation) - if any(transitions) > 0 and rail.check_path_exists(start, orientation, target): - agents_direction.append(orientation) - break + if current_track_nbr % 2 != 0: + agents_direction.append(city_orientations[current_start_node]) + else: + agents_direction.append((city_orientations[current_start_node] + 2) % 2) + if speed_ratio_map: speeds = speed_initialization_helper(num_agents, speed_ratio_map)