From 20d93455bcd09aef130d5deaa3900940c4b2b84b Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 13 Jun 2019 16:32:50 +0200
Subject: [PATCH] Added potential conflict to tree observation as an 8th
 feature. ATTENTION this means that the observation space dimension has
 increased! will still check that this is handled correctly everywhere but
 looks good.

---
 examples/training_example.py  | 17 ++++++++++++-----
 flatland/envs/env_utils.py    |  8 +++++---
 flatland/envs/observations.py | 35 ++++++++++++++++++++---------------
 3 files changed, 37 insertions(+), 23 deletions(-)

diff --git a/examples/training_example.py b/examples/training_example.py
index 1342107..ad222cb 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -1,6 +1,8 @@
 import numpy as np
 
 from flatland.envs.generators import complex_rail_generator
+from flatland.envs.observations import TreeObsForRailEnv
+from flatland.envs.predictions import DummyPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 
 np.random.seed(1)
@@ -8,10 +10,13 @@ np.random.seed(1)
 # Use the complex_rail_generator to generate feasible network configurations with corresponding tasks
 # Training on simple small tasks is the best way to get familiar with the environment
 #
-env = RailEnv(width=15,
-              height=15,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=10, min_dist=10, max_dist=99999, seed=0),
-              number_of_agents=5)
+
+TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv())
+env = RailEnv(width=20,
+              height=20,
+              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              obs_builder_object=TreeObservation,
+              number_of_agents=2)
 
 
 # Import your own Agent or use RLlib to train agents on Flatland
@@ -56,6 +61,7 @@ n_trials = 5
 # Empty dictionary for all agent action
 action_dict = dict()
 print("Starting Training...")
+
 for trials in range(1, n_trials + 1):
 
     # Reset environment and get initial observations for all agents
@@ -74,7 +80,8 @@ for trials in range(1, n_trials + 1):
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
-
+        TreeObservation.util_print_obs_subtree(next_obs[0], num_features_per_node=8)
+        print(len(next_obs[0]))
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
             agent.step((obs[a], action_dict[a], all_rewards[a], next_obs[a], done[a]))
diff --git a/flatland/envs/env_utils.py b/flatland/envs/env_utils.py
index 35e6560..c9595b7 100644
--- a/flatland/envs/env_utils.py
+++ b/flatland/envs/env_utils.py
@@ -89,10 +89,12 @@ def coordinate_to_position(width, coords):
     :param coords:
     :return:
     """
-    position = []
+    position = np.empty(len(coords), dtype=int)
+    idx = 0
     for t in coords:
-        position.append((t[1] * width + t[0]))
-    return np.asarray(position).flatten()
+        position[idx] = int(t[1] * width + t[0])
+        idx += 1
+    return position
 
 
 class AStarNode():
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 7bc2f99..e0f0c7f 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -27,7 +27,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         for i in range(self.max_depth + 1):
             size += pow4
             pow4 *= 4
-        self.observation_dim = 7
+        self.observation_dim = 8
         self.observation_space = [size * self.observation_dim]
         self.location_has_agent = {}
         self.location_has_agent_direction = {}
@@ -187,10 +187,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     dir_list.append(self.predictions[a][t][3])
                 self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)})
                 self.predicted_dir.update({t: dir_list})
-
-            pred_pos = np.concatenate([[x[:, 1:3]] for x in list(self.predictions.values())], axis=0)
-            pred_pos = list(map(list, zip(*pred_pos)))
-
+            self.max_prediction_depth = len(self.predicted_pos)
         observations = {}
         for h in handles:
             observations[h] = self.get(h)
@@ -256,7 +253,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         possible_transitions = self.env.rail.get_transitions((*agent.position, agent.direction))
         num_transitions = np.count_nonzero(possible_transitions)
         # Root node - current position
-        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0]
+        observation = [0, 0, 0, 0, self.distance_map[(handle, *agent.position, agent.direction)], 0, 0, 0]
 
         root_observation = observation[:]
         visited = set()
@@ -309,7 +306,7 @@ class TreeObsForRailEnv(ObservationBuilder):
         other_target_encountered = np.inf
         other_agent_same_direction = 0
         other_agent_opposite_direction = 0
-
+        potential_conflict = 0
         num_steps = 1
         while exploring:
             # #############################
@@ -329,6 +326,10 @@ class TreeObsForRailEnv(ObservationBuilder):
                     other_agent_opposite_direction += 1
 
             # Register possible conflict
+            if self.predictor and num_steps < self.max_prediction_depth:
+                if coordinate_to_position(self.env.width, [position]) in np.delete(self.predicted_pos[num_steps],
+                                                                                   handle):
+                    potential_conflict = 1
 
             if position in self.location_has_target:
                 if num_steps < other_target_encountered:
@@ -430,7 +431,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            root_observation[3] + num_steps,
                            0,
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           potential_conflict
                            ]
 
         elif last_isTerminal:
@@ -440,7 +442,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            np.inf,
                            np.inf,
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           potential_conflict
                            ]
         else:
             observation = [0,
@@ -449,7 +452,8 @@ class TreeObsForRailEnv(ObservationBuilder):
                            root_observation[3] + num_steps,
                            self.distance_map[handle, position[0], position[1], direction],
                            other_agent_same_direction,
-                           other_agent_opposite_direction
+                           other_agent_opposite_direction,
+                           potential_conflict
                            ]
         # #############################
         # #############################
@@ -493,7 +497,7 @@ class TreeObsForRailEnv(ObservationBuilder):
 
         return observation, visited
 
-    def util_print_obs_subtree(self, tree, num_features_per_node=5, prompt='', current_depth=0):
+    def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0):
         """
         Utility function to pretty-print tree observations returned by this object.
         """
@@ -520,7 +524,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                                         prompt=prompt_[children],
                                         current_depth=current_depth + 1)
 
-    def split_tree(self, tree, num_features_per_node=7, current_depth=0):
+    def split_tree(self, tree, num_features_per_node=8, current_depth=0):
         """
 
         :param tree:
@@ -541,9 +545,10 @@ class TreeObsForRailEnv(ObservationBuilder):
             depth += 1
             pow4 *= 4
         child_size = (len(tree) - num_features_per_node) // 4
-        tree_data = tree[0:num_features_per_node - 3].tolist()
-        distance_data = [tree[num_features_per_node - 3]]
-        agent_data = tree[-2:].tolist()
+        tree_data = tree[0:4].tolist()
+        distance_data = [tree[4]]
+        agent_data = tree[-3:].tolist()
+
         for children in range(4):
             child_tree = tree[(num_features_per_node + children * child_size):
                               (num_features_per_node + (children + 1) * child_size)]
-- 
GitLab