From df18b26889675d5bb86286519105db27c0252ed4 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 17 Sep 2019 15:10:18 +0200 Subject: [PATCH] #178 post-rebase fix --- examples/custom_observation_example_03_ObservePredictions.py | 3 ++- flatland/core/grid/rail_env_grid.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/custom_observation_example_03_ObservePredictions.py b/examples/custom_observation_example_03_ObservePredictions.py index 00a3d625..48a4084e 100644 --- a/examples/custom_observation_example_03_ObservePredictions.py +++ b/examples/custom_observation_example_03_ObservePredictions.py @@ -11,6 +11,7 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import complex_rail_generator from flatland.envs.schedule_generators import complex_schedule_generator +from flatland.utils.ordered_set import OrderedSet from flatland.utils.rendertools import RenderTool random.seed(100) @@ -82,7 +83,7 @@ class ObservePredictions(TreeObsForRailEnv): # We are going to track what cells where considered while building the obervation and make them accesible # For rendering - visited = set() + visited = OrderedSet() for _idx in range(10): # Check if any of the other prediction overlap with agents own predictions x_coord = self.predictions[handle][_idx][1] diff --git a/flatland/core/grid/rail_env_grid.py b/flatland/core/grid/rail_env_grid.py index 680e9453..db09fbd5 100644 --- a/flatland/core/grid/rail_env_grid.py +++ b/flatland/core/grid/rail_env_grid.py @@ -1,4 +1,5 @@ from flatland.core.grid.grid4 import Grid4Transitions +from flatland.utils.ordered_set import OrderedSet class RailEnvTransitions(Grid4Transitions): @@ -44,7 +45,7 @@ class RailEnvTransitions(Grid4Transitions): ) # create this to make validation faster - self.transitions_all = set() + self.transitions_all = OrderedSet() for index, trans in enumerate(self.transitions): self.transitions_all.add(trans) if index in (2, 4, 6, 7, 8, 9, 10): -- GitLab