Forked from
Flatland / Flatland
1102 commits behind the upstream repository.
-
Erik Nygren authoredErik Nygren authored
rail_generators.py 39.98 KiB
"""Rail generators (infrastructure manager, "Infrastrukturbetreiber")."""
import warnings
from typing import Callable, Tuple, Optional, Dict, List, Any
import msgpack
import numpy as np
from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_cities
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
def empty_rail_generator() -> RailGenerator:
"""
Returns a generator which returns an empty rail mail with no agents.
Primarily used by the editor
"""
def generator(width: int, height: int, num_agents: int = 0, num_resets: int = 0) -> RailGeneratorProduct:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
return grid_map, None
return generator
def complex_rail_generator(nr_start_goal=1,
nr_extra=100,
min_dist=20,
max_dist=99999,
seed=0) -> RailGenerator:
"""
complex_rail_generator
Parameters
----------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
grid_map = GridTransitionMap(width=width, height=height, transitions=RailEnvTransitions())
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# generate rail array
# step 1:
# - generate a start and goal position
# - validate min/max distance allowed
# - validate that start/goals are not placed too close to other start/goals
# - draw a rail from [start,goal]
# - if rail crosses existing rail then validate new connection
# - possibility that this fails to create a path to goal
# - on failure generate new start/goal
#
# step 2:
# - add more rails to map randomly between cells that have rails
# - validate all new rails, on failure don't add new rails
#
# step 3:
# - return transition map + list of [start_pos, start_dir, goal_pos] points
#
rail_trans = grid_map.transitions
start_goal = []
start_dir = []
nr_created = 0
created_sanity = 0
sanity_max = 9000
while nr_created < nr_start_goal and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos is empty?
if rail_array[goal] != 0 or rail_array[start] != 0:
continue
# check min/max distance
dist_sg = distance_on_rail(start, goal)
if dist_sg < min_dist:
continue
if dist_sg > max_dist:
continue
# check distance to existing points
sg_new = [start, goal]
def check_all_dist(sg_new):
for sg in start_goal:
for i in range(2):
for j in range(2):
dist = distance_on_rail(sg_new[i], sg[j])
if dist < 2:
return False
return True
if check_all_dist(sg_new):
all_ok = True
break
if not all_ok:
# we might as well give up at this point
break
new_path = connect_rail(rail_trans, grid_map, start, goal)
if len(new_path) >= 2:
nr_created += 1
start_goal.append([start, goal])
start_dir.append(mirror(get_direction(new_path[0], new_path[1])))
else:
# after too many failures we will give up
created_sanity += 1
# add extra connections between existing rail
created_sanity = 0
nr_created = 0
while nr_created < nr_extra and created_sanity < sanity_max:
all_ok = False
for _ in range(sanity_max):
start = (np.random.randint(0, height), np.random.randint(0, width))
goal = (np.random.randint(0, height), np.random.randint(0, width))
# check to make sure start,goal pos are not empty
if rail_array[goal] == 0 or rail_array[start] == 0:
continue
else:
all_ok = True
break
if not all_ok:
break
new_path = connect_rail(rail_trans, grid_map, start, goal)
if len(new_path) >= 2:
nr_created += 1
return grid_map, {'agents_hints': {
'start_goal': start_goal,
'start_dir': start_dir
}}
return generator
def rail_from_manual_specifications_generator(rail_spec):
"""
Utility to convert a rail given by manual specification as a map of tuples
(cell_type, rotation), to a transition map with the correct 16-bit
transitions specifications.
Parameters
----------
rail_spec : list of list of tuples
List (rows) of lists (columns) of tuples, each specifying a rail_spec_of_cell for
the RailEnv environment as (cell_type, rotation), with rotation being
clock-wise and in [0, 90, 180, 270].
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width, height, num_agents, num_resets=0):
rail_env_transitions = RailEnvTransitions()
height = len(rail_spec)
width = len(rail_spec[0])
rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions)
for r in range(height):
for c in range(width):
rail_spec_of_cell = rail_spec[r][c]
index_basic_type_of_cell_ = rail_spec_of_cell[0]
rotation_cell_ = rail_spec_of_cell[1]
if index_basic_type_of_cell_ < 0 or index_basic_type_of_cell_ >= len(rail_env_transitions.transitions):
print("ERROR - invalid rail_spec_of_cell type=", index_basic_type_of_cell_)
return []
basic_type_of_cell_ = rail_env_transitions.transitions[index_basic_type_of_cell_]
effective_transition_cell = rail_env_transitions.rotate_transition(basic_type_of_cell_, rotation_cell_)
rail.set_transitions((r, c), effective_transition_cell)
return [rail, None]
return generator
def rail_from_file(filename, load_from_package=None) -> RailGenerator:
"""
Utility to load pickle file
Parameters
----------
filename : Pickle file generated by env.save() or editor
Returns
-------
function
Generator function that always returns a GridTransitionMap object with
the matrix of correct 16-bit bitmaps for each rail_spec_of_cell.
"""
def generator(width, height, num_agents, num_resets):
rail_env_transitions = RailEnvTransitions()
if load_from_package is not None:
from importlib_resources import read_binary
load_data = read_binary(load_from_package, filename)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
grid = np.array(data[b"grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid
if b"distance_map" in data.keys():
distance_map = data[b"distance_map"]
if len(distance_map) > 0:
return rail, {'distance_map': distance_map}
return [rail, None]
return generator
def rail_from_grid_transition_map(rail_map) -> RailGenerator:
"""
Utility to convert a rail given by a GridTransitionMap map with the correct
16-bit transitions specifications.
Parameters
----------
rail_map : GridTransitionMap object
GridTransitionMap object to return when the generator is called.
Returns
-------
function
Generator function that always returns the given `rail_map` object.
"""
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
return rail_map, None
return generator
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGenerator:
"""
Dummy random level generator:
- fill in cells at random in [width-2, height-2]
- keep filling cells in among the unfilled ones, such that all transitions\
are legit; if no cell can be filled in without violating some\
transitions, pick one among those that can satisfy most transitions\
(1,2,3 or 4), and delete (+mark to be re-filled) the cells that were\
incompatible.
- keep trying for a total number of insertions\
(e.g., (W-2)*(H-2)*MAX_REPETITIONS ); if no solution is found, empty the\
board and try again from scratch.
- finally pad the border of the map with dead-ends to avoid border issues.
Dead-ends are not allowed inside the grid, only at the border; however, if
no cell type can be inserted in a given cell (because of the neighboring
transitions), deadends are allowed if they solve the problem. This was
found to turn most un-genereatable levels into valid ones.
Parameters
----------
width : int
The width (number of cells) of the grid to generate.
height : int
The height (number of cells) of the grid to generate.
Returns
-------
numpy.ndarray of type numpy.uint16
The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
t_utils = RailEnvTransitions()
transition_probability = cell_type_relative_proportion
transitions_templates_ = []
transition_probabilities = []
for i in range(len(t_utils.transitions)): # don't include dead-ends
if t_utils.transitions[i] == int('0010000000000000', 2):
continue
all_transitions = 0
for dir_ in range(4):
trans = t_utils.get_transitions(t_utils.transitions[i], dir_)
all_transitions |= (trans[0] << 3) | \
(trans[1] << 2) | \
(trans[2] << 1) | \
(trans[3])
template = [int(x) for x in bin(all_transitions)[2:]]
template = [0] * (4 - len(template)) + template
# add all rotations
for rot in [0, 90, 180, 270]:
transitions_templates_.append((template,
t_utils.rotate_transition(
t_utils.transitions[i],
rot)))
transition_probabilities.append(transition_probability[i])
template = [template[-1]] + template[:-1]
def get_matching_templates(template):
ret = []
for i in range(len(transitions_templates_)):
is_match = True
for j in range(4):
if template[j] >= 0 and template[j] != transitions_templates_[i][0][j]:
is_match = False
break
if is_match:
ret.append((transitions_templates_[i][1], transition_probabilities[i]))
return ret
MAX_INSERTIONS = (width - 2) * (height - 2) * 10
MAX_ATTEMPTS_FROM_SCRATCH = 10
attempt_number = 0
while attempt_number < MAX_ATTEMPTS_FROM_SCRATCH:
cells_to_fill = []
rail = []
for r in range(height):
rail.append([None] * width)
if r > 0 and r < height - 1:
cells_to_fill = cells_to_fill + [(r, c) for c in range(1, width - 1)]
num_insertions = 0
while num_insertions < MAX_INSERTIONS and len(cells_to_fill) > 0:
cell = cells_to_fill[np.random.choice(len(cells_to_fill), 1)[0]]
cells_to_fill.remove(cell)
row = cell[0]
col = cell[1]
# look at its neighbors and see what are the possible transitions
# that can be chosen from, if any.
valid_template = [-1, -1, -1, -1]
for el in [(0, 2, (-1, 0)),
(1, 3, (0, 1)),
(2, 0, (1, 0)),
(3, 1, (0, -1))]: # N, E, S, W
neigh_trans = rail[row + el[2][0]][col + el[2][1]]
if neigh_trans is not None:
# select transition coming from facing direction el[1] and
# moving to direction el[1]
max_bit = 0
for k in range(4):
max_bit |= t_utils.get_transition(neigh_trans, k, el[1])
if max_bit:
valid_template[el[0]] = 1
else:
valid_template[el[0]] = 0
possible_cell_transitions = get_matching_templates(valid_template)
if len(possible_cell_transitions) == 0: # NO VALID TRANSITIONS
# no cell can be filled in without violating some transitions
# can a dead-end solve the problem?
if valid_template.count(1) == 1:
for k in range(4):
if valid_template[k] == 1:
rot = 0
if k == 0:
rot = 180
elif k == 1:
rot = 270
elif k == 2:
rot = 0
elif k == 3:
rot = 90
rail[row][col] = t_utils.rotate_transition(int('0010000000000000', 2), rot)
num_insertions += 1
break
else:
# can I get valid transitions by removing a single
# neighboring cell?
bestk = -1
besttrans = []
for k in range(4):
tmp_template = valid_template[:]
tmp_template[k] = -1
possible_cell_transitions = get_matching_templates(tmp_template)
if len(possible_cell_transitions) > len(besttrans):
besttrans = possible_cell_transitions
bestk = k
if bestk >= 0:
# Replace the corresponding cell with None, append it
# to cells to fill, fill in a transition in the current
# cell.
replace_row = row - 1
replace_col = col
if bestk == 1:
replace_row = row
replace_col = col + 1
elif bestk == 2:
replace_row = row + 1
replace_col = col
elif bestk == 3:
replace_row = row
replace_col = col - 1
cells_to_fill.append((replace_row, replace_col))
rail[replace_row][replace_col] = None
possible_transitions, possible_probabilities = zip(*besttrans)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
else:
print('WARNING: still nothing!')
rail[row][col] = int('0000000000000000', 2)
num_insertions += 1
pass
else:
possible_transitions, possible_probabilities = zip(*possible_cell_transitions)
possible_probabilities = [p / sum(possible_probabilities) for p in possible_probabilities]
rail[row][col] = np.random.choice(possible_transitions,
p=possible_probabilities)
num_insertions += 1
if num_insertions == MAX_INSERTIONS:
# Failed to generate a valid level; try again for a number of times
attempt_number += 1
else:
break
if attempt_number == MAX_ATTEMPTS_FROM_SCRATCH:
print('ERROR: failed to generate level')
# Finally pad the border of the map with dead-ends to avoid border issues;
# at most 1 transition in the neigh cell
for r in range(height):
# Check for transitions coming from [r][1] to WEST
max_bit = 0
neigh_trans = rail[r][1]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit:
rail[r][0] = t_utils.rotate_transition(int('0010000000000000', 2), 270)
else:
rail[r][0] = int('0000000000000000', 2)
# Check for transitions coming from [r][-2] to EAST
max_bit = 0
neigh_trans = rail[r][-2]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90)
else:
rail[r][-1] = int('0000000000000000', 2)
for c in range(width):
# Check for transitions coming from [1][c] to NORTH
max_bit = 0
neigh_trans = rail[1][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit:
rail[0][c] = int('0010000000000000', 2)
else:
rail[0][c] = int('0000000000000000', 2)
# Check for transitions coming from [-2][c] to SOUTH
max_bit = 0
neigh_trans = rail[-2][c]
if neigh_trans is not None:
for k in range(4):
neigh_trans_from_direction = (neigh_trans >> ((3 - k) * 4)) & (2 ** 4 - 1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit:
rail[-1][c] = t_utils.rotate_transition(int('0010000000000000', 2), 180)
else:
rail[-1][c] = int('0000000000000000', 2)
# For display only, wrong levels
for r in range(height):
for c in range(width):
if rail[r][c] is None:
rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail
return return_rail, None
return generator
def sparse_rail_generator(num_cities=5, min_node_dist=20, node_radius=2,
grid_mode=False, max_inter_city_rails=4, tracks_in_city=4,
seed=0) -> RailGenerator:
"""
This is a level generator which generates complex sparse rail configurations
:param num_cities: Number of city node (can hold trainstations)
:type num_cities: int
:param num_intersections: Number of intersection that city nodes can connect to
:param num_trainstations: Total number of trainstations in env
:param min_node_dist: Minimal distance between nodes
:param node_radius: Proximity of trainstations to center of city node
:param num_neighb: Number of neighbouring nodes each node connects to
:param grid_mode: True -> NOdes evenly distirbuted in env, False-> Random distribution of nodes
:param enhance_intersection: True -> Extra rail elements added at intersections
:param seed: Random Seed
:return: numpy.ndarray of type numpy.uint16 -- The matrix with the correct 16-bit bitmaps for each cell.
"""
def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct:
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
max_inter_city_rails_allowed = max_inter_city_rails
if max_inter_city_rails_allowed > tracks_in_city:
max_inter_city_rails_allowed = tracks_in_city
# Generate a set of nodes for the sparse network
# Try to connect cities to nodes first
city_positions = []
intersection_positions = []
# Evenly distribute cities and intersections
node_positions: List[Any] = None
nb_nodes = num_cities
if grid_mode:
node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, height, width)
else:
node_positions, city_cells = _generate_random_node_positions(nb_nodes, height, width)
# reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
nb_nodes = len(node_positions)
# Set up connection points for all cities
connection_points, connection_info = _generate_node_connection_points(node_positions, node_radius,
tracks_in_city)
# Connect the cities through the connection points
outer_connection_points = _connect_cities(node_positions, connection_points, connection_info, city_cells,
max_inter_city_rails_allowed,
rail_trans, grid_map)
# Build inner cities
through_tracks = _build_inner_cities(node_positions, connection_points, outer_connection_points, rail_trans,
grid_map)
# Populate cities
train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, through_tracks, grid_map)
# Adjust the number of agents if you could not build enough trainstations
if num_agents > built_num_trainstation:
num_agents = built_num_trainstation
warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
# Fix all transition elements
_fix_transitions(grid_map)
# Generate start target pairs
agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations)
return grid_map, {'agents_hints': {
'num_agents': num_agents,
'agent_start_targets_nodes': agent_start_targets_nodes,
'train_stations': train_stations
}}
def _generate_random_node_positions(nb_nodes, height, width):
node_positions = []
city_cells = []
for node_idx in range(nb_nodes):
to_close = True
tries = 0
while to_close:
x_tmp = node_radius + 1 + np.random.randint(height - 2 * node_radius - 1)
y_tmp = node_radius + 1 + np.random.randint(width - 2 * node_radius - 1)
to_close = False
# Check distance to nodes
for node_pos in node_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
to_close = True
if not to_close:
node_positions.append((x_tmp, y_tmp))
city_cells.extend(_city_cells(node_positions[-1], node_radius))
tries += 1
if tries > 100:
warnings.warn(
"Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
len(node_positions),
tries, nb_nodes))
break
return node_positions, city_cells
def _generate_node_positions_grid_mode(nb_nodes, height, width):
nodes_ratio = height / width
nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
x_positions = np.linspace(node_radius + 1, height - node_radius - 2, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius + 1, width - node_radius - 2, nodes_per_col, dtype=int)
node_positions = []
city_cells = []
for node_idx in range(nb_nodes):
x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row]
node_positions.append((x_tmp, y_tmp))
city_cells.extend(_city_cells(node_positions[-1], node_radius))
return node_positions, city_cells
def _generate_node_connection_points(node_positions, node_size, tracks_in_city=2):
connection_points = []
connection_info = []
if tracks_in_city > 2 * node_size - 1:
tracks_in_city = 2 * node_size - 1
for node_position in node_positions:
# Chose the directions where close cities are situated
neighb_dist = []
for neighb_node in node_positions:
neighb_dist.append(distance_on_rail(node_position, neighb_node))
closest_neighb_idx = argsort(neighb_dist)
# Store the directions to these neighbours and orient city to face closest neighbour
connection_sides_idx = []
idx = 1
current_closest_direction = direction_to_point(node_position, node_positions[closest_neighb_idx[idx]])
connection_sides_idx.append(current_closest_direction)
connection_sides_idx.append((current_closest_direction + 2) % 4)
# set the number of tracks within a city, at least 2 tracks per city
connections_per_direction = np.zeros(4, dtype=int)
nr_of_connection_points = np.random.randint(2, tracks_in_city + 1)
for idx in connection_sides_idx:
connections_per_direction[idx] = nr_of_connection_points
connection_points_coordinates = [[] for i in range(4)]
for direction in range(4):
connection_slots = np.arange(connections_per_direction[direction]) - int(
connections_per_direction[direction] / 2)
for connection_idx in range(connections_per_direction[direction]):
if direction == 0:
tmp_coordinates = (
node_position[0] - node_size, node_position[1] + connection_slots[connection_idx])
if direction == 1:
tmp_coordinates = (
node_position[0] + connection_slots[connection_idx], node_position[1] + node_size)
if direction == 2:
tmp_coordinates = (
node_position[0] + node_size, node_position[1] + connection_slots[connection_idx])
if direction == 3:
tmp_coordinates = (
node_position[0] + connection_slots[connection_idx], node_position[1] - node_size)
connection_points_coordinates[direction].append(tmp_coordinates)
connection_points.append(connection_points_coordinates)
connection_info.append(connections_per_direction)
return connection_points, connection_info
def _connect_cities(node_positions, connection_points, connection_info, city_cells, max_inter_city_rails_allowed,
rail_trans, grid_map):
"""
Function to connect the different cities through their connection points
:param node_positions: Positions of city centers
:param connection_points: Boarder connection points of cities
:param connection_info: Number of connection points per direction NESW
:param rail_trans: Transitions
:param grid_map: Grid map
:return:
"""
boarder_connections = [[] for i in range(len(node_positions))]
for current_node in np.arange(len(node_positions)):
direction = 0
connected_to_city = []
for nbr_connection_points in connection_info[current_node]:
if nbr_connection_points > 0:
neighb_idx = _closest_neigh_in_direction(current_node, direction, node_positions)
else:
direction += 1
continue
if neighb_idx is None or neighb_idx in connected_to_city:
node_dist = []
for av_node in node_positions:
node_dist.append(distance_on_rail(node_positions[current_node], av_node))
i = 1
neighbours = np.argsort(node_dist)
neighb_idx = neighbours[i]
while neighb_idx in connected_to_city:
i += 1
neighb_idx = neighbours[i]
connected_to_city.append(neighb_idx)
number_of_out_rails = np.random.randint(1, max_inter_city_rails_allowed + 1)
for tmp_out_connection_point in connection_points[current_node][direction][:number_of_out_rails]:
# Find closest connection point
min_connection_dist = np.inf
all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in
sublist]
for tmp_in_connection_point in all_neighb_connection_points:
tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
neighb_connection_point = tmp_in_connection_point
connect_cities(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point,
city_cells)
if tmp_out_connection_point not in boarder_connections[current_node]:
boarder_connections[current_node].append(tmp_out_connection_point)
if neighb_connection_point not in boarder_connections[neighb_idx]:
boarder_connections[neighb_idx].append(neighb_connection_point)
direction += 1
return boarder_connections
def _build_inner_cities(node_positions, connection_points, outer_connection_points, rail_trans, grid_map):
"""
Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
:param node_positions:
:param connection_points:
:param rail_trans:
:param grid_map:
:return:
"""
through_path_cells = [[] for i in range(len(node_positions))]
for current_city in range(len(node_positions)):
for boarder in range(4):
for source in connection_points[current_city][boarder]:
for other_boarder in range(4):
if boarder != other_boarder and len(connection_points[current_city][other_boarder]) > 0:
for target in connection_points[current_city][other_boarder]:
city_boarder = _city_boarder(node_positions[current_city], node_radius)
current_track = connect_cities(rail_trans, grid_map, source, target, city_boarder)
if target in outer_connection_points[current_city] and source in \
outer_connection_points[current_city] and len(through_path_cells[current_city]) < 1:
through_path_cells[current_city].extend(current_track)
else:
continue
return through_path_cells
def _set_trainstation_positions(node_positions, through_tracks, grid_map):
"""
:param node_positions:
:param num_trainstations:
:return:
"""
nb_nodes = len(node_positions)
train_stations = [[] for i in range(nb_nodes)]
built_num_trainstations = 0
for current_city in range(len(node_positions)):
for possible_location in _city_cells(node_positions[current_city], node_radius - 1):
if possible_location in through_tracks[current_city]:
continue
cell_type = grid_map.get_full_transitions(*possible_location)
nbits = 0
while cell_type > 0:
nbits += (cell_type & 1)
cell_type = cell_type >> 1
if 1 <= nbits <= 2:
built_num_trainstations += 1
train_stations[current_city].append(possible_location)
return train_stations, built_num_trainstations
def _generate_start_target_pairs(num_agents, nb_nodes, train_stations):
# Generate start and target node directory for all agents.
# Assure that start and target are not in the same node
agent_start_targets_nodes = []
# Slot availability in node
node_available_start = []
node_available_target = []
for node_idx in range(nb_nodes):
node_available_start.append(len(train_stations[node_idx]))
node_available_target.append(len(train_stations[node_idx]))
# Assign agents to slots
for agent_idx in range(num_agents):
avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
start_node = np.random.choice(avail_start_nodes)
target_node = np.random.choice(avail_target_nodes)
tries = 0
found_agent_pair = True
while target_node == start_node:
target_node = np.random.choice(avail_target_nodes)
tries += 1
# Test again with new start node if no pair is found (This code needs to be improved)
if (tries + 1) % 10 == 0:
start_node = np.random.choice(avail_start_nodes)
if tries > 100:
warnings.warn("Could not set trainstations, removing agent!")
found_agent_pair = False
break
if found_agent_pair:
node_available_start[start_node] -= 1
node_available_target[target_node] -= 1
agent_start_targets_nodes.append((start_node, target_node))
else:
num_agents -= 1
return agent_start_targets_nodes, num_agents
def _fix_transitions(grid_map):
"""
Function to fix all transition elements in environment
"""
# Fix all nodes with illegal transition maps
empty_to_fix = []
rails_to_fix = []
height, width = np.shape(grid_map.grid)
for r in range(height):
for c in range(width):
rc_pos = (r, c)
check = grid_map.cell_neighbours_valid(rc_pos, True)
if not check:
if grid_map.grid[rc_pos] == 0:
empty_to_fix.append(rc_pos)
else:
rails_to_fix.append(rc_pos)
# Fix empty cells first to avoid cutting the network
for cell in empty_to_fix:
grid_map.fix_transitions(cell)
# Fix all other cells
for cell in rails_to_fix:
grid_map.fix_transitions(cell)
def _closest_neigh_in_direction(current_node, direction, node_positions):
# Sort available neighbors according to their distance.
node_dist = []
for av_node in range(len(node_positions)):
node_dist.append(distance_on_rail(node_positions[current_node], node_positions[av_node]))
sorted_neighbours = np.argsort(node_dist)
for neighb in sorted_neighbours[1:]:
distance_0 = np.abs(node_positions[current_node][0] - node_positions[neighb][0])
distance_1 = np.abs(node_positions[current_node][1] - node_positions[neighb][1])
if direction == 0:
if node_positions[neighb][0] < node_positions[current_node][0] and distance_1 <= distance_0:
return neighb
if direction == 1:
if node_positions[neighb][1] > node_positions[current_node][1] and distance_0 <= distance_1:
return neighb
if direction == 2:
if node_positions[neighb][0] > node_positions[current_node][0] and distance_1 <= distance_0:
return neighb
if direction == 3:
if node_positions[neighb][1] < node_positions[current_node][1] and distance_0 <= distance_1:
return neighb
return None
def argsort(seq):
# http://stackoverflow.com/questions/3071415/efficient-method-to-calculate-the-rank-vector-of-a-list-in-python
return sorted(range(len(seq)), key=seq.__getitem__)
def _city_cells(center, radius):
"""
Function to return all cells within a city
:param center: center coordinates of city
:param radius: radius of city (it is a square)
:return: returns flat list of all cell coordinates in the city
"""
city_cells = []
for x in range(-radius, radius + 1):
for y in range(-radius, radius + 1):
city_cells.append((center[0] + x, center[1] + y))
return city_cells
def _city_boarder(center, radius):
city_boarder = []
for x in range(-radius, radius + 1):
for y in range(-radius, radius + 1):
if abs(x) == radius or abs(y) == radius:
city_boarder.append((center[0] + x, center[1] + y))
return city_boarder
return generator