diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 00a3d6252ce6d93754c3bfd9c629ad78d45f348d..48a4084ea5395f4406a7884e03e3d84a110fc37e 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -11,6 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.utils.ordered_set import OrderedSet from flatland.utils.rendertools import RenderTool random.seed(100) @@ -82,7 +83,7 @@ class ObservePredictions(TreeObsForRailEnv): # We are going to track what cells where considered while building the obervation and make them accesible # For rendering - visited = set() + visited = OrderedSet() for _idx in range(10): # Check if any of the other prediction overlap with agents own predictions x_coord = self.predictions[handle][_idx][1] diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py index 680e945316ab3a4876bd36fa8e6b001ea346cd26..db09fbd57b18d203c956742d4711973c986ca452 100644 --- a/flatland/core/grid/rail_env_grid.py +++ b/flatland/core/grid/rail_env_grid.py @@ -1,4 +1,5 @@ from flatland.core.grid.grid4 import Grid4Transitions +from flatland.utils.ordered_set import OrderedSet class RailEnvTransitions(Grid4Transitions): @@ -44,7 +45,7 @@ class RailEnvTransitions(Grid4Transitions): ) # create this to make validation faster - self.transitions_all = set() + self.transitions_all = OrderedSet() for index, trans in enumerate(self.transitions): self.transitions_all.add(trans) if index in (2, 4, 6, 7, 8, 9, 10):