Skip to content
Snippets Groups Projects
Commit 40d892dd authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

refactoring of new generator code.

Much more stable level generation.
Checking for invalid transitions objects and fixing them.
parent 43d9ef66
No related branches found
No related tags found
No related merge requests found
......@@ -103,4 +103,5 @@ def a_star(rail_trans, rail_array, start, end):
# no full path found
if len(open_nodes) == 0:
print("could not make path")
return []
......@@ -403,6 +403,62 @@ class GridTransitionMap(TransitionMap):
return True
def fix_transitions(self, rcPos):
"""
Fixes broken transitions
"""
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
# loop over available outbound directions (indices) for rcPos
incomping_connections = np.zeros(4)
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
connected = 0
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
incomping_connections[iDirOut] = 1
number_of_incoming = np.sum(incomping_connections)
# Only one incoming direction --> Straight line
if number_of_incoming == 1:
for direction in range(4):
if incomping_connections[direction] > 0:
self.set_transition((rcPos[0], rcPos[1], mirror(direction)), direction, 1)
# Connect all incoming connections
if number_of_incoming == 2:
connect_directions = np.argwhere(incomping_connections > 0)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
# Find feasible connection fro three entries
if number_of_incoming == 3:
hole = np.argwhere(incomping_connections < 1)[0][0]
connect_directions = [(hole + 1) % 4, (hole + 2) % 4, (hole + 3) % 4]
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[1], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[0])), connect_directions[2], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[1])), connect_directions[0], 1)
self.set_transition((rcPos[0], rcPos[1], mirror(connect_directions[2])), connect_directions[0], 1)
# Make a cross
if number_of_incoming == 4:
for direction in range(4):
self.set_transition((grcPos[0], grcPos[1], direction), direction, 1)
return True
def mirror(dir):
return (dir + 2) % 4
......
......@@ -9,7 +9,7 @@ from flatland.core.grid.grid_utils import distance_on_rail
from flatland.core.grid.rail_env_grid import RailEnvTransitions
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.grid4_generators_utils import connect_rail, connect_from_nodes, connect_nodes
from flatland.envs.grid4_generators_utils import connect_rail, connect_nodes, connect_from_nodes
from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail
......@@ -863,7 +863,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
if tries > 100:
warnings.warn("Could not set nodes, please change initial parameters!!!!")
break
print(node_positions)
# Chose node connection
available_nodes_full = np.arange(num_cities + num_intersections)
available_cities = np.arange(num_cities)
......@@ -881,7 +881,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
delete_idx = np.where(available_cities == current_node)
available_cities = np.delete(available_cities, delete_idx, 0)
elif len(available_intersections) > 0:
elif len(available_intersections) > 0 and len(available_cities) > 0:
available_nodes = available_cities
delete_idx = np.where(available_intersections == current_node)
available_intersections = np.delete(available_intersections, delete_idx, 0)
......@@ -894,8 +894,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
available_nodes = available_nodes[np.argsort(node_dist)]
# Set number of neighboring nodes
print(current_node, allowed_connections)
if len(available_nodes) >= allowed_connections:
connected_neighb_idx = available_nodes[:allowed_connections]
else:
......@@ -903,6 +901,7 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
if current_node == 0:
allowed_connections -= 1
# Connect to the neighboring nodes
for neighb in connected_neighb_idx:
if neighb not in node_stack:
......@@ -938,12 +937,14 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
train_stations[trainstation_node].append((station_x, station_y))
# Connect train station to the correct node
connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], (station_x, station_y))
connection = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
(station_x, station_y))
# Check if connection was made
if len(connection) == 0:
train_stations[trainstation_node].pop(-1)
# Fix all nodes with illegal transition maps
for current_node in node_positions:
if not grid_map.cell_neighbours_valid(current_node):
grid_map.fix_neighbours(current_node)
grid_map.fix_transitions(current_node)
# Generate start and target node directory for all agents.
# Assure that start and target are not in the same node
......
......@@ -76,8 +76,8 @@ def connect_nodes(rail_trans, rail_array, start, end):
if index == 0:
if new_trans == 0:
# end-point
# need to flip direction because of how end points are defined
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
# don't set any transition at node yet
new_trans = 0
else:
# into existing rail
new_trans = rail_trans.set_transition(new_trans, current_dir, new_dir, 1)
......@@ -93,7 +93,9 @@ def connect_nodes(rail_trans, rail_array, start, end):
new_trans_e = rail_array[end_pos]
if new_trans_e == 0:
# end-point
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
# don't set any transition at node yet
new_trans_e = 0
else:
# into existing rail
new_trans_e = rail_trans.set_transition(new_trans_e, new_dir, new_dir, 1)
......
import time
import numpy as np
from flatland.envs.generators import sparse_rail_generator, realistic_rail_generator
......@@ -23,11 +21,11 @@ def test_realistic_rail_generator():
def test_sparse_rail_generator():
env = RailEnv(width=50,
height=50,
rail_generator=sparse_rail_generator(num_cities=3, # Number of cities in map
env = RailEnv(width=20,
height=20,
rail_generator=sparse_rail_generator(num_cities=5, # Number of cities in map
num_intersections=2, # Number of interesections in map
num_trainstations=15, # Number of possible start/targets on map
num_trainstations=20, # Number of possible start/targets on map
min_node_dist=6, # Minimal distance of nodes
node_radius=3, # Proximity of stations to city center
num_neighb=2, # Number of connections to other cities
......@@ -38,4 +36,3 @@ def test_sparse_rail_generator():
# reset to initialize agents_static
env_renderer = RenderTool(env, gl="PILSVG", )
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
time.sleep(10)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment