From 96444273bc2f3338c47b180883d47b94e6db8362 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 24 Sep 2019 11:51:29 +0200 Subject: [PATCH] bugfix load flatland --- env_data/tests/test_001.pkl | Bin 0 -> 461 bytes flatland/envs/distance_map.py | 2 ++ flatland/envs/observations.py | 3 ++- flatland/envs/rail_env_utils.py | 17 +++++++++++++++++ flatland/envs/rail_generators.py | 10 +++++++--- flatland/envs/schedule_generators.py | 10 +++++++--- tests/test_flatland_core_transition_map.py | 1 - tests/test_flatland_envs_env_utils.py | 5 +++++ tests/test_flatland_envs_rail_env.py | 1 - 9 files changed, 40 insertions(+), 9 deletions(-) create mode 100644 env_data/tests/test_001.pkl create mode 100644 flatland/envs/rail_env_utils.py diff --git a/env_data/tests/test_001.pkl b/env_data/tests/test_001.pkl new file mode 100644 index 0000000000000000000000000000000000000000..823e594e41f26f6febac6a241f3befcec9f81cf7 GIT binary patch literal 461 zcmZo(l3tXVGIbUM9L!=lD}>Bp@jT1Ih>pY<&oVN=HF+R&7#P+jrl;nW6vr2rB$i|* zPntG~k%4g%Gbcmy#)ACf%#zIfy!f=D#AFcTG|Z-@#RaLUDW~l}Km_)d6eZ?C)k0Lp z19c^*=H#SSWR@_ruFg%&Nh{3*o58RFi4$LxSdz-HAuqL}BpykaVSQc^GQXH%XKsF3 hW?p)HQfgX$Q7W?Bp=pztIGHCgF$@_8FM|a51OPozqFw+1 literal 0 HcmV?d00001 diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 94019319..c9c6b003 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -58,6 +58,8 @@ class DistanceMap: self.reset_was_called = True self.agents = agents self.rail = rail + self.env_height = rail.height + self.env_width = rail.width def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): self.agents_previous_computation = self.agents diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 712d2425..6312cca3 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -167,7 +167,8 @@ class TreeObsForRailEnv(ObservationBuilder): # Root node - current position # Here information about the agent itself is stored - observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0, + distance_map = self.env.distance_map.get() + observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0, agent.malfunction_data['malfunction'], agent.speed_data['speed']] visited = OrderedSet() diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py new file mode 100644 index 00000000..722f54bc --- /dev/null +++ b/flatland/envs/rail_env_utils.py @@ -0,0 +1,17 @@ +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_file +from flatland.envs.schedule_generators import schedule_from_file + + +def load_flatland_environment_from_file(file_name, load_from_package=None): + environment = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name, load_from_package), + number_of_agents=1, + schedule_generator=schedule_from_file(file_name,load_from_package), + obs_builder_object=TreeObsForRailEnv( + max_depth=2, + predictor=ShortestPathPredictorForRailEnv(max_depth=10))) + return environment diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 60c606f7..0bce90aa 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -204,7 +204,7 @@ def rail_from_manual_specifications_generator(rail_spec): return generator -def rail_from_file(filename) -> RailGenerator: +def rail_from_file(filename, load_from_package=None) -> RailGenerator: """ Utility to load pickle file @@ -221,8 +221,12 @@ def rail_from_file(filename) -> RailGenerator: def generator(width, height, num_agents, num_resets): rail_env_transitions = RailEnvTransitions() - with open(filename, "rb") as file_in: - load_data = file_in.read() + if load_from_package is not None: + from importlib_resources import read_binary + load_data = read_binary(load_from_package, filename) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False) grid = np.array(data[b"grid"]) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index b3576a2b..7f42feea 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -205,7 +205,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = return generator -def schedule_from_file(filename) -> ScheduleGenerator: +def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: """ Utility to load pickle file @@ -220,8 +220,12 @@ def schedule_from_file(filename) -> ScheduleGenerator: """ def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: - with open(filename, "rb") as file_in: - load_data = file_in.read() + if load_from_package is not None: + from importlib_resources import read_binary + load_data = read_binary(load_from_package, filename) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # agents are always reset as not moving diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 4da1da4d..930cc24f 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -168,7 +168,6 @@ def test_get_entry_directions(): south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) - # Simple turn not in the base transitions ? south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) north_east_turn = transitions.rotate_transition(south_east_turn, 270) diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py index cf5c8592..d0320ab7 100644 --- a/tests/test_flatland_envs_env_utils.py +++ b/tests/test_flatland_envs_env_utils.py @@ -4,6 +4,7 @@ import pytest from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_direction from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position +from flatland.envs.rail_env_utils import load_flatland_environment_from_file depth_to_test = 5 positions_to_test = [0, 5, 1, 6, 20, 30] @@ -32,3 +33,7 @@ def test_get_direction(): assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH with pytest.raises(Exception, match="Could not determine direction"): get_direction((0, 0), (0, 0)) + + +def test_load(): + load_flatland_environment_from_file('test_001.pkl', 'env_data.tests') diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index d5dc3ac7..0114730a 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -61,7 +61,6 @@ def test_rail_environment_single_agent(): vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) - # Simple turn not in the base transitions ? south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) north_east_turn = transitions.rotate_transition(south_east_turn, 270) -- GitLab