Skip to content
Snippets Groups Projects
Commit 95ba74c8 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

first commit of improved level generator. Now cities are not connected from...

first commit of improved level generator. Now cities are not connected from the center but rather from the boarders. NESW
parent 15db8d20
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
import time
import numpy as np import numpy as np
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator from flatland.envs.schedule_generators import random_schedule_generator
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
np.random.seed(1) np.random.seed(1)
...@@ -30,20 +32,20 @@ speed_ration_map = {1.: 0.25, # Fast passenger train ...@@ -30,20 +32,20 @@ speed_ration_map = {1.: 0.25, # Fast passenger train
env = RailEnv(width=50, env = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=25, # Number of cities in map (where train stations are) rail_generator=sparse_rail_generator(num_cities=9, # Number of cities in map (where train stations are)
num_intersections=10, # Number of intersections (no start / target) num_intersections=0, # Number of intersections (no start / target)
num_trainstations=50, # Number of possible start/targets on map num_trainstations=50, # Number of possible start/targets on map
min_node_dist=3, # Minimal distance of nodes min_node_dist=3, # Minimal distance of nodes
node_radius=4, # Proximity of stations to city center node_radius=5, # Proximity of stations to city center
num_neighb=4, # Number of connections to other cities/intersections num_neighb=3, # Number of connections to other cities/intersections
seed=15, # Random seed seed=15, # Random seed
grid_mode=True, grid_mode=True,
enhance_intersection=False enhance_intersection=False
), ),
schedule_generator=sparse_schedule_generator(speed_ration_map), schedule_generator=random_schedule_generator(),
number_of_agents=20, number_of_agents=0,
stochastic_data=stochastic_data, # Malfunction data generator stochastic_data=stochastic_data, # Malfunction data generator
obs_builder_object=TreeObservation) obs_builder_object=GlobalObsForRailEnv())
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
...@@ -111,6 +113,7 @@ for step in range(500): ...@@ -111,6 +113,7 @@ for step in range(500):
# reward and whether their are done # reward and whether their are done
next_obs, all_rewards, done, _ = env.step(action_dict) next_obs, all_rewards, done, _ = env.step(action_dict)
env_renderer.render_env(show=True, show_observations=False, show_predictions=False) env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
time.sleep(50)
frame_step += 1 frame_step += 1
# Update replay buffer and train agent # Update replay buffer and train agent
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
......
...@@ -9,7 +9,7 @@ from flatland.envs.rail_generators_city_generator import city_generator ...@@ -9,7 +9,7 @@ from flatland.envs.rail_generators_city_generator import city_generator
from flatland.envs.schedule_generators import city_schedule_generator from flatland.envs.schedule_generators import city_schedule_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant from flatland.utils.rendertools import RenderTool, AgentRenderVariant
os.mkdir("./../render_output/") # os.mkdir("./../render_output/")
for itrials in np.arange(1, 15, 1): for itrials in np.arange(1, 15, 1):
print(itrials, "generate new city") print(itrials, "generate new city")
......
...@@ -295,4 +295,4 @@ def coordinate_to_position(depth, coords): ...@@ -295,4 +295,4 @@ def coordinate_to_position(depth, coords):
def distance_on_rail(pos1, pos2): def distance_on_rail(pos1, pos2):
return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1]) return np.sqrt(np.power(pos1[0] - pos2[0], 2) + np.power(pos1[1] - pos2[1], 2))
...@@ -570,8 +570,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -570,8 +570,8 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
nodes_ratio = height / width nodes_ratio = height / width
nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio))) nodes_per_row = int(np.ceil(np.sqrt(nb_nodes * nodes_ratio)))
nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row)) nodes_per_col = int(np.ceil(nb_nodes / nodes_per_row))
x_positions = np.linspace(node_radius, height - node_radius, nodes_per_row, dtype=int) x_positions = np.linspace(node_radius, height - node_radius - 1, nodes_per_row, dtype=int)
y_positions = np.linspace(node_radius, width - node_radius, nodes_per_col, dtype=int) y_positions = np.linspace(node_radius, width - node_radius - 1, nodes_per_col, dtype=int)
city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False) city_idx = np.random.choice(np.arange(nb_nodes), num_cities, False)
node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, node_positions = _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions,
...@@ -598,11 +598,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -598,11 +598,15 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
available_cities = np.arange(_num_cities) available_cities = np.arange(_num_cities)
available_intersections = np.arange(_num_cities, nb_nodes) available_intersections = np.arange(_num_cities, nb_nodes)
# Set up connection points
connection_points = _generate_node_connection_points(node_positions, node_radius, max_nr_connection_points=4)
print(connection_points)
# Start at some node # Start at some node
current_node = np.random.randint(len(available_nodes_full)) current_node = np.random.randint(len(available_nodes_full))
node_stack = [current_node] node_stack = [current_node]
allowed_connections = num_neighb allowed_connections = len(connection_points[current_node])
first_node = True first_node = True
i = 0
while len(node_stack) > 0: while len(node_stack) > 0:
current_node = node_stack[0] current_node = node_stack[0]
delete_idx = np.where(available_nodes_full == current_node) delete_idx = np.where(available_nodes_full == current_node)
...@@ -645,37 +649,60 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -645,37 +649,60 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
for neighb in connected_neighb_idx: for neighb in connected_neighb_idx:
if neighb not in node_stack: if neighb not in node_stack:
node_stack.append(neighb) node_stack.append(neighb)
connect_nodes(rail_trans, grid_map, node_positions[current_node], node_positions[neighb])
dist_from_center = distance_on_rail(node_positions[current_node], node_positions[neighb])
for tmp_out_connection_point in connection_points[current_node]:
tmp_dist_to_node = distance_on_rail(tmp_out_connection_point, node_positions[neighb])
# Check if this connection node is on the city side facing the neighbour
print("Current node", current_node, "Neigh", neighb, "Distance", tmp_dist_to_node, dist_from_center)
if tmp_dist_to_node < dist_from_center - 1:
min_connection_dist = np.inf
# Find closes connection point
for tmp_in_connection_point in connection_points[neighb]:
tmp_dist = distance_on_rail(tmp_out_connection_point, tmp_in_connection_point)
if tmp_dist < min_connection_dist:
min_connection_dist = tmp_dist
neighb_connection_point = tmp_in_connection_point
center_distance = distance_on_rail(node_positions[current_node], tmp_in_connection_point)
if distance_on_rail(tmp_out_connection_point, neighb_connection_point) < center_distance:
i += 1
connect_nodes(rail_trans, grid_map, tmp_out_connection_point, neighb_connection_point)
node_stack.pop(0) node_stack.pop(0)
# Place train stations close to the node # Place train stations close to the node
# We currently place them uniformly distributed among all cities # We currently place them uniformly distributed among all cities
built_num_trainstation = 0 built_num_trainstation = 0
train_stations = [[] for i in range(_num_cities)] train_stations = [[] for i in range(_num_cities)]
if _num_cities > 1: if _num_cities > 1:
for station in range(num_trainstations): for station in range(num_trainstations):
spot_found = True spot_found = True
reduced_node_radius = node_radius - 1
trainstation_node = int(station / num_trainstations * _num_cities) trainstation_node = int(station / num_trainstations * _num_cities)
station_x = np.clip(node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), station_x = np.clip(
0, node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius, reduced_node_radius),
height - 1) 0,
station_y = np.clip(node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), height - 1)
0, station_y = np.clip(
width - 1) node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius, reduced_node_radius),
0,
width - 1)
tries = 0 tries = 0
while (station_x, station_y) in train_stations[trainstation_node] \ while (station_x, station_y) in train_stations[trainstation_node] \
or (station_x, station_y) == node_positions[trainstation_node] \ or (station_x, station_y) == node_positions[trainstation_node] \
or rail_array[(station_x, station_y)] != 0: # noqa: E125 or rail_array[(station_x, station_y)] != 0: # noqa: E125
station_x = np.clip( station_x = np.clip(
node_positions[trainstation_node][0] + np.random.randint(-node_radius, node_radius), node_positions[trainstation_node][0] + np.random.randint(-reduced_node_radius,
reduced_node_radius),
0, 0,
height - 1) height - 1)
station_y = np.clip( station_y = np.clip(
node_positions[trainstation_node][1] + np.random.randint(-node_radius, node_radius), node_positions[trainstation_node][1] + np.random.randint(-reduced_node_radius,
reduced_node_radius),
0, 0,
width - 1) width - 1)
tries += 1 tries += 1
...@@ -687,11 +714,17 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -687,11 +714,17 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
if spot_found: if spot_found:
train_stations[trainstation_node].append((station_x, station_y)) train_stations[trainstation_node].append((station_x, station_y))
# Connect train station to the correct node # Connect train station to random nodes
connection = connect_from_nodes(rail_trans, grid_map, node_positions[trainstation_node],
rand_corner_nodes = np.random.choice(range(len(connection_points[trainstation_node])), 2, replace=False)
connection_1 = connect_from_nodes(rail_trans, grid_map,
connection_points[trainstation_node][rand_corner_nodes[0]],
(station_x, station_y)) (station_x, station_y))
connection_2 = connect_from_nodes(rail_trans, grid_map,
connection_points[trainstation_node][rand_corner_nodes[1]],
(station_x, station_y))
# Check if connection was made # Check if connection was made
if len(connection) == 0: if len(connection_1) == 0 and len(connection_2) == 0:
if len(train_stations[trainstation_node]) > 0: if len(train_stations[trainstation_node]) > 0:
train_stations[trainstation_node].pop(-1) train_stations[trainstation_node].pop(-1)
else: else:
...@@ -702,39 +735,10 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -702,39 +735,10 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
num_agents = built_num_trainstation num_agents = built_num_trainstation
warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents") warnings.warn("sparse_rail_generator: num_agents > nr_start_goal, changing num_agents")
# Place passing lanes at intersections
# We currently place them uniformly distirbuted among all cities
if enhance_intersection:
for intersection in range(_num_intersections):
intersect_x_1 = np.clip(intersection_positions[intersection][0] + np.random.randint(1, 3),
1,
height - 2)
intersect_y_1 = np.clip(intersection_positions[intersection][1] + np.random.randint(-3, 3),
2,
width - 2)
intersect_x_2 = np.clip(
intersection_positions[intersection][0] + np.random.randint(-3, -1),
1,
height - 2)
intersect_y_2 = np.clip(
intersection_positions[intersection][1] + np.random.randint(-3, 3),
1,
width - 2)
# Connect train station to the correct node
connect_nodes(rail_trans, grid_map, (intersect_x_1, intersect_y_1),
(intersect_x_2, intersect_y_2))
connect_nodes(rail_trans, grid_map, intersection_positions[intersection],
(intersect_x_1, intersect_y_1))
connect_nodes(rail_trans, grid_map, intersection_positions[intersection],
(intersect_x_2, intersect_y_2))
grid_map.fix_transitions((intersect_x_1, intersect_y_1))
grid_map.fix_transitions((intersect_x_2, intersect_y_2))
# Fix all nodes with illegal transition maps # Fix all nodes with illegal transition maps
for current_node in node_positions: flat_list = [item for sublist in connection_points for item in sublist]
grid_map.fix_transitions(current_node) for cell_to_fix in flat_list:
grid_map.fix_transitions(cell_to_fix)
# Generate start and target node directory for all agents. # Generate start and target node directory for all agents.
# Assure that start and target are not in the same node # Assure that start and target are not in the same node
...@@ -773,9 +777,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -773,9 +777,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
num_agents -= 1 num_agents -= 1
return grid_map, {'agents_hints': { return grid_map, {'agents_hints': {
'num_agents': num_agents, 'num_agents': num_agents
'agent_start_targets_nodes': agent_start_targets_nodes,
'train_stations': train_stations
}} }}
def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes, def _generate_node_positions_not_grid_mode(city_positions, height, intersection_positions, nb_nodes,
...@@ -820,15 +822,46 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2 ...@@ -820,15 +822,46 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes, def _generate_node_positions_grid_mode(city_idx, city_positions, intersection_positions, nb_nodes,
nodes_per_row, x_positions, y_positions): nodes_per_row, x_positions, y_positions):
for node_idx in range(nb_nodes): for node_idx in range(nb_nodes):
x_tmp = x_positions[node_idx % nodes_per_row] x_tmp = x_positions[node_idx % nodes_per_row]
y_tmp = y_positions[node_idx // nodes_per_row] y_tmp = y_positions[node_idx // nodes_per_row]
if node_idx in city_idx: if node_idx in city_idx:
city_positions.append((x_tmp, y_tmp)) city_positions.append((x_tmp, y_tmp))
else: else:
intersection_positions.append((x_tmp, y_tmp)) intersection_positions.append((x_tmp, y_tmp))
node_positions = city_positions + intersection_positions node_positions = city_positions + intersection_positions
return node_positions return node_positions
def _generate_node_connection_points(node_positions, node_size, max_nr_connection_points=2):
connection_points = []
for node_position in node_positions:
n_connection_points = max_nr_connection_points # np.random.randint(1, max_nr_connection_points)
connection_per_direction = n_connection_points // 4
connection_point_vector = [connection_per_direction, connection_per_direction, connection_per_direction,
n_connection_points - 3 * connection_per_direction]
print(connection_point_vector)
connection_points_coordinates = []
for direction in range(4):
rnd_points = np.random.choice(np.arange(-node_size, node_size), size=connection_point_vector[direction],
replace=False)
for connection_idx in range(connection_point_vector[direction]):
if direction == 0:
connection_points_coordinates.append(
(node_position[0] - node_size, node_position[1] + rnd_points[connection_idx]))
if direction == 1:
connection_points_coordinates.append(
(node_position[0] + rnd_points[connection_idx], node_position[1] + node_size))
if direction == 2:
connection_points_coordinates.append(
(node_position[0] + node_size, node_position[1] + rnd_points[connection_idx]))
if direction == 3:
connection_points_coordinates.append(
(node_position[0] + rnd_points[connection_idx], node_position[1] - node_size))
connection_points.append(connection_points_coordinates)
return connection_points
return generator return generator
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment