diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index ee75df2834898042f12df6065c220ea51358c0bb..ff21046bf22c59235a9d8c76656c9f368b1e799a 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,9 +1,11 @@ +import msgpack import numpy as np from flatland.core.grid.grid4_utils import get_direction, mirror from flatland.core.grid.grid_utils import distance_on_rail from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap +from flatland.envs.agent_utils import EnvAgentStatic from flatland.envs.grid4_generators_utils import connect_rail from flatland.envs.grid4_generators_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -195,7 +197,7 @@ def rail_from_manual_specifications_generator(rail_spec): return generator -def rail_from_data(input_data): +def rail_from_data(filename): """ Utility to load pickle file @@ -210,19 +212,20 @@ def rail_from_data(input_data): the matrix of correct 16-bit bitmaps for each rail_spec_of_cell. """ - def generator(): - data = msgpack.unpackb(msg_data, use_list=False) - self.rail.grid = np.array(input_data[b"grid"]) - rail = GridTransitionMap(width=width, height=height, transitions=rail_env_transitions) + def generator(width, height, num_agents, num_resets): + rail_env_transitions = RailEnvTransitions() + with open(filename, "rb") as file_in: + load_data = file_in.read() + data = msgpack.unpackb(load_data, use_list=False) + grid = np.array(data[b"grid"]) + rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) + rail.grid = grid # agents are always reset as not moving - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] # setup with loaded data - self.height, self.width = self.rail.grid.shape - self.rail.height = self.height - self.rail.width = self.width - self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) - + agents_position = [a.position for a in agents_static] + agents_direction = [a.direction for a in agents_static] + agents_target = [a.target for a in agents_static] return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 8fc70ce37390ed2bde29808f97de5522d877cdf2..b7d1ff26126f6546db42c6e8d5556f038008e70b 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -60,6 +60,7 @@ class TreeObsForRailEnv(ObservationBuilder): def _compute_distance_map(self): agents = self.env.agents nb_agents = len(agents) + print(nb_agents) self.distance_map = np.inf * np.ones(shape=(nb_agents, self.env.height, self.env.width, @@ -75,7 +76,6 @@ class TreeObsForRailEnv(ObservationBuilder): orientation within it) to each agent's target cell. """ # Returns max distance to target, from the farthest away node, while filling in distance_map - self.distance_map[target_nr, position[0], position[1], :] = 0 # Fill in the (up to) 4 neighboring nodes diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 03879eb4cbf1072dace4dafacd36c826e6334ab2..34762c0bb1d4de1cb82ee82e139113977b7e7a7d 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -80,7 +80,6 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), - file_name=None ): """ Environment init. @@ -133,10 +132,6 @@ class RailEnv(Environment): self.agents = [None] * number_of_agents # live agents self.agents_static = [None] * number_of_agents # static agent information self.num_resets = 0 - if file_name: - self.loaded_file = file_name - else: - self.loaded_file = None self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -177,13 +172,11 @@ class RailEnv(Environment): if regen_rail or self.rail is None: self.rail = tRailAgents[0] + self.height, self.width = self.rail.grid.shape if replace_agents: self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) - if self.loaded_file: - self.load(self.loaded_file) - self.restart_agents() for i_agent in range(self.get_num_agents()): diff --git a/tests/test_file_load.py b/tests/test_file_load.py index af5644f3ee81d72641449e7184c078830f65bdc8..2b929b174cca8666579e7a5c1eccce05a1a19176 100644 --- a/tests/test_file_load.py +++ b/tests/test_file_load.py @@ -3,7 +3,7 @@ import numpy as np -from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator +from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_data from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -26,10 +26,9 @@ def test_load_pkl(): env = RailEnv(width=1, height=1, - rail_generator=empty_rail_generator(), + rail_generator=rail_from_data(file_name), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), - file_name=file_name ) rails_loaded = env.rail.grid agents_loaded = env.agents