diff --git a/env_data/tests/test_001.pkl b/env_data/tests/test_001.pkl new file mode 100644 index 0000000000000000000000000000000000000000..823e594e41f26f6febac6a241f3befcec9f81cf7 Binary files /dev/null and b/env_data/tests/test_001.pkl differ diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py index 940193198a18a63f669d6a50a04cbd8f67740a32..c9c6b00375ef4577880e2b8c98c2ff9dc946a7fa 100644 --- a/flatland/envs/distance_map.py +++ b/flatland/envs/distance_map.py @@ -58,6 +58,8 @@ class DistanceMap: self.reset_was_called = True self.agents = agents self.rail = rail + self.env_height = rail.height + self.env_width = rail.width def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap): self.agents_previous_computation = self.agents diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 712d24258e042f0c81eb0afdc33ff880387cac95..6312cca3e51fa0e60d2493587ff5d699fcd355c4 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -167,7 +167,8 @@ class TreeObsForRailEnv(ObservationBuilder): # Root node - current position # Here information about the agent itself is stored - observation = [0, 0, 0, 0, 0, 0, self.env.distance_map.get()[(handle, *agent.position, agent.direction)], 0, 0, + distance_map = self.env.distance_map.get() + observation = [0, 0, 0, 0, 0, 0, distance_map[(handle, *agent.position, agent.direction)], 0, 0, agent.malfunction_data['malfunction'], agent.speed_data['speed']] visited = OrderedSet() diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..722f54bcb5332174199aab070b02308802619bb1 --- /dev/null +++ b/flatland/envs/rail_env_utils.py @@ -0,0 +1,17 @@ +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from flatland.envs.rail_generators import rail_from_file +from flatland.envs.schedule_generators import schedule_from_file + + +def load_flatland_environment_from_file(file_name, load_from_package=None): + environment = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name, load_from_package), + number_of_agents=1, + schedule_generator=schedule_from_file(file_name,load_from_package), + obs_builder_object=TreeObsForRailEnv( + max_depth=2, + predictor=ShortestPathPredictorForRailEnv(max_depth=10))) + return environment diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index 60c606f789f0d83f04fd5c549e155f437c977b7d..0bce90aac8faf1df4758755a8e2d82d6e20dd133 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -204,7 +204,7 @@ def rail_from_manual_specifications_generator(rail_spec): return generator -def rail_from_file(filename) -> RailGenerator: +def rail_from_file(filename, load_from_package=None) -> RailGenerator: """ Utility to load pickle file @@ -221,8 +221,12 @@ def rail_from_file(filename) -> RailGenerator: def generator(width, height, num_agents, num_resets): rail_env_transitions = RailEnvTransitions() - with open(filename, "rb") as file_in: - load_data = file_in.read() + if load_from_package is not None: + from importlib_resources import read_binary + load_data = read_binary(load_from_package, filename) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False) grid = np.array(data[b"grid"]) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index b3576a2bec77f75afc9331cc6c190649590a990c..7f42feeacd0ef50b56846540a9b2af9d147eafb0 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -205,7 +205,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] = return generator -def schedule_from_file(filename) -> ScheduleGenerator: +def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: """ Utility to load pickle file @@ -220,8 +220,12 @@ def schedule_from_file(filename) -> ScheduleGenerator: """ def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct: - with open(filename, "rb") as file_in: - load_data = file_in.read() + if load_from_package is not None: + from importlib_resources import read_binary + load_data = read_binary(load_from_package, filename) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # agents are always reset as not moving diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py index 4da1da4d23cc98ec7032530ee51820f29b637c17..930cc24fb4a9be817c14f2cca149747ac6ca370b 100644 --- a/tests/test_flatland_core_transition_map.py +++ b/tests/test_flatland_core_transition_map.py @@ -168,7 +168,6 @@ def test_get_entry_directions(): south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) - # Simple turn not in the base transitions ? south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) north_east_turn = transitions.rotate_transition(south_east_turn, 270) diff --git a/tests/test_flatland_envs_env_utils.py b/tests/test_flatland_envs_env_utils.py index cf5c8592708eef237bcf29308032df49753860bd..d0320ab7f0a42423172f9b9cd38d6b5bb66c73b9 100644 --- a/tests/test_flatland_envs_env_utils.py +++ b/tests/test_flatland_envs_env_utils.py @@ -4,6 +4,7 @@ import pytest from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.core.grid.grid4_utils import get_direction from flatland.core.grid.grid_utils import position_to_coordinate, coordinate_to_position +from flatland.envs.rail_env_utils import load_flatland_environment_from_file depth_to_test = 5 positions_to_test = [0, 5, 1, 6, 20, 30] @@ -32,3 +33,7 @@ def test_get_direction(): assert get_direction((1, 0), (0, 0)) == Grid4TransitionsEnum.NORTH with pytest.raises(Exception, match="Could not determine direction"): get_direction((0, 0), (0, 0)) + + +def test_load(): + load_flatland_environment_from_file('test_001.pkl', 'env_data.tests') diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index d5dc3ac7af4be6ebd8c5cbeaf705bb710d36d138..0114730a2ac1d0df618eea773dfcf1cd7175dee2 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -61,7 +61,6 @@ def test_rail_environment_single_agent(): vertical_line = cells[1] south_symmetrical_switch = cells[6] north_symmetrical_switch = transitions.rotate_transition(south_symmetrical_switch, 180) - # Simple turn not in the base transitions ? south_east_turn = int('0100000000000010', 2) south_west_turn = transitions.rotate_transition(south_east_turn, 90) north_east_turn = transitions.rotate_transition(south_east_turn, 270)