diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 230b0b7dd450d222d282d8bf507814815b923b36..ad5be7c709b0f090a3ff1d0705f73b49c283e848 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -230,7 +230,7 @@ class TreeObsForRailEnv(ObservationBuilder): (possible future use: number of other agents in the same direction in this branch) 0 = no agent present same direction - #9: agent in the opposite drection + #9: agent in the opposite direction n = number of agents present other direction than myself (so conflict) (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself diff --git a/tests/test_flatland_envs_predictions.py b/tests/test_flatland_envs_predictions.py index a334bd1efb4734a94ec8e277eb758ea63890ecaf..2926d21d20710c3897cf046550f0977e86b14aa3 100644 --- a/tests/test_flatland_envs_predictions.py +++ b/tests/test_flatland_envs_predictions.py @@ -265,18 +265,17 @@ def test_shortest_path_predictor_conflicts(rendering=False): pp.pprint(tree_0) # check the expectations - # TODO check with Erik, this should be symmetric, should it not? - expected_conflicts_0 = [('F', 'R'), ('F', 'L')] - expected_conflicts_1 = [('F'), ('F', 'L')] + expected_conflicts_0 = [('F','R')] + expected_conflicts_1 = [('F','L')] _check_expected_conflicts(expected_conflicts_0, obs_builder, tree_0, "agent[0]: ") _check_expected_conflicts(expected_conflicts_1, obs_builder, tree_1, "agent[1]: ") def _check_expected_conflicts(expected_conflicts, obs_builder, tree_0, prompt=''): - assert (tree_0[''][7] > 0) == (() in expected_conflicts), "{}[]".format(prompt) + assert (tree_0[''][8] > 0) == (() in expected_conflicts), "{}[]".format(prompt) for a_1 in obs_builder.tree_explorted_actions_char: - conflict = tree_0[a_1][''][7] + conflict = tree_0[a_1][''][8] assert (conflict > 0) == ((a_1) in expected_conflicts), "{}[{}]".format(prompt, a_1) for a_2 in obs_builder.tree_explorted_actions_char: - conflict = tree_0[a_1][a_2][''][7] + conflict = tree_0[a_1][a_2][''][8] assert (conflict > 0) == ((a_1, a_2) in expected_conflicts), "{}[{}][{}]".format(prompt, a_1, a_2)