From 314dda88599461046448c778c61cd1f4154b87f8 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Fri, 5 Jul 2019 13:46:53 +0200
Subject: [PATCH] #67 multiple agents

---
 flatland/envs/observations.py           |  2 +-
 tests/test_flatland_envs_predictions.py | 11 +++++------
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py
index 230b0b7..ad5be7c 100644
--- a/flatland/envs/observations.py
+++ b/flatland/envs/observations.py
@@ -230,7 +230,7 @@ class TreeObsForRailEnv(ObservationBuilder):
                 (possible future use: number of other agents in the same direction in this branch)
             0 = no agent present same direction
 
-        #9: agent in the opposite drection
+        #9: agent in the opposite direction
             n = number of agents present other direction than myself (so conflict)
                 (possible future use: number of other agents in other direction in this branch, ie. number of conflicts)
             0 = no agent present other direction than myself
diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py
index a334bd1..2926d21 100644
--- a/tests/test_flatland_envs_predictions.py
+++ b/tests/test_flatland_envs_predictions.py
@@ -265,18 +265,17 @@ def test_shortest_path_predictor_conflicts(rendering=False):
     pp.pprint(tree_0)
 
     # check the expectations
-    # TODO check with Erik, this should be symmetric, should it not?
-    expected_conflicts_0 = [('F', 'R'), ('F', 'L')]
-    expected_conflicts_1 = [('F'), ('F', 'L')]
+    expected_conflicts_0 = [('F','R')]
+    expected_conflicts_1 = [('F','L')]
     _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ")
     _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ")
 
 
 def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''):
-    assert (tree_0[''][7] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
+    assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt)
     for a_1 in obs_builder.tree_explorted_actions_char:
-        conflict = tree_0[a_1][''][7]
+        conflict = tree_0[a_1][''][8]
         assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1)
         for a_2 in obs_builder.tree_explorted_actions_char:
-            conflict = tree_0[a_1][a_2][''][7]
+            conflict = tree_0[a_1][a_2][''][8]
             assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)
-- 
GitLab