Commit 7ca1ed3b authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

Fixed invalid transitions: now all nodes are valid

All starting positions are valid.
Check for feasibility of map is not yet done.
parent 0334070a
......@@ -77,7 +77,7 @@ for trials in range(1, n_trials + 1):
score = 0
# Run episode
for step in range(100):
for step in range(500):
# Chose an action for each agent in the environment
for a in range(env.get_num_agents()):
action = agent.act(obs[a])
......
......@@ -350,4 +350,60 @@ class GridTransitionMap(TransitionMap):
return True
def fix_neighbours(self, rcPos, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
return False
return True
def mirror(dir):
return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
......@@ -841,8 +841,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
to_close = True
tries = 0
while to_close:
x_tmp = 1 + np.random.randint(height - 1)
y_tmp = 1 + np.random.randint(width - 1)
x_tmp = 1 + np.random.randint(height - 2)
y_tmp = 1 + np.random.randint(width - 2)
to_close = False
for node_pos in node_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
......@@ -871,8 +871,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
available_nodes = available_nodes[np.argsort(node_dist)]
# Set number of neighboring nodes
# np.random.randint(1, max_neigbours)
if len(available_nodes) >= num_neighb:
connected_neighb_idx = available_nodes[
0:2] # np.random.choice(available_nodes, num_neighb, replace=False)
......@@ -887,6 +885,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
node_stack.pop(0)
# Place train stations close to the node
# We currently place them uniformly distirbuted among all cities
train_stations = [[] for i in range(num_cities)]
for station in range(num_trainstations):
......@@ -911,6 +911,11 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
new_path = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
(station_x, station_y))
# Fix all nodes with illegal transition maps
for current_node in node_positions:
if not grid_map.cell_neighbours_valid(current_node):
grid_map.fix_neighbours(current_node)
# Generate start and target node directory for all agents.
# Assure that start and target are not in the same node
agent_start_targets_nodes = []
......@@ -924,20 +929,24 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
# Assign agents to slots
for agent_idx in range(num_agents):
av_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
av_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
start_node = np.random.choice(av_start_nodes)
target_node = np.random.choice(av_target_nodes)
avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
start_node = np.random.choice(avail_start_nodes)
target_node = np.random.choice(avail_target_nodes)
tries = 0
while target_node == start_node:
target_node = np.random.choice(av_target_nodes)
target_node = np.random.choice(avail_target_nodes)
tries += 1
# Test again with new start node if no pair is found (This code needs to be improved)
if tries > 10:
start_node = np.random.choice(avail_start_nodes)
node_available_start[start_node] -= 1
node_available_target[target_node] -= 1
print(node_available_target, node_available_start)
agent_start_targets_nodes.append((start_node, target_node))
# Place agents and targets within available train stations
agents_position = []
agents_target = []
agents_direction = []
......@@ -956,7 +965,13 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
start = train_stations[current_start_node][start_station_idx]
agents_position.append((start[0], start[1]))
agents_target.append((target[0], target[1]))
agents_direction.append(0)
# Orient the agent correctly
for orientation in range(4):
transitions = grid_map.get_transitions(start[0], start[1], orientation)
if any(transitions) > 0:
agents_direction.append(orientation)
continue
agent_idx += 1
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
......
......@@ -23,23 +23,22 @@ def test_realistic_rail_generator():
env_renderer.close_window()
def test_sparse_rail_generator():
env = RailEnv(width=50,
env = RailEnv(width=20,
height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
num_intersections=3, # Number of interesections in map
num_trainstations=30, # Number of possible start/targets on map
num_trainstations=10, # Number of possible start/targets on map
min_node_dist=10, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center
num_neighb=4, # Number of connections to other cities
seed=15, # Random seed
),
number_of_agents=20,
number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv())
# reset to initialize agents_static
env_renderer = RenderTool(env, gl="PILSVG", )
env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
time.sleep(2)
time.sleep(20)
env_renderer.gl.save_image("flatalnd_2_0.png")
time.sleep(100)
time.sleep(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment