Skip to content
Snippets Groups Projects
Commit 98906167 authored by spiglerg's avatar spiglerg
Browse files

90% sure fix of serious bug in treesearch observation; now it seems to be working fine

parent e99a59af
No related branches found
No related tags found
No related merge requests found
...@@ -6,8 +6,8 @@ from flatland.envs.rail_env import * ...@@ -6,8 +6,8 @@ from flatland.envs.rail_env import *
from flatland.core.env_observation_builder import TreeObsForRailEnv from flatland.core.env_observation_builder import TreeObsForRailEnv
from flatland.utils.rendertools import * from flatland.utils.rendertools import *
random.seed(1) random.seed(0)
np.random.seed(1) np.random.seed(0)
""" """
transition_probability = [1.0, # empty cell - Case 0 transition_probability = [1.0, # empty cell - Case 0
...@@ -27,7 +27,7 @@ transition_probability = [1.0, # empty cell - Case 0 ...@@ -27,7 +27,7 @@ transition_probability = [1.0, # empty cell - Case 0
0.1, # Case 5 - double slip 0.1, # Case 5 - double slip
0.2, # Case 6 - symmetrical 0.2, # Case 6 - symmetrical
0.01] # Case 7 - dead end 0.01] # Case 7 - dead end
"""
# Example generate a random rail # Example generate a random rail
env = RailEnv(width=20, env = RailEnv(width=20,
height=20, height=20,
...@@ -38,7 +38,6 @@ env.reset() ...@@ -38,7 +38,6 @@ env.reset()
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
# Example generate a rail given a manual specification, # Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation) # a map of tuples (cell_type, rotation)
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)], specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
...@@ -51,19 +50,25 @@ env = RailEnv(width=6, ...@@ -51,19 +50,25 @@ env = RailEnv(width=6,
obs_builder_object=TreeObsForRailEnv(max_depth=2)) obs_builder_object=TreeObsForRailEnv(max_depth=2))
handle = env.get_agent_handles() handle = env.get_agent_handles()
env.agents_position[0] = [1, 4] env.agents_position[0] = [1, 4]
env.agents_target[0] = [1, 1] env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1 env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too! # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset() env.obs_builder.reset()
"""
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=1)
# TODO: delete next line # TODO: delete next line
#for i in range(4): for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i]) print(env.obs_builder.distance_map[0, :, :, i])
obs, all_rewards, done, _ = env.step({0:0}) obs, all_rewards, done, _ = env.step({0:0})
env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) for i in range(env.number_of_agents):
env.obs_builder.util_print_obs_subtree(tree=obs[i], num_elements_per_node=5)
env_renderer = RenderTool(env) env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True) env_renderer.renderEnv(show=True)
......
...@@ -103,7 +103,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -103,7 +103,6 @@ class TreeObsForRailEnv(ObservationBuilder):
node = nodes_queue.popleft() node = nodes_queue.popleft()
node_id = (node[0], node[1], node[2]) node_id = (node[0], node[1], node[2])
if node_id not in visited: if node_id not in visited:
visited.add(node_id) visited.add(node_id)
...@@ -126,58 +125,50 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -126,58 +125,50 @@ class TreeObsForRailEnv(ObservationBuilder):
""" """
neighbors = [] neighbors = []
for direction in range(4): possible_directions = [0, 1, 2, 3]
new_cell = self._new_position(position, (direction+2) % 4) if enforce_target_direction >= 0:
# The agent must land into the current cell with orientation `enforce_target_direction'.
# This is only possible if the agent has arrived from the cell in the opposite direction!
possible_directions = [(enforce_target_direction+2) % 4]
for neigh_direction in possible_directions:
new_cell = self._new_position(position, neigh_direction)
if new_cell[0] >= 0 and new_cell[0] < self.env.height and \ if new_cell[0] >= 0 and new_cell[0] < self.env.height and \
new_cell[1] >= 0 and new_cell[1] < self.env.width: new_cell[1] >= 0 and new_cell[1] < self.env.width:
# Check if the two cells are connected by a valid transition desired_movement_from_new_cell = (neigh_direction+2) % 4
transitionValid = False
for orientation in range(4): """
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation)) # Is the next cell a dead-end?
if moves[direction]: isNextCellDeadEnd = False
transitionValid = True nbits = 0
break tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
while tmp > 0:
if not transitionValid: nbits += (tmp & 1)
continue tmp = tmp >> 1
if nbits == 1:
# Check if a transition in direction node[2] is possible if an agent lands in the current # Dead-end!
# cell with orientation `direction'; this only applies to cells that are not dead-ends! isNextCellDeadEnd = True
directionMatch = True """
if enforce_target_direction >= 0:
directionMatch = self.env.rail.get_transition((new_cell[0], new_cell[1], direction), # Check all possible transitions in new_cell
enforce_target_direction) for agent_orientation in range(4):
# Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
# If transition is found to invalid, check if perhaps it is a dead-end, in which case the isValid = self.env.rail.get_transition((new_cell[0], new_cell[1], agent_orientation),
# direction of movement is rotated 180 degrees (moving forward turns the agents and makes desired_movement_from_new_cell)
# it step in the previous cell)
if not directionMatch: if isValid:
# If cell is a dead-end, append previous node with reversed """
# orientation! # TODO: check that it works with deadends! -- still bugged!
nbits = 0 movement = desired_movement_from_new_cell
tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1])) if isNextCellDeadEnd:
while tmp > 0: movement = (desired_movement_from_new_cell+2) % 4
nbits += (tmp & 1) """
tmp = tmp >> 1 new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation],
if nbits == 1: current_distance+1)
# Dead-end! neighbors.append((new_cell[0], new_cell[1], agent_orientation, new_distance))
# Check if transition is possible in new_cell with orientation self.distance_map[target_nr, new_cell[0], new_cell[1], agent_orientation] = new_distance
# (direction+2)%4 in direction `direction'
directionMatch = directionMatch or \
self.env.rail.get_transition((new_cell[0], new_cell[1], (direction+2) % 4),
direction)
if transitionValid and directionMatch:
# Append all possible orientations in new_cell that allow a transition to direction!
for orientation in range(4):
moves = self.env.rail.get_transitions((new_cell[0], new_cell[1], orientation))
if moves[direction]:
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], orientation],
current_distance+1)
neighbors.append((new_cell[0], new_cell[1], orientation, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1], orientation] = new_distance
return neighbors return neighbors
...@@ -309,16 +300,24 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -309,16 +300,24 @@ class TreeObsForRailEnv(ObservationBuilder):
exploring = False exploring = False
if num_transitions == 1: if num_transitions == 1:
# Check if dead-end, or if we can go forward along direction # Check if dead-end, or if we can go forward along direction
if cell_transitions[direction]: nbits = 0
position = self._new_position(position, direction) tmp = self.env.rail.get_transitions((position[0], position[1]))
while tmp > 0:
nbits += (tmp & 1)
tmp = tmp >> 1
if nbits == 1:
# Dead-end!
last_isDeadEnd = True
if not last_isDeadEnd:
# Keep walking through the tree along `direction' # Keep walking through the tree along `direction'
exploring = True exploring = True
else: for i in range(4):
# If a dead-end is reached, pick that as node. Also, no further branching is possible. if cell_transitions[i]:
last_isDeadEnd = True position = self._new_position(position, i)
break direction = i
break
elif num_transitions > 0: elif num_transitions > 0:
# Switch detected # Switch detected
...@@ -352,8 +351,6 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -352,8 +351,6 @@ class TreeObsForRailEnv(ObservationBuilder):
0, 0,
self.distance_map[handle, position[0], position[1], direction]] self.distance_map[handle, position[0], position[1], direction]]
# TODO:
# ############################# # #############################
# ############################# # #############################
......
...@@ -218,7 +218,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -218,7 +218,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
rot = 90 rot = 90
rail[row][col] = t_utils.rotate_transition( rail[row][col] = t_utils.rotate_transition(
int('0000000000100000', 2), rot) int('0010000000000000', 2), rot)
num_insertions += 1 num_insertions += 1
break break
...@@ -299,7 +299,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -299,7 +299,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
max_bit = max_bit | (neigh_trans_from_direction & 1) max_bit = max_bit | (neigh_trans_from_direction & 1)
if max_bit: if max_bit:
rail[r][0] = t_utils.rotate_transition( rail[r][0] = t_utils.rotate_transition(
int('0000000000100000', 2), 270) int('0010000000000000', 2), 270)
else: else:
rail[r][0] = int('0000000000000000', 2) rail[r][0] = int('0000000000000000', 2)
...@@ -312,7 +312,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -312,7 +312,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
& (2**4-1) & (2**4-1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 2)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 2))
if max_bit: if max_bit:
rail[r][-1] = t_utils.rotate_transition(int('0000000000100000', 2), rail[r][-1] = t_utils.rotate_transition(int('0010000000000000', 2),
90) 90)
else: else:
rail[r][-1] = int('0000000000000000', 2) rail[r][-1] = int('0000000000000000', 2)
...@@ -327,7 +327,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -327,7 +327,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
& (2**4-1) & (2**4-1)
max_bit = max_bit | (neigh_trans_from_direction & (1 << 3)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 3))
if max_bit: if max_bit:
rail[0][c] = int('0000000000100000', 2) rail[0][c] = int('0010000000000000', 2)
else: else:
rail[0][c] = int('0000000000000000', 2) rail[0][c] = int('0000000000000000', 2)
...@@ -341,7 +341,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -341,7 +341,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
max_bit = max_bit | (neigh_trans_from_direction & (1 << 1)) max_bit = max_bit | (neigh_trans_from_direction & (1 << 1))
if max_bit: if max_bit:
rail[-1][c] = t_utils.rotate_transition( rail[-1][c] = t_utils.rotate_transition(
int('0000000000100000', 2), 180) int('0010000000000000', 2), 180)
else: else:
rail[-1][c] = int('0000000000000000', 2) rail[-1][c] = int('0000000000000000', 2)
...@@ -352,6 +352,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8): ...@@ -352,6 +352,7 @@ def random_rail_generator(cell_type_relative_proportion=[1.0]*8):
rail[r][c] = int('0000000000000000', 2) rail[r][c] = int('0000000000000000', 2)
tmp_rail = np.asarray(rail, dtype=np.uint16) tmp_rail = np.asarray(rail, dtype=np.uint16)
return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils) return_rail = GridTransitionMap(width=width, height=height, transitions=t_utils)
return_rail.grid = tmp_rail return_rail.grid = tmp_rail
return return_rail return return_rail
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment