Skip to content
Snippets Groups Projects
Commit 9b5a1655 authored by spiglerg's avatar spiglerg
Browse files

fixed dead-ends in tree search

parent 3fbfb8af
No related branches found
No related tags found
No related merge requests found
......@@ -9,14 +9,14 @@ from flatland.utils.rendertools import *
random.seed(1)
np.random.seed(1)
"""
# Example generate a random rail
env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10)
env.reset()
env_renderer = RenderTool(env)
env_renderer.renderEnv(show=True)
"""
# Example generate a rail given a manual specification,
# a map of tuples (cell_type, rotation)
......@@ -27,7 +27,7 @@ env = RailEnv(width=6,
height=2,
rail_generator=rail_from_manual_specifications_generator(specs),
number_of_agents=1,
obs_builder_object=TreeObsForRailEnv(max_depth=1))
obs_builder_object=TreeObsForRailEnv(max_depth=2))
handle = env.get_agent_handles()
......
......@@ -263,6 +263,7 @@ class TreeObsForRailEnv(ObservationBuilder):
# We treat dead-ends as nodes, instead of going back, to avoid loops
exploring = True
last_isSwitch = False
last_isDeadEnd = False
# TODO: last_isTerminal = False # dead-end
# TODO: last_isTarget = False
while exploring:
......@@ -301,7 +302,7 @@ class TreeObsForRailEnv(ObservationBuilder):
else:
# If a dead-end is reached, pick that as node. Also, no further branching is possible.
# TODO: last_isTerminal = True
last_isDeadEnd = True
break
elif num_transitions > 0:
......@@ -331,7 +332,17 @@ class TreeObsForRailEnv(ObservationBuilder):
# Start from the current orientation, and see which transitions are available;
# organize them as [left, forward, right, back], relative to the current orientation
for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
if last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), branch_direction):
if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
(branch_direction+2) % 4):
# Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
# it back
new_cell = self._new_position(position, (branch_direction+2)%4)
branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2)%4, depth+1)
observation = observation + branch_observation
elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
branch_direction):
new_cell = self._new_position(position, branch_direction)
branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment