Commit 6f9df191 authored by Erik Nygren's avatar Erik Nygren 🚅
Browse files

update the the way we fix cells in the map.

Now we first fix empty cells with incoming connections
then we fix cells with rails and wrong transition maps.
parent f5e750a2
Pipeline #2167 failed with stages
in 18 minutes and 39 seconds
......@@ -41,7 +41,7 @@ env = RailEnv(width=50,
seed=15, # Random seed
grid_mode=True,
nr_inter_connections=1,
max_nr_connection_points=8
max_nr_connection_points=12
),
schedule_generator=sparse_schedule_generator(),
number_of_agents=50,
......
......@@ -422,6 +422,28 @@ class GridTransitionMap(TransitionMap):
continue
else:
return False
# If the cell is empty but has incoming connections we return false
if binTrans < 1:
connected = 0
for iDirOut in np.arange(4):
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then ignore it for the count of incoming connections
if np.any(gPos2 < 0):
continue
if np.any(gPos2 >= grcMax):
continue
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
for orientation in range(4):
connected += self.get_transition((gPos2[0], gPos2[1], orientation), mirror(iDirOut))
if connected > 0:
return False
return True
......
......@@ -252,7 +252,6 @@ class RailEnv(Environment):
rc_pos = (r, c)
check = self.rail.cell_neighbours_valid(rc_pos, True)
if not check:
self.rail.fix_transitions(rc_pos)
warnings.warn("Invalid grid at {} -> {}".format(rc_pos, check))
if replace_agents:
......
......@@ -707,6 +707,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
connection_points[trainstation_node][corner_node_idx],
(station_x, station_y))
if len(connection) != 0:
if (connection_points[trainstation_node][corner_node_idx],
trainstation_node) in boarder_connections:
boarder_connections.remove(
......@@ -717,6 +718,7 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
if len(train_stations[trainstation_node]) > 0:
train_stations[trainstation_node].pop(-1)
else:
built_num_trainstation += 1
# Adjust the number of agents if you could not build enough trainstations
if num_agents > built_num_trainstation:
......@@ -749,14 +751,25 @@ def sparse_rail_generator(num_cities=5, num_intersections=4, num_trainstations=2
boarder_connections.remove(tbd)
print(boarder_connections)
# Fix all nodes with illegal transition maps
flat_trainstation_list = [item for sublist in train_stations for item in sublist]
for cell_to_fix in flat_trainstation_list:
grid_map.fix_transitions(cell_to_fix)
empty_to_fix = []
rails_to_fix = []
for r in range(height):
for c in range(width):
rc_pos = (r, c)
check = grid_map.cell_neighbours_valid(rc_pos, True)
if not check:
if grid_map.grid[rc_pos] == 0:
empty_to_fix.append(rc_pos)
else:
rails_to_fix.append(rc_pos)
flat_list = [item for sublist in connection_points for item in sublist]
# Fix empty cells first to avoid cutting the network
for cell in empty_to_fix:
grid_map.fix_transitions(cell)
for cell_to_fix in flat_list:
grid_map.fix_transitions(cell_to_fix)
# Fix all other cells
for cell in rails_to_fix:
grid_map.fix_transitions(cell)
# Generate start and target node directory for all agents.
# Assure that start and target are not in the same node
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment