Commit fe85e973 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

added city boarder to as forbidden zone for inner city connection. This...

added city boarder to as forbidden zone for inner city connection. This strictly seperates inner from outer connections. should help us to get nicer infrastructures.
parent 1d8a6479
Pipeline #2209 failed with stages
in 60 minutes
......@@ -33,16 +33,16 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are)
num_trainstations=50, # Number of possible start/targets on map
min_node_dist=5, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_trainstations=45, # Number of possible start/targets on map
min_node_dist=10, # Minimal distance of nodes
node_radius=4, # Proximity of stations to city center
seed=15, # Random seed
grid_mode=True,
grid_mode=False,
max_connection_points_per_side=2,
max_nr_connection_directions=4
max_nr_connection_directions=2
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=50,
number_of_agents=15,
stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=GlobalObsForRailEnv())
......
......@@ -9,7 +9,7 @@ from flatland.core.grid.grid4_utils import get_direction, mirror
from flatland.core.grid.grid_utils import distance_on_rail, direction_to_point
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_cities
from flatland.envs.grid4_generators_utils import connect_rail, connect_cities
RailGeneratorProduct = Tuple[GridTransitionMap, Optional[Dict]]
RailGenerator = Callable[[int, int, int, int], RailGeneratorProduct]
......@@ -589,7 +589,11 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
_connect_cities(node_positions, connection_points, connection_info, city_cells, rail_trans, grid_map)
# Build inner cities
train_stations, built_num_trainstation = _build_cities(node_positions, connection_points, rail_trans, grid_map)
_build_inner_cities(node_positions, connection_points, rail_trans, grid_map)
# Populate cities
train_stations, built_num_trainstation = _set_trainstation_positions(node_positions, city_cells,
num_trainstations, grid_map)
# Adjust the number of agents if you could not build enough trainstations
if num_agents > built_num_trainstation:
......@@ -600,6 +604,7 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
_fix_transitions(grid_map)
# Generate start target paris
print(train_stations)
agent_start_targets_nodes, num_agents = _generate_start_target_pairs(num_agents, nb_nodes, train_stations)
return grid_map, {'agents_hints': {
......@@ -688,25 +693,25 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
nr_of_connection_points = np.random.randint(1, max_nr_connection_points + 1)
connections_per_direction[idx] = nr_of_connection_points
connection_points_coordinates = []
connection_points_coordinates = [[] for i in range(4)]
for direction in range(4):
connection_slots = np.arange(connections_per_direction[direction]) - int(
connections_per_direction[direction] / 2)
for connection_idx in range(connections_per_direction[direction]):
if direction == 0:
connection_points_coordinates.append(
(node_position[0] - node_size, node_position[1] + connection_slots[connection_idx]))
tmp_coordinates = (
node_position[0] - node_size, node_position[1] + connection_slots[connection_idx])
if direction == 1:
connection_points_coordinates.append(
(node_position[0] + connection_slots[connection_idx], node_position[1] + node_size))
tmp_coordinates = (
node_position[0] + connection_slots[connection_idx], node_position[1] + node_size)
if direction == 2:
connection_points_coordinates.append(
(node_position[0] + node_size, node_position[1] + connection_slots[connection_idx]))
tmp_coordinates = (
node_position[0] + node_size, node_position[1] + connection_slots[connection_idx])
if direction == 3:
connection_points_coordinates.append(
(node_position[0] + connection_slots[connection_idx], node_position[1] - node_size))
tmp_coordinates = (
node_position[0] + connection_slots[connection_idx], node_position[1] - node_size)
connection_points_coordinates[direction].append(tmp_coordinates)
connection_points.append(connection_points_coordinates)
connection_info.append(connections_per_direction)
return connection_points, connection_info
......@@ -733,15 +738,13 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
if neighb_idx is not None:
connection_distances = []
for tmp_out_connection_point in connection_points[current_node]:
tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb_idx])
connection_distances.append(tmp_dist_to_node)
possible_connection_points = argsort(connection_distances)
for sort_idx in possible_connection_points[:connection_info[current_node][direction]]:
for tmp_out_connection_point in connection_points[current_node][direction]:
# Find closest connection point
tmp_out_connection_point = connection_points[current_node][sort_idx]
min_connection_dist = np.inf
for tmp_in_connection_point in connection_points[neighb_idx]:
all_neighb_connection_points = [item for sublist in connection_points[neighb_idx] for item in
sublist]
for tmp_in_connection_point in all_neighb_connection_points:
tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
......@@ -753,71 +756,51 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
direction += 1
return boarder_connections
def _build_inner_cities(node_positions, connection_points, rail_trans, grid_map):
"""
Builds inner city tracks. This current version connects all incoming connections to all outgoing connections
:param node_positions:
:param connection_points:
:param rail_trans:
:param grid_map:
:return:
"""
for current_city in range(len(node_positions)):
for boarder in range(4):
for source in connection_points[current_city][boarder]:
for other_boarder in range(4):
if boarder != other_boarder and len(connection_points[current_city][other_boarder]) > 0:
for target in connection_points[current_city][other_boarder]:
city_boarder = _city_boarder(node_positions[current_city], node_radius)
connect_cities(rail_trans, grid_map, source, target, city_boarder)
else:
continue
def _build_cities(node_positions, connection_points, rail_trans, grid_map):
# Place train stations close to the node
# We currently place them uniformly distributed among all cities
built_num_trainstation = 0
nb_nodes = len(node_positions)
height, width = np.shape(grid_map.grid)
train_stations = [[] for i in range(nb_nodes)]
if nb_nodes > 1:
for station in range(num_trainstations):
spot_found = True
reduced_node_radius = node_radius - 1
trainstation_node = int(station / num_trainstations * nb_nodes)
station_x = np.clip(
node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius, reduced_node_radius),
0,
height - 1)
station_y = np.clip(
node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius, reduced_node_radius),
0,
width - 1)
tries = 0
while (station_x, station_y) in train_stations[trainstation_node]:
station_x = np.clip(
node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius,
reduced_node_radius),
0,
height - 1)
station_y = np.clip(
node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius,
reduced_node_radius),
0,
width - 1)
tries += 1
if tries > 100:
warnings.warn("Could not set trainstations, please change initial parameters!!!!")
spot_found = False
break
if spot_found:
train_stations[trainstation_node].append((station_x, station_y))
# Connect train station to random nodes
return
if len(connection_points[trainstation_node]) > 1:
rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2,
replace=False)
else:
rand_corner_nodes = [0]
for corner_node_idx in rand_corner_nodes:
connection = connect_nodes(rail_trans, grid_map,
connection_points[trainstation_node][corner_node_idx],
(station_x, station_y))
# Check if connection was made
if len(connection) == 0:
if len(train_stations[trainstation_node]) > 0:
train_stations[trainstation_node].pop(-1)
else:
def _set_trainstation_positions(node_positions, city_cells, num_trainstations, grid_map):
"""
built_num_trainstation += 1
return train_stations, built_num_trainstation
:param node_positions:
:param num_trainstations:
:return:
"""
nb_nodes = len(node_positions)
train_stations = [[] for i in range(nb_nodes)]
num_cities = len(node_positions)
built_num_trainstations = 0
stations_per_city = int(num_trainstations / num_cities)
for current_city in range(len(node_positions)):
for possible_location in _city_cells(node_positions[current_city], node_radius - 1):
cell_type = grid_map.get_full_transitions(*possible_location)
nbits = 0
while cell_type > 0:
nbits += (cell_type & 1)
cell_type = cell_type >> 1
if 1 <= nbits <= 2:
built_num_trainstations += 1
train_stations[current_city].append(possible_location)
return train_stations, built_num_trainstations
def _fix_transitions(grid_map):
"""
......@@ -924,10 +907,19 @@ def sparse_rail_generator(num_cities=5, num_trainstations=2, min_node_dist=20, n
:return: returns flat list of all cell coordinates in the city
"""
city_cells = []
for x in range(-radius, radius):
for y in range(-radius, radius):
for x in range(-radius, radius + 1):
for y in range(-radius, radius + 1):
city_cells.append((center[0] + x, center[1] + y))
return city_cells
def _city_boarder(center, radius):
city_boarder = []
for x in range(-radius, radius + 1):
for y in range(-radius, radius + 1):
print(x, y, radius)
if abs(x) == radius or abs(y) == radius:
city_boarder.append((center[0] + x, center[1] + y))
return city_boarder
return generator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment