From a712961050c185022444d676eb8c38bd4131d9f6 Mon Sep 17 00:00:00 2001
From: spiglerg <spiglerg@gmail.com>
Date: Tue, 23 Apr 2019 12:24:29 +0200
Subject: [PATCH] important fixes to treesearch

---
 examples/temporary_example.py            | 29 +++++++++++++++++++++---
 flatland/core/env_observation_builder.py | 26 ++++++++++++++++-----
 2 files changed, 46 insertions(+), 9 deletions(-)

diff --git a/examples/temporary_example.py b/examples/temporary_example.py
index 02c282c..67d0fe6 100644
--- a/examples/temporary_example.py
+++ b/examples/temporary_example.py
@@ -17,7 +17,7 @@ transition_probability = [1.0,  # empty cell - Case 0
                           0.5,  # Case 5 - double slip
                           0.2,  # Case 6 - symmetrical
                           0.0]  # Case 7 - dead end
-
+"""
 # Example generate a random rail
 env = RailEnv(width=20,
               height=20,
@@ -33,7 +33,7 @@ env.reset()
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True)
-
+"""
 """
 # Example generate a rail given a manual specification,
 # a map of tuples (cell_type, rotation)
@@ -52,12 +52,35 @@ env.agents_target[0] = [1, 1]
 env.agents_direction[0] = 1
 # TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
 env.obs_builder.reset()
+"""
+
+
+specs = [[(0, 0), (0, 0), (0, 0), (0, 0), (7, 0), (0, 0)],
+         [(7, 270), (1, 90), (1, 90), (2, 270), (2, 0), (0, 0)],
+         [(0, 0), (0, 0), (0, 0), (2, 180), (2, 90), (7, 90)],
+         [(0, 0), (0, 0), (0, 0), (7, 180), (0, 0), (0, 0)]]
+
+env = RailEnv(width=6,
+              height=4,
+              rail_generator=rail_from_manual_specifications_generator(specs),
+              number_of_agents=1,
+              obs_builder_object=TreeObsForRailEnv(max_depth=2))
+
+handle = env.get_agent_handles()
+env.agents_position[0] = [1, 3]
+env.agents_target[0] = [1, 1]
+env.agents_direction[0] = 1
+# TODO: watch out: if these variables are overridden, the obs_builder object has to be reset, too!
+env.obs_builder.reset()
+
+
+
 """
 env = RailEnv(width=7,
               height=7,
               rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
               number_of_agents=2)
-
+"""
 # Print the distance map of each cell to the target of the first agent
 # for i in range(4):
 #     print(env.obs_builder.distance_map[0, :, :, i])
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index b0730d8..fd71bb9 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -169,10 +169,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                             (new_cell[0], new_cell[1], (direction+2) % 4), direction)
 
                 if transitionValid and directionMatch:
-                    new_distance = min(self.distance_map[target_nr,
-                                                         new_cell[0], new_cell[1]], current_distance+1)
+                    new_distance = min(self.distance_map[target_nr, new_cell[0], new_cell[1], direction],
+                                       current_distance+1)
                     neighbors.append((new_cell[0], new_cell[1], direction, new_distance))
-                    self.distance_map[target_nr, new_cell[0], new_cell[1]] = new_distance
+                    self.distance_map[target_nr, new_cell[0], new_cell[1], direction] = new_distance
 
         possible_directions = [0, 1, 2, 3]
         if enforce_target_direction >= 0:
@@ -319,13 +319,15 @@ class TreeObsForRailEnv(ObservationBuilder):
         exploring = True
         last_isSwitch = False
         last_isDeadEnd = False
-        # TODO: last_isTerminal = False  # wrong cell encountered
+        last_isTerminal = False  # wrong cell encountered OR cycle encountered;  either way, we don't want the agent
+                                 # to land here
         last_isTarget = False
 
+        visited = set([position[0], position[1], direction])
+
         other_agent_encountered = False
         other_target_encountered = False
         while exploring:
-
             # #############################
             # #############################
             # Modify here to compute any useful data required to build the end node's features. This code is called
@@ -340,6 +342,10 @@ class TreeObsForRailEnv(ObservationBuilder):
             # #############################
             # #############################
 
+            if (position[0], position[1], direction) in visited:
+                last_isTerminal = True
+                break
+
             # If the target node is encountered, pick that as node. Also, no further branching is possible.
             if position[0] == self.env.agents_target[handle][0] and position[1] == self.env.agents_target[handle][1]:
                 last_isTarget = True
@@ -380,9 +386,11 @@ class TreeObsForRailEnv(ObservationBuilder):
 
             elif num_transitions == 0:
                 # Wrong cell type, but let's cover it and treat it as a dead-end, just in case
-                # TODO: last_isTerminal = True
+                last_isTerminal = True
                 break
 
+            visited.add((position[0], position[1], direction))
+
         # `position' is either a terminal node or a switch
 
         observation = []
@@ -398,6 +406,12 @@ class TreeObsForRailEnv(ObservationBuilder):
                            0,
                            0]
 
+        elif last_isTerminal:
+            observation = [0,
+                           1 if other_target_encountered else 0,
+                           1 if other_agent_encountered else 0,
+                           0,
+                           np.inf]
         else:
             observation = [0,
                            1 if other_target_encountered else 0,
-- 
GitLab