diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index c3f569ae0d927a6e71803a24d921337c65d39c29..907b4a25edc0d44ff83d7a063795b55e020362b3 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,10 +1,12 @@ +import msgpack import numpy as np -from flatland.core.transition_map import GridTransitionMap +from flatland.core.grid.grid4_utils import get_direction, mirror +from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions +from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.grid4_generators_utils import connect_rail -from flatland.core.grid.grid_utils import distance_on_rail -from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -195,6 +197,40 @@ def rail_from_manual_specifications_generator(rail_spec): return generator +def rail_from_file(filename): + """ + Utility to load pickle file + + Parameters + ------- + input_file : Pickle file generated by env.save() or editor + + Returns + ------- + function + Generator function that always returns a GridTransitionMap object with + the matrix of correct 16-bit bitmaps for each rail_spec_of_cell. + """ + + def generator(width, height, num_agents, num_resets): + rail_env_transitions = RailEnvTransitions() + with open(filename, "rb") as file_in: + load_data = file_in.read() + data = msgpack.unpackb(load_data, use_list=False) + grid = np.array(data[b"grid"]) + rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) + rail.grid = grid + # agents are always reset as not moving + agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + # setup with loaded data + agents_position = [a.position for a in agents_static] + agents_direction = [a.direction for a in agents_static] + agents_target = [a.target for a in agents_static] + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + + return generator + + def rail_from_GridTransitionMap_generator(rail_map): """ Utility to convert a rail given by a GridTransitionMap map with the correct diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8fc70ce37390ed2bde29808f97de5522d877cdf2..3929b9e191615ba91fb0e13df6d8ae040b401a5a 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -75,7 +75,6 @@ class TreeObsForRailEnv(ObservationBuilder): orientation within it) to each agent's target cell. """ # Returns max distance to target, from the farthest away node, while filling in distance_map - self.distance_map[target_nr, position[0], position[1], :] = 0 # Fill in the (up to) 4 neighboring nodes diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 03879eb4cbf1072dace4dafacd36c826e6334ab2..34762c0bb1d4de1cb82ee82e139113977b7e7a7d 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -80,7 +80,6 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), - file_name=None ): """ Environment init. @@ -133,10 +132,6 @@ class RailEnv(Environment): self.agents = [None] * number_of_agents # live agents self.agents_static = [None] * number_of_agents # static agent information self.num_resets = 0 - if file_name: - self.loaded_file = file_name - else: - self.loaded_file = None self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -177,13 +172,11 @@ class RailEnv(Environment): if regen_rail or self.rail is None: self.rail = tRailAgents[0] + self.height, self.width = self.rail.grid.shape if replace_agents: self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) - if self.loaded_file: - self.load(self.loaded_file) - self.restart_agents() for i_agent in range(self.get_num_agents()): diff --git a/tests/test_file_load.py b/tests/test_file_load.py index af5644f3ee81d72641449e7184c078830f65bdc8..57fa45cb29dab07b84f43c97a45043c9dfa39979 100644 --- a/tests/test_file_load.py +++ b/tests/test_file_load.py @@ -3,7 +3,7 @@ import numpy as np -from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator +from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -26,10 +26,9 @@ def test_load_pkl(): env = RailEnv(width=1, height=1, - rail_generator=empty_rail_generator(), + rail_generator=rail_from_file(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - file_name=file_name ) rails_loaded = env.rail.grid agents_loaded = env.agents