diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index b1f46ec638d8a865a2ce36a9715d89b452c782f2..8fc70ce37390ed2bde29808f97de5522d877cdf2 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -172,18 +172,21 @@ class TreeObsForRailEnv(ObservationBuilder): if handles is None: handles = [] if self.predictor: + self.max_prediction_depth = 0 self.predicted_pos = {} self.predicted_dir = {} self.predictions = self.predictor.get(custom_args={'distance_map': self.distance_map}) - for t in range(len(self.predictions[0])): - pos_list = [] - dir_list = [] - for a in handles: - pos_list.append(self.predictions[a][t][1:3]) - dir_list.append(self.predictions[a][t][3]) - self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) - self.predicted_dir.update({t: dir_list}) - self.max_prediction_depth = len(self.predicted_pos) + if self.predictions: + + for t in range(len(self.predictions[0])): + pos_list = [] + dir_list = [] + for a in handles: + pos_list.append(self.predictions[a][t][1:3]) + dir_list.append(self.predictions[a][t][3]) + self.predicted_pos.update({t: coordinate_to_position(self.env.width, pos_list)}) + self.predicted_dir.update({t: dir_list}) + self.max_prediction_depth = len(self.predicted_pos) observations = {} for h in handles: observations[h] = self.get(h) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 149e32bd66bb0c0e1ac96bbfe213705b4de63253..03879eb4cbf1072dace4dafacd36c826e6334ab2 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -80,6 +80,7 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), + file_name=None ): """ Environment init. @@ -110,6 +111,7 @@ class RailEnv(Environment): obs_builder_object: ObservationBuilder object ObservationBuilder-derived object that takes builds observation vectors for each agent. + file_name: you can load a pickle file. """ self.rail_generator = rail_generator @@ -117,14 +119,10 @@ class RailEnv(Environment): self.width = width self.height = height - self.obs_builder = obs_builder_object - self.obs_builder._set_env(self) - - self.action_space = [1] - self.observation_space = self.obs_builder.observation_space # updated on resets? - self.rewards = [0] * number_of_agents self.done = False + self.obs_builder = obs_builder_object + self.obs_builder._set_env(self) self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) @@ -135,6 +133,14 @@ class RailEnv(Environment): self.agents = [None] * number_of_agents # live agents self.agents_static = [None] * number_of_agents # static agent information self.num_resets = 0 + if file_name: + self.loaded_file = file_name + else: + self.loaded_file = None + + self.action_space = [1] + self.observation_space = self.obs_builder.observation_space # updated on resets? + self.reset() self.num_resets = 0 # yes, set it to zero again! @@ -175,6 +181,9 @@ class RailEnv(Environment): if replace_agents: self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) + if self.loaded_file: + self.load(self.loaded_file) + self.restart_agents() for i_agent in range(self.get_num_agents()): @@ -425,6 +434,9 @@ class RailEnv(Environment): load_data = file_in.read() self.set_full_state_msg(load_data) + def load_pkl(self, pkl_data): + self.set_full_state_msg(pkl_data) + def load_resource(self, package, resource): from importlib_resources import read_binary load_data = read_binary(package, resource) diff --git a/tests/test_file_load.py b/tests/test_file_load.py new file mode 100644 index 0000000000000000000000000000000000000000..af5644f3ee81d72641449e7184c078830f65bdc8 --- /dev/null +++ b/tests/test_file_load.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import numpy as np + +from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from tests.simple_rail import make_simple_rail + + +def test_load_pkl(): + file_name = "test_pkl.pkl" + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env.save(file_name) + # initialize agents_static + rails_initial = env.rail.grid + agents_initial = env.agents + + env = RailEnv(width=1, + height=1, + rail_generator=empty_rail_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + file_name=file_name + ) + rails_loaded = env.rail.grid + agents_loaded = env.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded + + return