Skip to content
Snippets Groups Projects
Commit 7ca1ed3b authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Fixed invalid transitions: now all nodes are valid

All starting positions are valid.
Check for feasibility of map is not yet done.
parent 0334070a
No related branches found
No related tags found
No related merge requests found
...@@ -77,7 +77,7 @@ for trials in range(1, n_trials + 1): ...@@ -77,7 +77,7 @@ for trials in range(1, n_trials + 1):
score = 0 score = 0
# Run episode # Run episode
for step in range(100): for step in range(500):
# Chose an action for each agent in the environment # Chose an action for each agent in the environment
for a in range(env.get_num_agents()): for a in range(env.get_num_agents()):
action = agent.act(obs[a]) action = agent.act(obs[a])
......
...@@ -350,4 +350,60 @@ class GridTransitionMap(TransitionMap): ...@@ -350,4 +350,60 @@ class GridTransitionMap(TransitionMap):
return True return True
def fix_neighbours(self, rcPos, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
Checks that:
- surrounding cells have inbound transitions for all the
outbound transitions of this cell.
These are NOT checked - see transition.is_valid:
- all transitions have the mirror transitions (N->E <=> W->S)
- Reverse transitions (N -> S) only exist for a dead-end
- a cell contains either no dead-ends or exactly one
Returns: True (valid) or False (invalid)
"""
cell_transition = self.grid[tuple(rcPos)]
if check_this_cell:
if not self.transitions.is_valid(cell_transition):
return False
gDir2dRC = self.transitions.gDir2dRC # [[-1,0] = N, [0,1]=E, etc]
grcPos = array(rcPos)
grcMax = self.grid.shape
binTrans = self.get_full_transitions(*rcPos) # 16bit integer - all trans in/out
lnBinTrans = array([binTrans >> 8, binTrans & 0xff], dtype=np.uint8) # 2 x uint8
g2binTrans = np.unpackbits(lnBinTrans).reshape(4, 4) # 4x4 x uint8 binary(0,1)
gDirOut = g2binTrans.any(axis=0) # outbound directions as boolean array (4)
giDirOut = np.argwhere(gDirOut)[:, 0] # valid outbound directions as array of int
# loop over available outbound directions (indices) for rcPos
for iDirOut in giDirOut:
gdRC = gDir2dRC[iDirOut] # row,col increment
gPos2 = grcPos + gdRC # next cell in that direction
# Check the adjacent cell is within bounds
# if not, then this transition is invalid!
if np.any(gPos2 < 0):
return False
if np.any(gPos2 >= grcMax):
return False
# Get the transitions out of gPos2, using iDirOut as the inbound direction
# if there are no available transitions, ie (0,0,0,0), then rcPos is invalid
t4Trans2 = self.get_transitions(*gPos2, iDirOut)
if any(t4Trans2):
continue
else:
self.set_transition((gPos2[0], gPos2[1], iDirOut), mirror(iDirOut), 1)
return False
return True
def mirror(dir):
return (dir + 2) % 4
# TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?) # TODO: improvement override __getitem__ and __setitem__ (cell contents, not transitions?)
...@@ -841,8 +841,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -841,8 +841,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
to_close = True to_close = True
tries = 0 tries = 0
while to_close: while to_close:
x_tmp = 1 + np.random.randint(height - 1) x_tmp = 1 + np.random.randint(height - 2)
y_tmp = 1 + np.random.randint(width - 1) y_tmp = 1 + np.random.randint(width - 2)
to_close = False to_close = False
for node_pos in node_positions: for node_pos in node_positions:
if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist: if distance_on_rail((x_tmp, y_tmp), node_pos) < min_node_dist:
...@@ -871,8 +871,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -871,8 +871,6 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
available_nodes = available_nodes[np.argsort(node_dist)] available_nodes = available_nodes[np.argsort(node_dist)]
# Set number of neighboring nodes # Set number of neighboring nodes
# np.random.randint(1, max_neigbours)
if len(available_nodes) >= num_neighb: if len(available_nodes) >= num_neighb:
connected_neighb_idx = available_nodes[ connected_neighb_idx = available_nodes[
0:2] # np.random.choice(available_nodes, num_neighb, replace=False) 0:2] # np.random.choice(available_nodes, num_neighb, replace=False)
...@@ -887,6 +885,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -887,6 +885,8 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
node_stack.pop(0) node_stack.pop(0)
# Place train stations close to the node # Place train stations close to the node
# We currently place them uniformly distirbuted among all cities
train_stations = [[] for i in range(num_cities)] train_stations = [[] for i in range(num_cities)]
for station in range(num_trainstations): for station in range(num_trainstations):
...@@ -911,6 +911,11 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -911,6 +911,11 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
new_path = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node], new_path = connect_from_nodes(rail_trans, rail_array, node_positions[trainstation_node],
(station_x, station_y)) (station_x, station_y))
# Fix all nodes with illegal transition maps
for current_node in node_positions:
if not grid_map.cell_neighbours_valid(current_node):
grid_map.fix_neighbours(current_node)
# Generate start and target node directory for all agents. # Generate start and target node directory for all agents.
# Assure that start and target are not in the same node # Assure that start and target are not in the same node
agent_start_targets_nodes = [] agent_start_targets_nodes = []
...@@ -924,20 +929,24 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -924,20 +929,24 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
# Assign agents to slots # Assign agents to slots
for agent_idx in range(num_agents): for agent_idx in range(num_agents):
av_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0] avail_start_nodes = [idx for idx, val in enumerate(node_available_start) if val > 0]
av_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0] avail_target_nodes = [idx for idx, val in enumerate(node_available_target) if val > 0]
start_node = np.random.choice(av_start_nodes) start_node = np.random.choice(avail_start_nodes)
target_node = np.random.choice(av_target_nodes) target_node = np.random.choice(avail_target_nodes)
tries = 0
while target_node == start_node: while target_node == start_node:
target_node = np.random.choice(av_target_nodes) target_node = np.random.choice(avail_target_nodes)
tries += 1
# Test again with new start node if no pair is found (This code needs to be improved)
if tries > 10:
start_node = np.random.choice(avail_start_nodes)
node_available_start[start_node] -= 1 node_available_start[start_node] -= 1
node_available_target[target_node] -= 1 node_available_target[target_node] -= 1
print(node_available_target, node_available_start)
agent_start_targets_nodes.append((start_node, target_node)) agent_start_targets_nodes.append((start_node, target_node))
# Place agents and targets within available train stations # Place agents and targets within available train stations
agents_position = [] agents_position = []
agents_target = [] agents_target = []
agents_direction = [] agents_direction = []
...@@ -956,7 +965,13 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation ...@@ -956,7 +965,13 @@ def sparse_rail_generator(num_cities=100, num_intersections=10, num_trainstation
start = train_stations[current_start_node][start_station_idx] start = train_stations[current_start_node][start_station_idx]
agents_position.append((start[0], start[1])) agents_position.append((start[0], start[1]))
agents_target.append((target[0], target[1])) agents_target.append((target[0], target[1]))
agents_direction.append(0)
# Orient the agent correctly
for orientation in range(4):
transitions = grid_map.get_transitions(start[0], start[1], orientation)
if any(transitions) > 0:
agents_direction.append(orientation)
continue
agent_idx += 1 agent_idx += 1
return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
......
...@@ -23,23 +23,22 @@ def test_realistic_rail_generator(): ...@@ -23,23 +23,22 @@ def test_realistic_rail_generator():
env_renderer.close_window() env_renderer.close_window()
def test_sparse_rail_generator(): def test_sparse_rail_generator():
env = RailEnv(width=20,
env = RailEnv(width=50,
height=50, height=50,
rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map rail_generator=sparse_rail_generator(num_cities=10, # Number of cities in map
num_intersections=3, # Number of interesections in map num_intersections=3, # Number of interesections in map
num_trainstations=30, # Number of possible start/targets on map num_trainstations=10, # Number of possible start/targets on map
min_node_dist=10, # Minimal distance of nodes min_node_dist=10, # Minimal distance of nodes
node_radius=2, # Proximity of stations to city center node_radius=2, # Proximity of stations to city center
num_neighb=4, # Number of connections to other cities num_neighb=4, # Number of connections to other cities
seed=15, # Random seed seed=15, # Random seed
), ),
number_of_agents=20, number_of_agents=1,
obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=GlobalObsForRailEnv())
# reset to initialize agents_static # reset to initialize agents_static
env_renderer = RenderTool(env, gl="PILSVG", ) env_renderer = RenderTool(env, gl="PILSVG", )
env_renderer.render_env(show=True, show_observations=True, show_predictions=False) env_renderer.render_env(show=True, show_observations=True, show_predictions=False)
time.sleep(2) time.sleep(20)
env_renderer.gl.save_image("flatalnd_2_0.png") env_renderer.gl.save_image("flatalnd_2_0.png")
time.sleep(100) time.sleep(1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment