Commit 40d892dd authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

refactoring of new generator code.

Much more stable level generation.
Checking for invalid transitions objects and fixing them.
parent 43d9ef66
Pipeline #1759 failed with stages
in 6 minutes and 6 seconds
......@@ -103,4 +103,5 @@ def a_star(rail_trans, rail_array, start, end):
# no full path found
if len(open_nodes) == 0:
print("could not make path")
return []
......@@ -403,6 +403,62 @@ class GridTransitionMap(TransitionMap):
return True
def fix_transitions(self, rcPos):
"""
Fixes broken transitions
"""
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# loop over available outbound directions (indices) for rcPos
incomping_connections = np.zeros(4)
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
connected = 0
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
incomping_connections[iDirOut] = 1
number_of_incoming = np.sum(incomping_connections)
# Only one incoming direction --> Straight line
if number_of_incoming == 1:
for direction in range(4):
if incomping_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
connect_directions = np.argwhere(incomping_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection fro three entries
if number_of_incoming == 3:
hole = np.argwhere(incomping_connections < 1)[0][0]
connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1)
# Make a cross
if number_of_incoming == 4:
for direction in range(4):
self.set_transition((grcPos[0], grcPos[1], direction), direction, 1)
return True
def mirror(dir):
return (dir + 2) % 4
......
......@@ -9,7 +9,7 @@ from flatland.core.grid.grid_utils import distance_on_rail
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.grid4_generators_utils import connect_rail, connect_from_nodes, connect_nodes
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
......@@ -863,7 +863,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
if tries > 100:
warnings.warn("Could not set nodes, please change initial parameters!!!!")
break
print(node_positions)
# Chose node connection
available_nodes_full = np.arange(num_cities + num_intersections)
available_cities = np.arange(num_cities)
......@@ -881,7 +881,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
delete_idx = np.where(available_cities == current_node)
available_cities = np.delete(available_cities, delete_idx, 0)
elif len(available_intersections) > 0:
elif len(available_intersections) > 0 and len(available_cities) > 0:
available_nodes = available_cities
delete_idx = np.where(available_intersections == current_node)
available_intersections = np.delete(available_intersections, delete_idx, 0)
......@@ -894,8 +894,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
available_nodes = available_nodes[np.argsort(node_dist)]
# Set number of neighboring nodes
print(current_node, allowed_connections)
if len(available_nodes) >= allowed_connections:
connected_neighb_idx = available_nodes[:allowed_connections]
else:
......@@ -903,6 +901,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
if current_node == 0:
allowed_connections -= 1
# Connect to the neighboring nodes
for neighb in connected_neighb_idx:
if neighb not in node_stack:
......@@ -938,12 +937,14 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
train_stations[trainstation_node].append((station_x, station_y))
# Connect train station to the correct node
connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y))
connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
(station_x, station_y))
# Check if connection was made
if len(connection) == 0:
train_stations[trainstation_node].pop(-1)
# Fix all nodes with illegal transition maps
for current_node in node_positions:
if not grid_map.cell_neighbours_valid(current_node):
grid_map.fix_neighbours(current_node)
grid_map.fix_transitions(current_node)
# Generate start and target node directory for all agents.
# Assure that start and target are not in the same node
......
......@@ -76,8 +76,8 @@ def connect_nodes(rail_trans, rail_array, start, end):
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# don't set any transition at node yet
new_trans = 0
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
......@@ -93,7 +93,9 @@ def connect_nodes(rail_trans, rail_array, start, end):
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
# don't set any transition at node yet
new_trans_e = 0
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
......
import time
import numpy as np
from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator
......@@ -23,11 +21,11 @@ def test_realistic_rail_generator():
def test_sparse_rail_generator():
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=3, # Number of cities in map
env = RailEnv(width=20,
height=20,
rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map
num_intersections=2, # Number of interesections in map
num_trainstations=15, # Number of possible start/targets on map
num_trainstations=20, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=2, # Number of connections to other cities
......@@ -38,4 +36,3 @@ def test_sparse_rail_generator():
# reset to initialize agents_static
env_renderer = RenderTool(env, gl="PILSVG", )
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
time.sleep(10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment