Skip to content
Snippets Groups Projects
Commit a7129610 authored by spiglerg's avatar spiglerg
Browse files

important fixes to treesearch

parent a4bcd315
No related branches found
No related tags found
No related merge requests found
......@@ -17,7 +17,7 @@ transition_probability = [1.0, # empty cell - Case 0
0.5, # Case 5 - double slip
0.2, # Case 6 - symmetrical
0.0] # Case 7 - dead end
"""
# Example generate a random rail
env = RailEnv(width=20,
height=20,
......@@ -33,7 +33,7 @@ env.reset()
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
"""
"""
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
......@@ -52,12 +52,35 @@ env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
"""
specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
[(7, 270), (1, 90), (1, 90), (2, 270), (2, 0), (0, 0)],
[(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)],
[(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]]
env = RailEnv(width=6,
height=4,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=2))
handle = env.get_agent_handles()
env.agents_position[0] = [1, 3]
env.agents_target[0] = [1, 1]
env.agents_direction[0] = 1
# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
env.obs_builder.reset()
"""
env = RailEnv(width=7,
height=7,
rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
number_of_agents=2)
"""
# Print the distance map of each cell to the target of the first agent
# for i in range(4):
# print(env.obs_builder.distance_map[0, :, :, i])
......
......@@ -169,10 +169,10 @@ class TreeObsForRailEnv(ObservationBuilder):
(new_cell[0], new_cell[1], (direction+2) % 4), direction)
if transitionValid and directionMatch:
new_distance = min(self.distance_map[target_nr,
new_cell[0], new_cell[1]], current_distance+1)
new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction],
current_distance+1)
neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance
possible_directions = [0, 1, 2, 3]
if enforce_target_direction >= 0:
......@@ -319,13 +319,15 @@ class TreeObsForRailEnv(ObservationBuilder):
exploring = True
last_isSwitch = False
last_isDeadEnd = False
# TODO: last_isTerminal = False # wrong cell encountered
last_isTerminal = False # wrong cell encountered OR cycle encountered; either way, we don't want the agent
# to land here
last_isTarget = False
visited = set([position[0], position[1], direction])
other_agent_encountered = False
other_target_encountered = False
while exploring:
# #############################
# #############################
# Modify here to compute any useful data required to build the end node's features. This code is called
......@@ -340,6 +342,10 @@ class TreeObsForRailEnv(ObservationBuilder):
# #############################
# #############################
if (position[0], position[1], direction) in visited:
last_isTerminal = True
break
# If the target node is encountered, pick that as node. Also, no further branching is possible.
if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
last_isTarget = True
......@@ -380,9 +386,11 @@ class TreeObsForRailEnv(ObservationBuilder):
elif num_transitions == 0:
# Wrong cell type, but let's cover it and treat it as a dead-end, just in case
# TODO: last_isTerminal = True
last_isTerminal = True
break
visited.add((position[0], position[1], direction))
# `position' is either a terminal node or a switch
observation = []
......@@ -398,6 +406,12 @@ class TreeObsForRailEnv(ObservationBuilder):
0,
0]
elif last_isTerminal:
observation = [0,
1 if other_target_encountered else 0,
1 if other_agent_encountered else 0,
0,
np.inf]
else:
observation = [0,
1 if other_target_encountered else 0,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment