From 2df4a43402e5f4691fb777ab7306d6d4091d85be Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 2 Oct 2019 22:48:41 -0400
Subject: [PATCH] fixed bug in tree observation.

---
 examples/flatland_2_0_example.py |  4 ++--
 flatland/envs/observations.py    | 11 ++++-------
 2 files changed, 6 insertions(+), 9 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 6453ccf7..40b41591 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -1,6 +1,6 @@
 import numpy as np
 
-from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
+from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import sparse_rail_generator
@@ -39,7 +39,7 @@ env = RailEnv(width=40,
               schedule_generator=sparse_schedule_generator(speed_ration_map),
               number_of_agents=20,
               stochastic_data=stochastic_data,  # Malfunction data generator
-              obs_builder_object=GlobalObsForRailEnv(),
+              obs_builder_object=TreeObservation,
               remove_agents_at_target=True
               )
 
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 7b7aacaf..1edf943b 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -174,7 +174,6 @@ class TreeObsForRailEnv(ObservationBuilder):
             else:
                 self.location_has_agent_direction[(agent.position, agent.direction)] = 1
 
-
         self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
         self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
                                                self.env.agents}
@@ -271,9 +270,9 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                     malfunctioning_agent = self.location_has_agent_malfunction[position]
 
-                if (agent.position, agent.direction) in self.location_has_agent_direction:
+                if (position, direction) in self.location_has_agent_direction:
                     # Cummulate the number of agents on branch with same direction
-                    other_agent_same_direction += self.location_has_agent_direction[(agent.position, agent.direction)]
+                    other_agent_same_direction += self.location_has_agent_direction[(position, direction)]
 
                     # Check fractional speed of agents
                     current_fractional_speed = self.location_has_agent_speed[position]
@@ -284,13 +283,11 @@ class TreeObsForRailEnv(ObservationBuilder):
                     # TODO: This does not work as expected yet
                     other_agent_opposite_direction += self.location_has_agent[position] - \
                                                       self.location_has_agent_direction[
-                                                          (agent.position, agent.direction)]
+                                                          (position, direction)]
+
                 else:
                     # If no agent in the same direction was found all agents in that position are other direction
                     other_agent_opposite_direction += self.location_has_agent[position]
-                    print("went in here")
-                if other_agent_opposite_direction > 0:
-                    print("Other agents", other_agent_opposite_direction)
 
                 # Check number of possible transitions for agent and total number of transitions in cell (type)
             cell_transitions = self.env.rail.get_transitions(*position, direction)
-- 
GitLab