Skip to content
Snippets Groups Projects
Commit d3657c20 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

priority to connect to intersection instead of other city.

parent 979b637c
No related branches found
No related tags found
No related merge requests found
...@@ -838,7 +838,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -838,7 +838,10 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
np.random.seed(seed + num_resets) np.random.seed(seed + num_resets)
# Generate a set of nodes for the sparse network # Generate a set of nodes for the sparse network
# Try to connect cities to nodes first
node_positions = [] node_positions = []
city_positions = []
intersection_positions = []
for node_idx in range(num_cities + num_intersections): for node_idx in range(num_cities + num_intersections):
to_close = True to_close = True
tries = 0 tries = 0
...@@ -851,21 +854,34 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -851,21 +854,34 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
to_close = True to_close = True
if not to_close: if not to_close:
node_positions.append((x_tmp, y_tmp)) node_positions.append((x_tmp, y_tmp))
if node_idx < num_cities:
city_positions.append((x_tmp, y_tmp))
else:
intersection_positions.append((x_tmp, y_tmp))
tries += 1 tries += 1
if tries > 100: if tries > 100:
warnings.warn("Could not set nodes, please change initial parameters!!!!") warnings.warn("Could not set nodes, please change initial parameters!!!!")
break break
# Chose node connection # Chose node connection
available_nodes = np.arange(num_cities + num_intersections) available_nodes_full = np.arange(num_cities + num_intersections)
available_cities = np.arange(num_cities)
available_intersections = np.arange(num_cities, num_cities + num_intersections)
current_node = 0 current_node = 0
node_stack = [current_node] node_stack = [current_node]
while len(node_stack) > 0: while len(node_stack) > 0:
current_node = node_stack[0] current_node = node_stack[0]
delete_idx = np.where(available_nodes == current_node) delete_idx = np.where(available_nodes_full == current_node)
available_nodes = np.delete(available_nodes, delete_idx, 0) available_nodes_full = np.delete(available_nodes_full, delete_idx, 0)
if current_node < num_cities and len(available_intersections) > 0:
available_nodes = available_intersections
available_cities = np.delete(available_cities, delete_idx, 0)
elif len(available_intersections) > 0:
available_nodes = available_cities
available_intersections = np.delete(available_intersections, delete_idx, 0)
else:
available_nodes = available_nodes_full
# Sort available neighbors according to their distance. # Sort available neighbors according to their distance.
node_dist = [] node_dist = []
for av_node in available_nodes: for av_node in available_nodes:
...@@ -885,30 +901,36 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -885,30 +901,36 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
node_stack.append(neighb) node_stack.append(neighb)
connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb]) connect_nodes(rail_trans, rail_array, node_positions[current_node], node_positions[neighb])
node_stack.pop(0) node_stack.pop(0)
# Place train stations close to the node # Place train stations close to the node
# We currently place them uniformly distirbuted among all cities # We currently place them uniformly distirbuted among all cities
train_stations = [[] for i in range(num_cities)] if num_cities > 1:
train_stations = [[] for i in range(num_cities)]
for station in range(num_trainstations):
trainstation_node = int(station / num_trainstations * num_cities) for station in range(num_trainstations):
trainstation_node = int(station / num_trainstations * num_cities)
station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), 0,
height - 1)
station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), 0,
width - 1)
while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
trainstation_node] or \
rail_array[(station_x, station_y)] != 0:
station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
0, 0,
height - 1) height - 1)
station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
0, 0,
width - 1) width - 1)
train_stations[trainstation_node].append((station_x, station_y)) while (station_x, station_y) in train_stations or (station_x, station_y) == node_positions[
trainstation_node] or \
# Connect train station to the correct node rail_array[(station_x, station_y)] != 0:
connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y)) station_x = np.clip(
node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius),
0,
height - 1)
station_y = np.clip(
node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius),
0,
width - 1)
train_stations[trainstation_node].append((station_x, station_y))
# Connect train station to the correct node
connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y))
# Fix all nodes with illegal transition maps # Fix all nodes with illegal transition maps
for current_node in node_positions: for current_node in node_positions:
......
import time
import numpy as np import numpy as np
from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator
...@@ -23,17 +25,17 @@ def test_realistic_rail_generator(): ...@@ -23,17 +25,17 @@ def test_realistic_rail_generator():
def test_sparse_rail_generator(): def test_sparse_rail_generator():
env = RailEnv(width=50, env = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map
num_intersections=3, # Number of interesections in map num_intersections=2, # Number of interesections in map
num_trainstations=10, # Number of possible start/targets on map num_trainstations=10, # Number of possible start/targets on map
min_node_dist=10, # Minimal distance of nodes min_node_dist=10, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center node_radius=2, # Proximity of stations to city center
num_neighb=4, # Number of connections to other cities num_neighb=2, # Number of connections to other cities
seed=15, # Random seed seed=15, # Random seed
), ),
number_of_agents=1, number_of_agents=0,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
# reset to initialize agents_static # reset to initialize agents_static
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
env_renderer.render_env(show=True, show_observations=True, show_predictions=False) env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
time.sleep(5)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment