From 5d186dc59797823578c4332f9d930983179ca8e6 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Wed, 2 Oct 2019 20:14:30 -0400
Subject: [PATCH] still buggy with treeeobservation

---
 examples/flatland_2_0_example.py | 10 +++++-----
 flatland/envs/observations.py    | 31 +++++++++++++++++++++++--------
 2 files changed, 28 insertions(+), 13 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index 32b9f611..40b41591 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -28,13 +28,13 @@ speed_ration_map = {1.: 0.25,  # Fast passenger train
                     1. / 3.: 0.25,  # Slow commuter train
                     1. / 4.: 0.25}  # Slow freight train
 
-env = RailEnv(width=100,
-              height=20,
-              rail_generator=sparse_rail_generator(num_cities=2,  # Number of cities in map (where train stations are)
+env = RailEnv(width=40,
+              height=40,
+              rail_generator=sparse_rail_generator(num_cities=8,  # Number of cities in map (where train stations are)
                                                    seed=1,  # Random seed
-                                                   grid_mode=True,
+                                                   grid_mode=False,
                                                    max_inter_city_rails=2,
-                                                   max_tracks_in_city=8,
+                                                   max_tracks_in_city=4,
                                                    ),
               schedule_generator=sparse_schedule_generator(speed_ration_map),
               number_of_agents=20,
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index d97c476f..7b7aacaf 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -161,13 +161,20 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         # Update local lookup table for all agents' positions
         self.location_has_agent = dict()
+        self.location_has_agent_direction = dict()
         for agent in self.env.agents:
             if tuple(agent.position) in self.location_has_agent:
                 self.location_has_agent[tuple(agent.position)] = self.location_has_agent[tuple(agent.position)] + 1
             else:
                 self.location_has_agent[tuple(agent.position)] = 1
-        # TODO: Update this to handle number of agents at same location
-        self.location_has_agent_direction = {tuple(agent.position): agent.direction for agent in self.env.agents}
+
+            if (agent.position, agent.direction) in self.location_has_agent_direction:
+                self.location_has_agent_direction[(agent.position, agent.direction)] = \
+                self.location_has_agent_direction[(agent.position, agent.direction)] + 1
+            else:
+                self.location_has_agent_direction[(agent.position, agent.direction)] = 1
+
+
         self.location_has_agent_speed = {tuple(agent.position): agent.speed_data['speed'] for agent in self.env.agents}
         self.location_has_agent_malfunction = {tuple(agent.position): agent.malfunction_data['malfunction'] for agent in
                                                self.env.agents}
@@ -264,20 +271,28 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if self.location_has_agent_malfunction[position] > malfunctioning_agent:
                     malfunctioning_agent = self.location_has_agent_malfunction[position]
 
-                if self.location_has_agent_direction[position] == direction:
+                if (agent.position, agent.direction) in self.location_has_agent_direction:
                     # Cummulate the number of agents on branch with same direction
-                    other_agent_same_direction += 1
+                    other_agent_same_direction += self.location_has_agent_direction[(agent.position, agent.direction)]
 
                     # Check fractional speed of agents
                     current_fractional_speed = self.location_has_agent_speed[position]
                     if current_fractional_speed < min_fractional_speed:
                         min_fractional_speed = current_fractional_speed
 
-                if self.location_has_agent_direction[position] != direction:
-                    # Cummulate the number of agents on branch with other direction
-                    other_agent_opposite_direction += 1
+                    # Other direction agents
+                    # TODO: This does not work as expected yet
+                    other_agent_opposite_direction += self.location_has_agent[position] - \
+                                                      self.location_has_agent_direction[
+                                                          (agent.position, agent.direction)]
+                else:
+                    # If no agent in the same direction was found all agents in that position are other direction
+                    other_agent_opposite_direction += self.location_has_agent[position]
+                    print("went in here")
+                if other_agent_opposite_direction > 0:
+                    print("Other agents", other_agent_opposite_direction)
 
-            # Check number of possible transitions for agent and total number of transitions in cell (type)
+                # Check number of possible transitions for agent and total number of transitions in cell (type)
             cell_transitions = self.env.rail.get_transitions(*position, direction)
             transition_bit = bin(self.env.rail.get_full_transitions(*position))
             total_transitions = transition_bit.count("1")
-- 
GitLab