From 9b5a1655021665788fc6c7fb279281cf961910ed Mon Sep 17 00:00:00 2001 From: spiglerg <spiglerg@gmail.com> Date: Fri, 19 Apr 2019 14:14:50 +0200 Subject: [PATCH] fixed dead-ends in tree search --- examples/temporary_example.py | 6 +++--- flatland/core/env_observation_builder.py | 15 +++++++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/examples/temporary_example.py b/examples/temporary_example.py index 96acb30f..9d6046ec 100644 --- a/examples/temporary_example.py +++ b/examples/temporary_example.py @@ -9,14 +9,14 @@ from flatland.utils.rendertools import * random.seed(1) np.random.seed(1) - +""" # Example generate a random rail env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10) env.reset() env_renderer = RenderTool(env) env_renderer.renderEnv(show=True) - +""" # Example generate a rail given a manual specification, # a map of tuples (cell_type, rotation) @@ -27,7 +27,7 @@ env = RailEnv(width=6, height=2, rail_generator=rail_from_manual_specifications_generator(specs), number_of_agents=1, - obs_builder_object=TreeObsForRailEnv(max_depth=1)) + obs_builder_object=TreeObsForRailEnv(max_depth=2)) handle = env.get_agent_handles() diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py index 34a5bb53..e2e74b54 100644 --- a/flatland/core/env_observation_builder.py +++ b/flatland/core/env_observation_builder.py @@ -263,6 +263,7 @@ class TreeObsForRailEnv(ObservationBuilder): # We treat dead-ends as nodes, instead of going back, to avoid loops exploring = True last_isSwitch = False + last_isDeadEnd = False # TODO: last_isTerminal = False # dead-end # TODO: last_isTarget = False while exploring: @@ -301,7 +302,7 @@ class TreeObsForRailEnv(ObservationBuilder): else: # If a dead-end is reached, pick that as node. Also, no further branching is possible. - # TODO: last_isTerminal = True + last_isDeadEnd = True break elif num_transitions > 0: @@ -331,7 +332,17 @@ class TreeObsForRailEnv(ObservationBuilder): # Start from the current orientation, and see which transitions are available; # organize them as [left, forward, right, back], relative to the current orientation for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]: - if last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), branch_direction): + if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction), + (branch_direction+2) % 4): + # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes + # it back + new_cell = self._new_position(position, (branch_direction+2)%4) + + branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2)%4, depth+1) + observation = observation + branch_observation + + elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), + branch_direction): new_cell = self._new_position(position, branch_direction) branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1) -- GitLab