From df18b26889675d5bb86286519105db27c0252ed4 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 17 Sep 2019 15:10:18 +0200
Subject: [PATCH] #178 post-rebase fix

---
 examples/custom_observation_example_03_ObservePredictions.py | 3 ++-
 flatland/core/grid/rail_env_grid.py                          | 3 ++-
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py
index 00a3d625..48a4084e 100644
--- a/examples/custom_observation_example_03_ObservePredictions.py
+++ b/examples/custom_observation_example_03_ObservePredictions.py
@@ -11,6 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.envs.rail_generators import complex_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator
+from flatland.utils.ordered_set import OrderedSet
 from flatland.utils.rendertools import RenderTool
 
 random.seed(100)
@@ -82,7 +83,7 @@ class ObservePredictions(TreeObsForRailEnv):
         # We are going to track what cells where considered while building the obervation and make them accesible
         # For rendering
 
-        visited = set()
+        visited = OrderedSet()
         for _idx in range(10):
             # Check if any of the other prediction overlap with agents own predictions
             x_coord = self.predictions[handle][_idx][1]
diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py
index 680e9453..db09fbd5 100644
--- a/flatland/core/grid/rail_env_grid.py
+++ b/flatland/core/grid/rail_env_grid.py
@@ -1,4 +1,5 @@
 from flatland.core.grid.grid4 import Grid4Transitions
+from flatland.utils.ordered_set import OrderedSet
 
 
 class RailEnvTransitions(Grid4Transitions):
@@ -44,7 +45,7 @@ class RailEnvTransitions(Grid4Transitions):
         )
 
         # create this to make validation faster
-        self.transitions_all = set()
+        self.transitions_all = OrderedSet()
         for index, trans in enumerate(self.transitions):
             self.transitions_all.add(trans)
             if index in (2, 4, 6, 7, 8, 9, 10):
-- 
GitLab