From 1e1a128291e1c60a2363304af0b4929378c0d7aa Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 3 May 2019 11:32:45 +0200
Subject: [PATCH] fixed tree observation error. testing observation. replaced a
 few for loops with numpy functions

---
 examples/training_navigation.py          |  5 +++--
 flatland/core/env_observation_builder.py | 26 ++++++++----------------
 flatland/envs/rail_env.py                |  3 +++
 3 files changed, 15 insertions(+), 19 deletions(-)

diff --git a/examples/training_navigation.py b/examples/training_navigation.py
index 9fc83242..1111e0bb 100644
--- a/examples/training_navigation.py
+++ b/examples/training_navigation.py
@@ -1,4 +1,5 @@
 from flatland.envs.rail_env import *
+from flatland.envs.generators import *
 from flatland.core.env_observation_builder import TreeObsForRailEnv
 from flatland.utils.rendertools import *
 from flatland.baselines.dueling_double_dqn import Agent
@@ -54,9 +55,9 @@ scores = []
 dones_list = []
 action_prob = [0] * 4
 agent = Agent(state_size, action_size, "FC", 0)
-agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth'))
+#agent.qnetwork_local.load_state_dict(torch.load('../flatland/baselines/Nets/avoid_checkpoint14900.pth'))
 
-demo = True
+demo = False
 
 
 def max_lt(seq, val):
diff --git a/flatland/core/env_observation_builder.py b/flatland/core/env_observation_builder.py
index 6c27afdc..edd9cae7 100644
--- a/flatland/core/env_observation_builder.py
+++ b/flatland/core/env_observation_builder.py
@@ -236,6 +236,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         position = self.env.agents_position[handle]
         orientation = self.env.agents_direction[handle]
+        possible_transitions = self.env.rail.get_transitions((position[0], position[1], orientation))
 
         # Root node - current position
         observation = [0, 0, 0, 0, self.distance_map[handle, position[0], position[1], orientation]]
@@ -245,7 +246,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         # organize them as [left, forward, right, back], relative to the current orientation
         # TODO: Adjust this to the novel movement dynamics --> Only Forward present when one transition is possible.
         for branch_direction in [(orientation + 4 + i) % 4 for i in range(-1, 3)]:
-            if self.env.rail.get_transition((position[0], position[1], orientation), branch_direction):
+            if possible_transitions[branch_direction]:
                 new_cell = self._new_position(position, branch_direction)
 
                 branch_observation = self._explore_branch(handle, new_cell, branch_direction, root_observation, 1)
@@ -308,11 +309,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 break
 
             cell_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
-            num_transitions = 0
-            for i in range(4):
-                if cell_transitions[i]:
-                    num_transitions += 1
-
+            num_transitions = np.count_nonzero(cell_transitions)
             exploring = False
             if num_transitions == 1:
                 # Check if dead-end, or if we can go forward along direction
@@ -328,13 +325,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if not last_isDeadEnd:
                     # Keep walking through the tree along `direction'
                     exploring = True
-                    # TODO: Remove below calculation, this is computed already above and could be reused
-                    for i in range(4):
-                        if cell_transitions[i]:
-                            position = self._new_position(position, i)
-                            direction = i
-                            num_steps += 1
-                            break
+                    direction = np.argmax(cell_transitions)
+                    position = self._new_position(position, direction)
+                    num_steps += 1
 
             elif num_transitions > 0:
                 # Switch detected
@@ -383,13 +376,14 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
+        # Get the possible transitions
+        possible_transitions = self.env.rail.get_transitions((position[0], position[1], direction))
         for branch_direction in [(direction + 4 + i) % 4 for i in range(-1, 3)]:
             if last_isDeadEnd and self.env.rail.get_transition((position[0], position[1], direction),
                                                                (branch_direction + 2) % 4):
                 # Swap forward and back in case of dead-end, so that an agent can learn that going forward takes
                 # it back
                 new_cell = self._new_position(position, (branch_direction + 2) % 4)
-
                 branch_observation = self._explore_branch(handle,
                                                           new_cell,
                                                           (branch_direction + 2) % 4,
@@ -397,10 +391,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                           depth + 1)
                 observation = observation + branch_observation
 
-            elif last_isSwitch and self.env.rail.get_transition((position[0], position[1], direction),
-                                                                (branch_direction + 2) % 4):
+            elif last_isSwitch and possible_transitions[branch_direction]:
                 new_cell = self._new_position(position, branch_direction)
-
                 branch_observation = self._explore_branch(handle,
                                                           new_cell,
                                                           branch_direction,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 779ec058..9d67f830 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -323,6 +323,7 @@ class RailEnv(Environment):
                     nbits += (tmp & 1)
                     tmp = tmp >> 1
                 movement = direction
+                #print(nbits,np.sum(possible_transitions))
                 if action == 1:
                     movement = direction - 1
                     if nbits <= 2 or np.sum(possible_transitions) <= 1:
@@ -360,12 +361,14 @@ class RailEnv(Environment):
                             direction = reverse_direction
                             movement = reverse_direction
                             is_deadend = True
+
                     if np.sum(possible_transitions) == 1:
                         # Checking for curves
                         curv_dir = np.argmax(possible_transitions)
                         # valid_transition = self.rail.get_transition(
                         #    (pos[0], pos[1], direction),
                         #    movement)
+
                         movement = curv_dir
                 new_position = self._new_position(pos, movement)
                 # Is it a legal move?  1) transition allows the movement in the
-- 
GitLab