From 725b98de64ae9abf10e9ca01844bf3ea16de3e23 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 13 Jun 2019 18:38:44 +0200
Subject: [PATCH] fixed bugs in tree observation

---
 examples/training_example.py  |  1 -
 flatland/envs/observations.py | 26 ++++++++++++++++++--------
 2 files changed, 18 insertions(+), 9 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index ad222cb..c785804 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -81,7 +81,6 @@ for trials in range(1, n_trials + 1):
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
         TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
-        print(len(next_obs[0]))
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index bab7e63..4fa0e09 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -212,11 +212,11 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         Finally, each node information is composed of 5 floating point values:
 
-        #1:
+        #1: 1 if own target lies on the explored branch
 
-        #2: 1 if a target of another agent is detected between the previous node and the current one.
+        #2: distance toa target of another agent is detected between the previous node and the current one.
 
-        #3: 1 if another agent is detected between the previous node and the current one.
+        #3: distance to another agent is detected between the previous node and the current one.
 
         #4: distance of agent to the current branch node
 
@@ -255,6 +255,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         agent = self.env.agents[handle]  # TODO: handle being treated as index
         possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
+
         # Root node - current position
         observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
 
@@ -304,7 +305,8 @@ class TreeObsForRailEnv(ObservationBuilder):
         last_isTarget = False
 
         visited = set()
-
+        agent = self.env.agents[handle]
+        own_target_encountered = np.inf
         other_agent_encountered = np.inf
         other_target_encountered = np.inf
         other_agent_same_direction = 0
@@ -334,9 +336,14 @@ class TreeObsForRailEnv(ObservationBuilder):
                                                                                    handle):
                     potential_conflict = 1
 
-            if position in self.location_has_target:
+            if position in self.location_has_target and position != agent.target:
                 if num_steps < other_target_encountered:
                     other_target_encountered = num_steps
+
+            if position == agent.target:
+                if num_steps < own_target_encountered:
+                    own_target_encountered = num_steps
+
             # #############################
             # #############################
 
@@ -428,7 +435,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         """
 
         if last_isTarget:
-            observation = [0,
+            observation = [own_target_encountered,
                            other_target_encountered,
                            other_agent_encountered,
                            root_observation[3] + num_steps,
@@ -439,7 +446,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            ]
 
         elif last_isTerminal:
-            observation = [0,
+            observation = [own_target_encountered,
                            other_target_encountered,
                            other_agent_encountered,
                            np.inf,
@@ -449,7 +456,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                            potential_conflict
                            ]
         else:
-            observation = [0,
+            observation = [own_target_encountered,
                            other_target_encountered,
                            other_agent_encountered,
                            root_observation[3] + num_steps,
@@ -549,8 +556,11 @@ class TreeObsForRailEnv(ObservationBuilder):
             pow4 *= 4
         child_size = (len(tree) - num_features_per_node) // 4
         tree_data = tree[:4].tolist()
+        # print("data",tree_data)
         distance_data = [tree[4]]
+        #print("distance",distance_data)
         agent_data = tree[5:num_features_per_node].tolist()
+        #print("agent_data",agent_data)
         for children in range(4):
             child_tree = tree[(num_features_per_node + children * child_size):
                               (num_features_per_node + (children + 1) * child_size)]
-- 
GitLab