Skip to content
Snippets Groups Projects
Commit 67181fcb authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Merge branch '188_refining_generator' of gitlab.aicrowd.com:flatland/flatland...

Merge branch '188_refining_generator' of gitlab.aicrowd.com:flatland/flatland into 188_refining_generator
parents c380d754 9bbf2ed6
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,8 @@ from typing import Tuple, Callable, List, Type
import numpy as np
from flatland.core.grid.grid4 import Grid4TransitionsEnum
Vector2D: Type = Tuple[float, float]
IntVector2D: Type = Tuple[int, int]
......@@ -296,7 +298,7 @@ def distance_on_rail(pos1, pos2, metric="Euclidean"):
return np.abs(pos1[0] - pos2[0]) + np.abs(pos1[1] - pos2[1])
def direction_to_point(pos1, pos2):
def direction_to_city(pos1: IntVector2D, pos2: IntVector2D) -> Grid4TransitionsEnum:
"""
Returns the closest direction orientation of position 2 relative to position 1
:param pos1: position we are interested in
......@@ -308,11 +310,11 @@ def direction_to_point(pos1, pos2):
direction = np.sign(diff_vec[axis])
if axis == 0:
if direction > 0:
return 0
return Grid4TransitionsEnum.NORTH
else:
return 2
return Grid4TransitionsEnum.SOUTH
else:
if direction > 0:
return 3
return Grid4TransitionsEnum.WEST
else:
return 1
return Grid4TransitionsEnum.EAST
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import time
import warnings
from typing import Callable, Tuple, Optional, Dict, List
from typing import Callable, Tuple, Optional, Dict, List, Any
import msgpack
import numpy as np
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_city, IntVector2DArray, IntVector2D, \
Vec2dOperations
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_cities, connect_straigt_line
......@@ -545,8 +546,6 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
:return: generator
"""
DEBUG_PRINT_TIMING = False
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
np.random.seed(seed + num_resets)
......@@ -560,7 +559,6 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities
# Evenly distribute cities
city_time_start = time.time()
if grid_mode:
city_positions, city_cells = _generate_evenly_distr_city_positions(max_num_cities, city_radius, width, height)
else:
......@@ -568,62 +566,40 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# reduce num_cities if less were generated in random mode
num_cities = len(city_positions)
if DEBUG_PRINT_TIMING:
print("City position time", time.time() - city_time_start, "Seconds")
# Set up connection points for all cities
city_connection_time = time.time()
inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_city_connection_points(
city_positions, city_radius, rails_between_cities,
rails_in_city)
if DEBUG_PRINT_TIMING:
print("Connection points", time.time() - city_connection_time)
# Connect the cities through the connection points
city_connection_time = time.time()
inter_city_lines = _connect_cities(city_positions, outer_connection_points, city_cells,
rail_trans, grid_map)
if DEBUG_PRINT_TIMING:
print("City connection time", time.time() - city_connection_time)
# Build inner cities
city_build_time = time.time()
through_tracks, free_tracks = _build_inner_cities(city_positions, inner_connection_points,
outer_connection_points,
city_radius,
rail_trans,
grid_map)
if DEBUG_PRINT_TIMING:
print("City build time", time.time() - city_build_time)
through_tracks, free_rails = _build_inner_cities(city_positions, inner_connection_points,
outer_connection_points,
rail_trans,
grid_map)
# Populate cities
train_station_time = time.time()
train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_tracks,
grid_map)
if DEBUG_PRINT_TIMING:
print("Trainstation placing time", time.time() - train_station_time)
train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_rails)
# Fix all transition elements
grid_fix_time = time.time()
_fix_transitions(city_cells, inter_city_lines, grid_map)
if DEBUG_PRINT_TIMING:
print("Grid fix time", time.time() - grid_fix_time)
# Generate start target pairs
schedule_time = time.time()
agent_start_targets_cities, num_agents = _generate_start_target_pairs(num_agents, num_cities, train_stations,
city_orientations)
if DEBUG_PRINT_TIMING:
print("Schedule time", time.time() - schedule_time)
return grid_map, {'agents_hints': {
'num_agents': num_agents,
'agent_start_targets_cities': agent_start_targets_cities,
'agent_start_targets_nodes': agent_start_targets_cities,
'train_stations': train_stations,
'city_orientations': city_orientations
}}
def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
city_positions: List[Tuple[int, int]] = []
city_cells: List[Tuple[int, int]] = []
def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (IntVector2DArray, IntVector2DArray):
city_positions: IntVector2DArray = []
city_cells: IntVector2DArray = []
for city_idx in range(num_cities):
too_close = True
tries = 0
......@@ -650,7 +626,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
break
return city_positions, city_cells
def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (IntVector2DArray, IntVector2DArray):
aspect_ratio = height / width
cities_per_row = int(np.ceil(np.sqrt(num_cities * aspect_ratio)))
cities_per_col = int(np.ceil(num_cities / cities_per_row))
......@@ -665,8 +641,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
city_cells.extend(_get_cells_in_city(city_positions[-1], city_radius))
return city_positions, city_cells
def _generate_city_connection_points(city_positions: List[Tuple[int, int]], city_radius: int,
rails_between_cities: int, rails_in_city: int = 2):
def _generate_city_connection_points(city_positions: IntVector2DArray, city_radius: int, rails_between_cities: int, rails_in_city: int = 2):
inner_connection_points = []
outer_connection_points = []
connection_info = []
......@@ -675,8 +650,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# Chose the directions where close cities are situated
neighb_dist = []
for neighb_city in city_positions:
neighb_dist.append(distance_on_rail(city_position, neighb_city, metric="Manhattan"))
for neighbour_city in city_positions:
neighb_dist.append(Vec2dOperations.get_manhattan_distance(city_position, neighbour_city))
closest_neighb_idx = argsort(neighb_dist)
# Store the directions to these neighbours and orient city to face closest neighbour
......@@ -685,7 +660,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
if grid_mode:
current_closest_direction = np.random.randint(4)
else:
current_closest_direction = direction_to_point(city_position, city_positions[closest_neighb_idx[idx]])
current_closest_direction = direction_to_city(city_position, city_positions[closest_neighb_idx[idx]])
connection_sides_idx.append(current_closest_direction)
connection_sides_idx.append((current_closest_direction + 2) % 4)
city_orientations.append(current_closest_direction)
......@@ -723,7 +698,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
connection_info.append(connections_per_direction)
return inner_connection_points, outer_connection_points, connection_info, city_orientations
def _connect_cities(city_positions: List[Tuple[int, int]], connection_points, city_cells: List[Tuple[int, int]],
def _connect_cities(city_positions: IntVector2DArray, connection_points, city_cells: IntVector2DArray,
rail_trans, grid_map):
"""
Function to connect the different cities through their connection points
......@@ -750,8 +725,8 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
for dir in range(4):
current_points = connection_points[neighb_idx][dir]
for tmp_in_connection_point in current_points:
tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point,
metric="Manhattan")
tmp_dist = Vec2dOperations.get_manhattan_distance(tmp_out_connection_point,
tmp_in_connection_point)
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
neighb_connection_point = tmp_in_connection_point
......@@ -763,7 +738,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
return all_paths
def _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, city_radius, rail_trans,
def _build_inner_cities(city_positions, inner_connection_points, outer_connection_points, rail_trans,
grid_map):
"""
Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
......@@ -806,29 +781,27 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
free_tracks[current_city].append(current_track)
return through_path_cells, free_tracks
def _set_trainstation_positions(city_positions, city_radius, free_tracks, grid_map):
def _set_trainstation_positions(city_positions: IntVector2DArray, city_radius: int, free_rails):
"""
:param city_positions:
:param num_trainstations:
:return:
"""
nb_cities = len(city_positions)
train_stations = [[] for i in range(nb_cities)]
left = 0
right = 0
num_cities = len(city_positions)
train_stations = [[] for i in range(num_cities)]
built_num_trainstations = 0
for current_city in range(len(city_positions)):
for track_nbr in range(len(free_tracks[current_city])):
possible_location = free_tracks[current_city][track_nbr][city_radius]
for track_nbr in range(len(free_rails[current_city])):
possible_location = free_rails[current_city][track_nbr][city_radius]
train_stations[current_city].append((possible_location, track_nbr))
return train_stations, built_num_trainstations
def _generate_start_target_pairs(num_agents, nb_cities, train_stations, city_orientation):
def _generate_start_target_pairs(num_agents, num_cities, train_stations, city_orientation):
"""
Fill the trainstation positions with targets and goals
:param num_agents:
:param nb_cities:
:param num_cities:
:param train_stations:
:return:
"""
......@@ -839,7 +812,7 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
# Slot availability in city
city_available_start = []
city_available_target = []
for city_idx in range(nb_cities):
for city_idx in range(num_cities):
city_available_start.append(len(train_stations[city_idx]))
city_available_target.append(len(train_stations[city_idx]))
......@@ -879,35 +852,34 @@ def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_
for cell in range(rails_to_fix_cnt):
grid_map.fix_transitions((rails_to_fix[2 * cell], rails_to_fix[2 * cell + 1]))
def _closest_neighbour_in_direction(current_city_idx: int, city_positions: List[Tuple[int, int]]):
def _closest_neighbour_in_direction(current_city_idx: int, city_positions: IntVector2DArray) -> List[int]:
"""
Returns indices of closest neighbours in every direction NESW
Returns indices of closest neighbour in every direction NESW
:param current_city_idx: Index of city in city_positions list
:param city_positions: list of all points being considered
:return: list of index of closest neighbours in all directions
:return: list of index of closest neighbour in all directions
"""
city_dist = []
closest_neighb = [None for i in range(4)]
for av_city in range(len(city_positions)):
city_dist.append(
distance_on_rail(city_positions[current_city_idx], city_positions[av_city], metric="Manhattan"))
sorted_neighbours = np.argsort(city_dist)
city_distances = []
closest_neighbour: List[int] = [None for i in range(4)]
for city_idx in range(len(city_positions)):
city_distances.append(Vec2dOperations.get_manhattan_distance(city_positions[current_city_idx], city_positions[city_idx]))
sorted_neighbours = np.argsort(city_distances)
direction_set = 0
for neighb in sorted_neighbours[1:]:
direction_to_neighb = direction_to_point(city_positions[current_city_idx], city_positions[neighb])
if closest_neighb[direction_to_neighb] == None:
closest_neighb[direction_to_neighb] = neighb
for neighbour in sorted_neighbours[1:]:
direction_to_neighbour = direction_to_city(city_positions[current_city_idx], city_positions[neighbour])
if closest_neighbour[direction_to_neighbour] == None:
closest_neighbour[direction_to_neighbour] = neighbour
direction_set += 1
if direction_set == 4:
return closest_neighb
return closest_neighb
return closest_neighbour
return closest_neighbour
def argsort(seq):
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
def _get_cells_in_city(center: Tuple[int, int], radius: int) -> List[Tuple[int, int]]:
def _get_cells_in_city(center: IntVector2D, radius: int) -> IntVector2DArray:
"""
Parameters
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment