From c7d695c4701d0230335cf0d1e2ff31bf62dbf2f2 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Fri, 5 Jul 2019 13:13:43 +0200
Subject: [PATCH] unit test for conflicts of multiple agents

---
 examples/simple_example_3.py             |   2 +-
 flatland/core/grid/grid4.py              |   7 ++
 flatland/envs/observations.py            | 122 +++++++-------------
 flatland/envs/rail_env.py                |  12 +-
 tests/simple_rail.py                     |  48 ++++++++
 tests/test_flatland_envs_observations.py |  49 +-------
 tests/test_flatland_envs_predictions.py  | 138 +++++++++++++----------
 7 files changed, 188 insertions(+), 190 deletions(-)
 create mode 100644 tests/simple_rail.py

diff --git a/examples/simple_example_3.py b/examples/simple_example_3.py
index 9055dd4c..853d5f5e 100644
--- a/examples/simple_example_3.py
+++ b/examples/simple_example_3.py
@@ -19,7 +19,7 @@ env = RailEnv(width=7,
 # Print the observation vector for agent 0
 obs, all_rewards, done, _ = env.step({0: 0})
 for i in range(env.get_num_agents()):
-    env.obs_builder.util_print_obs_subtree(tree=obs[i], num_features_per_node=7)
+    env.obs_builder.util_print_obs_subtree(tree=obs[i])
 
 env_renderer = RenderTool(env)
 env_renderer.renderEnv(show=True, frames=True)
diff --git a/flatland/core/grid/grid4.py b/flatland/core/grid/grid4.py
index 5c09f0ac..714123ed 100644
--- a/flatland/core/grid/grid4.py
+++ b/flatland/core/grid/grid4.py
@@ -11,6 +11,13 @@ class Grid4TransitionsEnum(IntEnum):
     SOUTH = 2
     WEST = 3
 
+    @staticmethod
+    def to_char(int: int):
+        return {0: 'N',
+                1: 'E',
+                2: 'S',
+                3: 'W'}[int]
+
 
 class Grid4Transitions(Transitions):
     """
diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 8ed455ca..e9833b0d 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -1,6 +1,7 @@
 """
 Collection of environment-specific ObservationBuilder.
 """
+import pprint
 from collections import deque
 
 import numpy as np
@@ -34,6 +35,8 @@ class TreeObsForRailEnv(ObservationBuilder):
         self.location_has_agent_direction = {}
         self.predictor = predictor
         self.agents_previous_reset = None
+        self.tree_explored_actions = [1, 2, 3, 0]
+        self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
 
     def reset(self):
         agents = self.env.agents
@@ -126,19 +129,6 @@ class TreeObsForRailEnv(ObservationBuilder):
 
                 desired_movement_from_new_cell = (neigh_direction + 2) % 4
 
-                """
-                # Is the next cell a dead-end?
-                isNextCellDeadEnd = False
-                nbits = 0
-                tmp = self.env.rail.get_transitions((new_cell[0], new_cell[1]))
-                while tmp > 0:
-                    nbits += (tmp & 1)
-                    tmp = tmp >> 1
-                if nbits == 1:
-                    # Dead-end!
-                    isNextCellDeadEnd = True
-                """
-
                 # Check all possible transitions in new_cell
                 for agent_orientation in range(4):
                     # Is a transition along movement `desired_movement_from_new_cell' to the current cell possible?
@@ -213,7 +203,7 @@ class TreeObsForRailEnv(ObservationBuilder):
             [... from 'right] +
             [... from 'back']
 
-        Finally, each node information is composed of 5 floating point values:
+        Finally, each node information is composed of 8 floating point values:
 
         #1: if own target lies on the explored branch the current distance from the agent in number of cells is stored.
 
@@ -268,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         # Start from the current orientation, and see which transitions are available;
         # organize them as [left, forward, right, back], relative to the current orientation
         # If only one transition is possible, the tree is oriented with this transition as the forward branch.
-        # TODO: Test if this works as desired!
         orientation = agent.direction
 
         if num_transitions == 1:
@@ -282,15 +271,20 @@ class TreeObsForRailEnv(ObservationBuilder):
                 observation = observation + branch_observation
                 visited = visited.union(branch_visited)
             else:
-                num_cells_to_fill_in = 0
-                pow4 = 1
-                for i in range(self.max_depth):
-                    num_cells_to_fill_in += pow4
-                    pow4 *= 4
-                observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
+                # add cells filled with infinity if no transition is possible
+                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth)
         self.env.dev_obs_dict[handle] = visited
         return observation
 
+    def _num_cells_to_fill_in(self, remaining_depth):
+        """Computes the length of observation vector: sum_{i=0,depth-1} 2^i * observation_dim."""
+        num_observations = 0
+        pow4 = 1
+        for i in range(remaining_depth):
+            num_observations += pow4
+            pow4 *= 4
+        return num_observations * self.observation_dim
+
     def _explore_branch(self, handle, position, direction, root_observation, tot_dist, depth):
         """
         Utility function to compute tree-based observations.
@@ -334,7 +328,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                     # Cummulate the number of agents on branch with other direction
                     other_agent_opposite_direction += 1
 
-            # Register possible conflict
+            # Register possible future conflict
             if self.predictor and num_steps < self.max_prediction_depth:
                 int_position = coordinate_to_position(self.env.width, [position])
                 if tot_dist < self.max_prediction_depth:
@@ -422,42 +416,6 @@ class TreeObsForRailEnv(ObservationBuilder):
         # #############################
         # #############################
         # Modify here to append new / different features for each visited cell!
-        """
-        other_agent_same_direction = \
-            1 if position in self.location_has_agent and self.location_has_agent_direction[position] == direction else 0
-        other_agent_opposite_direction = \
-            1 if position in self.location_has_agent and self.location_has_agent_direction[position] != direction else 0
-
-        if last_isTarget:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           0,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction
-                           ]
-
-        elif last_isTerminal:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           np.inf,
-                           np.inf,
-                           other_agent_same_direction,
-                           other_agent_opposite_direction
-                           ]
-        else:
-            observation = [0,
-                           other_target_encountered,
-                           other_agent_encountered,
-                           root_observation[3] + num_steps,
-                           self.distance_map[handle, position[0], position[1], direction],
-                           other_agent_same_direction,
-                           other_agent_opposite_direction
-                           ]
-        """
-
         if last_isTarget:
             observation = [own_target_encountered,
                            other_target_encountered,
@@ -522,41 +480,47 @@ class TreeObsForRailEnv(ObservationBuilder):
                 if len(branch_visited) != 0:
                     visited = visited.union(branch_visited)
             else:
-                num_cells_to_fill_in = 0
-                pow4 = 1
-                for i in range(self.max_depth - depth):
-                    num_cells_to_fill_in += pow4
-                    pow4 *= 4
-                observation = observation + ([-np.inf] * self.observation_dim) * num_cells_to_fill_in
+                # no exploring possible, add just cells with infinity
+                observation = observation + [-np.inf] * self._num_cells_to_fill_in(self.max_depth - depth)
 
         return observation, visited
 
-    def util_print_obs_subtree(self, tree, num_features_per_node=8, prompt='', current_depth=0):
+    def util_print_obs_subtree(self, tree):
         """
         Utility function to pretty-print tree observations returned by this object.
         """
-        if len(tree) < num_features_per_node:
+        pp = pprint.PrettyPrinter(indent=4)
+        pp.pprint(self.unfold_observation_tree(tree))
+
+    def unfold_observation_tree(self, tree, current_depth=0, actions_for_display=True):
+        """
+        Utility function to pretty-print tree observations returned by this object.
+        """
+        if len(tree) < self.observation_dim:
             return
 
         depth = 0
-        tmp = len(tree) / num_features_per_node - 1
+        tmp = len(tree) / self.observation_dim - 1
         pow4 = 4
         while tmp > 0:
             tmp -= pow4
             depth += 1
             pow4 *= 4
 
-        prompt_ = ['L:', 'F:', 'R:', 'B:']
-
-        print("  " * current_depth + prompt, tree[0:num_features_per_node])
-        child_size = (len(tree) - num_features_per_node) // 4
-        for children in range(4):
-            child_tree = tree[(num_features_per_node + children * child_size):
-                              (num_features_per_node + (children + 1) * child_size)]
-            self.util_print_obs_subtree(child_tree,
-                                        num_features_per_node,
-                                        prompt=prompt_[children],
-                                        current_depth=current_depth + 1)
+        unfolded = {}
+        unfolded[''] = tree[0:self.observation_dim]
+        child_size = (len(tree) - self.observation_dim) // 4
+        for child in range(4):
+            child_tree = tree[(self.observation_dim + child * child_size):
+                              (self.observation_dim + (child + 1) * child_size)]
+            observation_tree = self.unfold_observation_tree(child_tree, current_depth=current_depth + 1)
+            if observation_tree is not None:
+                if actions_for_display:
+                    label = self.tree_explorted_actions_char[child]
+                else:
+                    label = self.tree_explored_actions[child]
+                unfolded[label] = observation_tree
+        return unfolded
 
     def _set_env(self, env):
         self.env = env
@@ -725,8 +689,6 @@ class LocalObsForRailEnv(ObservationBuilder):
                 bitlist = [int(digit) for digit in bin(self.env.rail.get_transitions((i, j)))[2:]]
                 bitlist = [0] * (16 - len(bitlist)) + bitlist
                 self.rail_obs[i + self.view_radius, j + self.view_radius] = np.array(bitlist)
-                # self.rail_obs[i+self.view_radius, j+self.view_radius] = np.array(
-                #     list(f'{self.env.rail.get_transitions((i, j)):016b}')).astype(int)
 
     def get(self, handle):
         agents = self.env.agents
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 26bccf6a..b4a56a8d 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -19,12 +19,22 @@ from flatland.envs.observations import TreeObsForRailEnv
 
 
 class RailEnvActions(IntEnum):
-    DO_NOTHING = 0
+    DO_NOTHING = 0  # implies change of direction in a dead-end!
     MOVE_LEFT = 1
     MOVE_FORWARD = 2
     MOVE_RIGHT = 3
     STOP_MOVING = 4
 
+    @staticmethod
+    def to_char(a: int):
+        return {
+            0: 'B',
+            1: 'L',
+            2: 'F',
+            3: 'R',
+            4: 'S',
+        }[a]
+
 
 class RailEnv(Environment):
     """
diff --git a/tests/simple_rail.py b/tests/simple_rail.py
new file mode 100644
index 00000000..894864ac
--- /dev/null
+++ b/tests/simple_rail.py
@@ -0,0 +1,48 @@
+import numpy as np
+
+from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.core.transition_map import GridTransitionMap
+
+
+def make_simple_rail():
+    # We instantiate a very simple rail network on a 7x10 grid:
+    #        |
+    #        |
+    #        |
+    # _ _ _ /_\ _ _  _  _ _ _
+    #               \ /
+    #                |
+    #                |
+    #                |
+    cells = [int('0000000000000000', 2),  # empty cell - Case 0
+             int('1000000000100000', 2),  # Case 1 - straight
+             int('1001001000100000', 2),  # Case 2 - simple switch
+             int('1000010000100001', 2),  # Case 3 - diamond drossing
+             int('1001011000100001', 2),  # Case 4 - single slip switch
+             int('1100110000110011', 2),  # Case 5 - double slip switch
+             int('0101001000000010', 2),  # Case 6 - symmetrical switch
+             int('0010000000000000', 2)]  # Case 7 - dead end
+    transitions = Grid4Transitions([])
+    empty = cells[0]
+    dead_end_from_south = cells[7]
+    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
+    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
+    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
+    vertical_straight = cells[1]
+    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
+    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
+    double_switch_north_horizontal_straight = transitions.rotate_transition(
+        double_switch_south_horizontal_straight, 180)
+    rail_map = np.array(
+        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
+        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
+        [[dead_end_from_east] + [horizontal_straight] * 2 +
+         [double_switch_north_horizontal_straight] +
+         [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
+         [horizontal_straight] * 2 + [dead_end_from_west]] +
+        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
+        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
+    rail = GridTransitionMap(width=rail_map.shape[1],
+                             height=rail_map.shape[0], transitions=transitions)
+    rail.grid = rail_map
+    return rail, rail_map
diff --git a/tests/test_flatland_envs_observations.py b/tests/test_flatland_envs_observations.py
index ce224736..5ee5b4a6 100644
--- a/tests/test_flatland_envs_observations.py
+++ b/tests/test_flatland_envs_observations.py
@@ -3,62 +3,17 @@
 
 import numpy as np
 
-from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.envs.observations import GlobalObsForRailEnv
 from flatland.envs.rail_env import RailEnv
+from simple_rail import make_simple_rail
 
 """Tests for `flatland` package."""
 
 
 def test_global_obs():
-    # We instantiate a very simple rail network on a 7x10 grid:
-    #        |
-    #        |
-    #        |
-    # _ _ _ /_\ _ _  _  _ _ _
-    #               \ /
-    #                |
-    #                |
-    #                |
+    rail, rail_map = make_simple_rail()
 
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-
-    transitions = Grid4Transitions([])
-    empty = cells[0]
-
-    dead_end_from_south = cells[7]
-    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
-    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
-    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
-
-    vertical_straight = cells[1]
-    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
-
-    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
-    double_switch_north_horizontal_straight = transitions.rotate_transition(
-        double_switch_south_horizontal_straight, 180)
-
-    rail_map = np.array(
-        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
-        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
-        [[dead_end_from_east] + [horizontal_straight] * 2 +
-         [double_switch_north_horizontal_straight] +
-         [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
-         [horizontal_straight] * 2 + [dead_end_from_west]] +
-        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
-        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
-
-    rail = GridTransitionMap(width=rail_map.shape[1],
-                             height=rail_map.shape[0], transitions=transitions)
-    rail.grid = rail_map
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_GridTransitionMap_generator(rail),
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index 16850672..1bf564ed 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -1,64 +1,21 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
+import pprint
 
 import numpy as np
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
-from flatland.core.transition_map import GridTransitionMap, Grid4Transitions
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import DummyPredictorForRailEnv, ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_env import RailEnvActions
 from flatland.utils.rendertools import RenderTool
+from simple_rail import make_simple_rail
 
 """Test predictions for `flatland` package."""
 
 
-def make_simple_rail():
-    # We instantiate a very simple rail network on a 7x10 grid:
-    #        |
-    #        |
-    #        |
-    # _ _ _ /_\ _ _  _  _ _ _
-    #               \ /
-    #                |
-    #                |
-    #                |
-    cells = [int('0000000000000000', 2),  # empty cell - Case 0
-             int('1000000000100000', 2),  # Case 1 - straight
-             int('1001001000100000', 2),  # Case 2 - simple switch
-             int('1000010000100001', 2),  # Case 3 - diamond drossing
-             int('1001011000100001', 2),  # Case 4 - single slip switch
-             int('1100110000110011', 2),  # Case 5 - double slip switch
-             int('0101001000000010', 2),  # Case 6 - symmetrical switch
-             int('0010000000000000', 2)]  # Case 7 - dead end
-    transitions = Grid4Transitions([])
-    empty = cells[0]
-    dead_end_from_south = cells[7]
-    dead_end_from_west = transitions.rotate_transition(dead_end_from_south, 90)
-    dead_end_from_north = transitions.rotate_transition(dead_end_from_south, 180)
-    dead_end_from_east = transitions.rotate_transition(dead_end_from_south, 270)
-    vertical_straight = cells[1]
-    horizontal_straight = transitions.rotate_transition(vertical_straight, 90)
-    double_switch_south_horizontal_straight = horizontal_straight + cells[6]
-    double_switch_north_horizontal_straight = transitions.rotate_transition(
-        double_switch_south_horizontal_straight, 180)
-    rail_map = np.array(
-        [[empty] * 3 + [dead_end_from_south] + [empty] * 6] +
-        [[empty] * 3 + [vertical_straight] + [empty] * 6] * 2 +
-        [[dead_end_from_east] + [horizontal_straight] * 2 +
-         [double_switch_north_horizontal_straight] +
-         [horizontal_straight] * 2 + [double_switch_south_horizontal_straight] +
-         [horizontal_straight] * 2 + [dead_end_from_west]] +
-        [[empty] * 6 + [vertical_straight] + [empty] * 3] * 2 +
-        [[empty] * 6 + [dead_end_from_north] + [empty] * 3], dtype=np.uint16)
-    rail = GridTransitionMap(width=rail_map.shape[1],
-                             height=rail_map.shape[0], transitions=transitions)
-    rail.grid = rail_map
-    return rail, rail_map
-
-
 def test_dummy_predictor(rendering=False):
     rail, rail_map = make_simple_rail()
 
@@ -68,12 +25,16 @@ def test_dummy_predictor(rendering=False):
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)),
                   )
+    # reset to initialize agents_static
     env.reset()
 
     # set initial position and direction for testing...
-    env.agents[0].position = (5, 6)
-    env.agents[0].direction = 0
-    env.agents[0].target = (3, 0)
+    env.agents_static[0].position = (5, 6)
+    env.agents_static[0].direction = 0
+    env.agents_static[0].target = (3, 0)
+
+    # reset to set agents from agents_static
+    env.reset(False, False)
 
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
@@ -154,41 +115,39 @@ def test_shortest_path_predictor(rendering=False):
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
+
+    # reset to initialize agents_static
     env.reset()
 
-    agent = env.agents[0]
+    # set the initial position
+    agent = env.agents_static[0]
     agent.position = (5, 6)  # south dead-end
     agent.direction = 0  # north
     agent.target = (3, 9)  # east dead-end
-
     agent.moving = True
 
+    # reset to set agents from agents_static
+    env.reset(False, False)
+
     if rendering:
         renderer = RenderTool(env, gl="PILSVG")
         renderer.renderEnv(show=True, show_observations=False)
         input("Continue?")
 
-    agent = env.agents[0]
-    assert agent.position == (5, 6)
-    assert agent.direction == 0
-    assert agent.target == (3, 9)
-    assert agent.moving
-
-    env.obs_builder._compute_distance_map()
-
+    # compute the observations and predictions
     distance_map = env.obs_builder.distance_map
-    assert distance_map[agent.handle, agent.position[0], agent.position[
+    assert distance_map[0, agent.position[0], agent.position[
         1], agent.direction] == 5.0, "found {} instead of {}".format(
         distance_map[agent.handle, agent.position[0], agent.position[1], agent.direction], 5.0)
 
-    # test assertions
-    env.obs_builder.get_many()
+    # extract the data
     predictions = env.obs_builder.predictions
     positions = np.array(list(map(lambda prediction: [*prediction[1:3]], predictions[0])))
     directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0])))
     time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0])))
     actions = np.array(list(map(lambda prediction: [prediction[4]], predictions[0])))
 
+    # test if data meets expectations
     expected_positions = [
         [5, 6],
         [4, 6],
@@ -292,3 +251,60 @@ def test_shortest_path_predictor(rendering=False):
         "time_offsets {}, expected {}".format(time_offsets, expected_time_offsets)
     assert np.array_equal(actions, expected_actions), \
         "actions {}, expected {}".format(actions, expected_actions)
+
+
+def test_shortest_path_predictor_conflicts(rendering=False):
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  number_of_agents=2,
+                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
+                  )
+    # initialize agents_static
+    env.reset()
+
+    # set the initial position
+    agent = env.agents_static[0]
+    agent.position = (5, 6)  # south dead-end
+    agent.direction = 0  # north
+    agent.target = (3, 9)  # east dead-end
+    agent.moving = True
+
+    agent = env.agents_static[1]
+    agent.position = (3, 8)  # east dead-end
+    agent.direction = 3  # west
+    agent.target = (6, 6)  # south dead-end
+    agent.moving = True
+
+    # reset to set agents from agents_static
+    observations = env.reset(False, False)
+
+    if rendering:
+        renderer = RenderTool(env, gl="PILSVG")
+        renderer.renderEnv(show=True, show_observations=False)
+        input("Continue?")
+
+    # get the trees to test
+    obs_builder: TreeObsForRailEnv = env.obs_builder
+    pp = pprint.PrettyPrinter(indent=4)
+    tree_0 = obs_builder.unfold_observation_tree(observations[0])
+    tree_1 = obs_builder.unfold_observation_tree(observations[1])
+    pp.pprint(tree_0)
+
+    # check the expectations
+    # TODO check with Erik, this should be symmetric, should it not?
+    expected_conflicts_0 = [('F', 'R'), ('F', 'L')]
+    expected_conflicts_1 = [('F'), ('F', 'L')]
+    _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ")
+    _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
+
+
+def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
+    assert (tree_0[''][7] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
+    for a_1 in obs_builder.tree_explorted_actions_char:
+        conflict = tree_0[a_1][''][7]
+        assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
+        for a_2 in obs_builder.tree_explorted_actions_char:
+            conflict = tree_0[a_1][a_2][''][7]
+            assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
-- 
GitLab