From 9b5a1655021665788fc6c7fb279281cf961910ed Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Fri, 19 Apr 2019 14:14:50 +0200
Subject: [PATCH] fixed dead-ends in tree search

---
 examples/temporary_example.py            |  6 +++---
 flatland/core/env_observation_builder.py | 15 +++++++++++++--
 2 files changed, 16 insertions(+), 5 deletions(-)

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 96acb30f..9d6046ec 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -9,14 +9,14 @@ from flatland.utils.rendertools import *
 random.seed(1)
 np.random.seed(1)
 
-
+"""
 # Example generate a random rail
 env = RailEnv(width=20, height=20, rail_generator=random_rail_generator, number_of_agents=10)
 env.reset()
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
-
+"""
 
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
@@ -27,7 +27,7 @@ env = RailEnv(width=6,
               height=2,
               rail_generator=rail_from_manual_specifications_generator(specs),
               number_of_agents=1,
-              obs_builder_object=TreeObsForRailEnv(max_depth=1))
+              obs_builder_object=TreeObsForRailEnv(max_depth=2))
 
 handle = env.get_agent_handles()
 
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 34a5bb53..e2e74b54 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -263,6 +263,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # We treat dead-ends as nodes, instead of going back, to avoid loops
         exploring = True
         last_isSwitch = False
+        last_isDeadEnd = False
         # TODO: last_isTerminal = False  # dead-end
         # TODO: last_isTarget = False
         while exploring:
@@ -301,7 +302,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
                 else:
                     # If a dead-end is reached, pick that as node. Also, no further branching is possible.
-                    # TODO: last_isTerminal = True
+                    last_isDeadEnd = True
                     break
 
             elif num_transitions > 0:
@@ -331,7 +332,17 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         for branch_direction in [(direction+4+i) % 4 for i in range(-1, 3)]:
-            if last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction), branch_direction):
+            if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
+                                                               (branch_direction+2) % 4):
+                # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
+                # it back
+                new_cell = self._new_position(position, (branch_direction+2)%4)
+
+                branch_observation = self._explore_branch(handle, new_cell, (branch_direction+2)%4, depth+1)
+                observation = observation + branch_observation
+
+            elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
+                                                                branch_direction):
                 new_cell = self._new_position(position, branch_direction)
 
                 branch_observation = self._explore_branch(handle, new_cell, branch_direction, depth+1)
-- 
GitLab