Commit 21223866 authored by u229589's avatar u229589
Browse files

refactor initial variables in sparse_rail_generator

parent 844f0cdb
......@@ -30,11 +30,11 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=40,
height=40,
rail_generator=sparse_rail_generator(num_cities=8, # Number of cities in map (where train stations are)
rail_generator=sparse_rail_generator(max_num_cities=8, # Number of cities in map (where train stations are)
seed=1, # Random seed
grid_mode=False,
max_inter_city_rails=2,
max_tracks_in_city=4,
max_rails_between_cities=2,
max_rails_in_city=4,
),
schedule_generator=sparse_schedule_generator(speed_ration_map),
number_of_agents=20,
......
......@@ -4,7 +4,6 @@ import warnings
from typing import Callable, Tuple, Optional, Dict, List, Any
import msgpack
import networkx as nx
import numpy as np
from flatland.core.grid.grid4_utils import get_direction, mirror
......@@ -534,88 +533,74 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11) -> RailGener
return generator
def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4, max_tracks_in_city=4,
seed=0) -> RailGenerator:
def sparse_rail_generator(max_num_cities: int = 5, grid_mode: bool = False, max_rails_between_cities: int = 4,
max_rails_in_city: int = 4, seed: int = 0) -> RailGenerator:
"""
Generates railway networks with cities and inner city rails
:param num_cities: Number of city centers in the map
:param grid_mode: Arange cities in a grid or randomly
:param max_inter_city_rails: Maximum number of connecting rails going out from a city
:param max_tracks_in_city: maximum number of internal rails
:param max_num_cities: Number of city centers in the map
:param grid_mode: arrange cities in a grid or randomly
:param max_rails_between_cities: Maximum number of connecting rails going out from a city
:param max_rails_in_city: maximum number of internal rails
:param seed: Random seed to initiate rail
:return: generator
"""
G = nx.DiGraph()
DEBUG_PRINT_TIMING = False
def generator(width, height, num_agents, num_resets=0) -> RailGeneratorProduct:
def generator(width: int, height: int, num_agents: int, num_resets: int = 0) -> RailGeneratorProduct:
np.random.seed(seed + num_resets)
rail_trans = RailEnvTransitions()
grid_map = GridTransitionMap(width=width, height=height, transitions=rail_trans)
rail_array = grid_map.grid
rail_array.fill(0)
np.random.seed(seed + num_resets)
# Graph to be able to create correct start/end pairs for schedule
city_radius = int(np.ceil((max_rails_in_city + 2) / 2.0)) + 1
node_radius = int(np.ceil((max_tracks_in_city + 2) / 2.0)) + 1
if 3 > max_tracks_in_city:
rail_in_city = 3
else:
rail_in_city = max_tracks_in_city
max_inter_city_rails_allowed = max_inter_city_rails
if max_inter_city_rails_allowed > rail_in_city:
max_inter_city_rails_allowed = rail_in_city
# Generate a set of nodes for the sparse network
# Try to connect cities to nodes first
city_positions = []
intersection_positions = []
# Evenly distribute cities and intersections
min_nr_rails_in_city = 3
rails_in_city = min_nr_rails_in_city if max_rails_in_city < min_nr_rails_in_city else max_rails_in_city
rails_between_cities = rails_in_city if max_rails_between_cities > rails_in_city else max_rails_between_cities
# Evenly distribute cities
node_time_start = time.time()
node_positions: List[Any] = None
nb_nodes = num_cities
if grid_mode:
node_positions, city_cells = _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width)
city_positions, city_cells = _generate_evenly_distr_city_positions(max_num_cities, city_radius, width, height)
else:
node_positions, city_cells = _generate_random_node_positions(nb_nodes, node_radius, height, width)
city_positions, city_cells = _generate_random_city_positions(max_num_cities, city_radius, width, height)
# reduce nb_nodes, _num_cities, _num_intersections if less were generated in not_grid_mode
nb_nodes = len(node_positions)
# reduce num_cities, _num_cities, _num_intersections if less were generated in not_grid_mode
num_cities = len(city_positions)
if DEBUG_PRINT_TIMING:
print("City position time", time.time() - node_time_start, "Seconds")
# Set up connection points for all cities
node_connection_time = time.time()
inner_connection_points, outer_connection_points, connection_info, city_orientations = _generate_node_connection_points(
node_positions, node_radius, max_inter_city_rails_allowed,
rail_in_city)
city_positions, city_radius, rails_between_cities,
rails_in_city)
if DEBUG_PRINT_TIMING:
print("Connection points", time.time() - node_connection_time)
# Connect the cities through the connection points
city_connection_time = time.time()
inter_city_lines = _connect_cities(node_positions, outer_connection_points, connection_info, city_cells,
inter_city_lines = _connect_cities(city_positions, outer_connection_points, connection_info, city_cells,
rail_trans, grid_map)
if DEBUG_PRINT_TIMING:
print("City connection time", time.time() - city_connection_time)
# Build inner cities
city_build_time = time.time()
through_tracks, free_tracks = _build_inner_cities(node_positions, inner_connection_points,
through_tracks, free_tracks = _build_inner_cities(city_positions, inner_connection_points,
outer_connection_points,
node_radius,
city_radius,
rail_trans,
grid_map)
if DEBUG_PRINT_TIMING:
print("City build time", time.time() - city_build_time)
# Populate cities
train_station_time = time.time()
train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, node_radius, free_tracks,
train_stations, built_num_trainstation = _set_trainstation_positions(city_positions, city_radius, free_tracks,
grid_map)
if DEBUG_PRINT_TIMING:
print("Trainstation placing time", time.time() - train_station_time)
# Fix all transition elements
grid_fix_time = time.time()
_fix_transitions(city_cells, inter_city_lines, grid_map)
......@@ -624,7 +609,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
# Generate start target pairs
schedule_time = time.time()
agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations,
agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, num_cities, train_stations,
city_orientations)
if DEBUG_PRINT_TIMING:
print("Schedule time", time.time() - schedule_time)
......@@ -636,52 +621,50 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
'city_orientations': city_orientations
}}
def _generate_random_node_positions(nb_nodes, node_radius, height, width):
def _generate_random_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
node_positions = []
city_cells = []
for node_idx in range(nb_nodes):
for node_idx in range(num_cities):
to_close = True
tries = 0
while to_close:
x_tmp = node_radius + 1 + np.random.randint(height - 2 * (node_radius + 1))
y_tmp = node_radius + 1 + np.random.randint(width - 2 * (node_radius + 1))
x_tmp = city_radius + 1 + np.random.randint(height - 2 * (city_radius + 1))
y_tmp = city_radius + 1 + np.random.randint(width - 2 * (city_radius + 1))
to_close = False
# Check distance to nodes
for node_pos in node_positions:
if _city_overlap((x_tmp, y_tmp), node_pos, 2 * (node_radius + 1) + 1):
if _city_overlap((x_tmp, y_tmp), node_pos, 2 * (city_radius + 1) + 1):
to_close = True
if not to_close:
node_positions.append((x_tmp, y_tmp))
city_cells.extend(_city_cells(node_positions[-1], node_radius))
city_cells.extend(_city_cells(node_positions[-1], city_radius))
tries += 1
if tries > 200:
warnings.warn(
"Could not only set {} nodes after {} tries, although {} of nodes required to be generated!".format(
len(node_positions),
tries, nb_nodes))
tries, num_cities))
break
G.add_node(node_idx)
return node_positions, city_cells
def _generate_node_positions_grid_mode(nb_nodes, node_radius, height, width):
def _generate_evenly_distr_city_positions(num_cities: int, city_radius: int, width: int, height: int) -> (List[Tuple[int, int]], List[Tuple[int, int]]):
nodes_ratio = height / width
nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
x_positions = np.linspace(node_radius + 1, height - node_radius - 2, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius + 1, width - node_radius - 2, nodes_per_col, dtype=int)
nodes_per_row = int(np.ceil(np.sqrt(num_cities * nodes_ratio)))
nodes_per_col = int(np.ceil(num_cities / nodes_per_row))
x_positions = np.linspace(city_radius + 1, height - city_radius - 2, nodes_per_row, dtype=int)
y_positions = np.linspace(city_radius + 1, width - city_radius - 2, nodes_per_col, dtype=int)
node_positions = []
city_cells = []
for node_idx in range(nb_nodes):
for node_idx in range(num_cities):
x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row]
node_positions.append((x_tmp, y_tmp))
city_cells.extend(_city_cells(node_positions[-1], node_radius))
G.add_node(node_idx)
city_cells.extend(_city_cells(node_positions[-1], city_radius))
return node_positions, city_cells
def _generate_node_connection_points(node_positions, node_size, max_inter_city_rails_allowed, tracks_in_city=2):
......@@ -745,7 +728,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
rail_trans, grid_map):
"""
Function to connect the different cities through their connection points
:param node_positions: Positions of city centers
:param city_positions: Positions of city centers
:param connection_points: Boarder connection points of cities
:param connection_info: Number of connection points per direction NESW
:param rail_trans: Transitions
......@@ -778,9 +761,6 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
new_line = connect_cities(rail_trans, grid_map, tmp_out_connection_point,
neighb_connection_point,
city_cells)
G.add_edge(current_node, neighb_idx, direction=out_direction, length=len(new_line))
G.add_edge(neighb_idx, current_node, direction=neighbour_direction, length=len(new_line))
all_paths.extend(new_line)
return all_paths
......@@ -789,7 +769,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
grid_map):
"""
Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
:param node_positions: Positions of the cities
:param city_positions: Positions of the cities
:param inner_connection_points: Points on city boarder that are used to generate inner city track
:param outer_connection_points: Points where the city is connected to neighboring cities
:param rail_trans:
......@@ -833,7 +813,7 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
def _set_trainstation_positions(node_positions, node_radius, free_tracks, grid_map):
"""
:param node_positions:
:param city_positions:
:param num_trainstations:
:return:
"""
......@@ -906,8 +886,8 @@ def sparse_rail_generator(num_cities=5, grid_mode=False, max_inter_city_rails=4,
def _closest_neigh_in_direction(current_node, node_positions):
"""
Returns indices of closest neighbours in every direction NESW
:param current_node: Index of node in node_positions list
:param node_positions: list of all points being considered
:param current_node: Index of node in city_positions list
:param city_positions: list of all points being considered
:return: list of index of closest neighbours in all directions
"""
node_dist = []
......
......@@ -13,7 +13,7 @@ from flatland.utils.rendertools import RenderTool
def test_sparse_rail_generator():
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10, # Number of interesections in map
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
......@@ -733,7 +733,7 @@ def test_sparse_rail_generator_deterministic():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
......@@ -1509,7 +1509,7 @@ def test_rail_env_action_required_info():
1. / 4.: 0.25} # Slow freight train
env_always_action = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
......@@ -1528,7 +1528,7 @@ def test_rail_env_action_required_info():
np.random.seed(0)
env_only_if_action_required = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
......@@ -1592,7 +1592,7 @@ def test_rail_env_malfunction_speed_info():
}
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
rail_generator=sparse_rail_generator(max_num_cities=10, # Number of cities in map
num_intersections=10,
# Number of interesections in map
num_trainstations=50,
......@@ -1640,7 +1640,7 @@ def test_sparse_generator_with_too_man_cities_does_not_break_down():
RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(
num_cities=100, # Number of cities in map
max_num_cities=100, # Number of cities in map
num_intersections=10, # Number of interesections in map
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
......
......@@ -165,7 +165,7 @@ def test_initial_malfunction():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
......@@ -247,7 +247,7 @@ def test_initial_malfunction_stop_moving():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
......@@ -339,7 +339,7 @@ def test_initial_malfunction_do_nothing():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
......@@ -430,7 +430,7 @@ def test_initial_nextmalfunction_not_below_zero():
env = RailEnv(width=25,
height=30,
rail_generator=sparse_rail_generator(num_cities=5,
rail_generator=sparse_rail_generator(max_num_cities=5,
# Number of cities in map (where train stations are)
num_intersections=4,
# Number of intersections (no start / target)
......
......@@ -23,7 +23,7 @@ def test_get_global_observation():
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=25,
rail_generator=sparse_rail_generator(max_num_cities=25,
# Number of cities in map (where train stations are)
num_intersections=10,
# Number of intersections (no start / target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment