From de7c1afd33df3d8fba5894e700c38da80a00a83b Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Wed, 10 Jul 2019 18:09:40 -0400 Subject: [PATCH] added pkl_load functionality to the rail_env. It now sets all the info from pickle file before running all the prediction and observations. Also added a test for file load --- flatland/envs/rail_env.py | 19 +++++++++---------- tests/test_file_load.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 10 deletions(-) create mode 100644 tests/test_file_load.py diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 029851a8..03879eb4 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -80,7 +80,7 @@ class RailEnv(Environment): rail_generator=random_rail_generator(), number_of_agents=1, obs_builder_object=TreeObsForRailEnv(max_depth=2), - pkl_load=None + file_name=None ): """ Environment init. @@ -111,7 +111,7 @@ class RailEnv(Environment): obs_builder_object: ObservationBuilder object ObservationBuilder-derived object that takes builds observation vectors for each agent. - pkl_load: you can load a pickle file. + file_name: you can load a pickle file. """ self.rail_generator = rail_generator @@ -121,6 +121,8 @@ class RailEnv(Environment): self.rewards = [0] * number_of_agents self.done = False + self.obs_builder = obs_builder_object + self.obs_builder._set_env(self) self.dones = dict.fromkeys(list(range(number_of_agents)) + ["__all__"], False) @@ -131,13 +133,10 @@ class RailEnv(Environment): self.agents = [None] * number_of_agents # live agents self.agents_static = [None] * number_of_agents # static agent information self.num_resets = 0 - if pkl_load: - self.loaded_data = pkl_load + if file_name: + self.loaded_file = file_name else: - self.loaded_data = None - - self.obs_builder = obs_builder_object - self.obs_builder._set_env(self) + self.loaded_file = None self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -182,8 +181,8 @@ class RailEnv(Environment): if replace_agents: self.agents_static = EnvAgentStatic.from_lists(*tRailAgents[1:5]) - if self.loaded_data: - self.load_pkl(self.loaded_data) + if self.loaded_file: + self.load(self.loaded_file) self.restart_agents() diff --git a/tests/test_file_load.py b/tests/test_file_load.py new file mode 100644 index 00000000..6bc32604 --- /dev/null +++ b/tests/test_file_load.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator +from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.predictions import ShortestPathPredictorForRailEnv +from flatland.envs.rail_env import RailEnv +from tests.simple_rail import make_simple_rail + + +def test_load_pkl(): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + number_of_agents=3, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + ) + env.save("test_pkl.pkl") + # initialize agents_static + obs_0 = env.reset(False, False) + file_name = "test_pkl.pkl" + + env = RailEnv(width=1, + height=1, + rail_generator=empty_rail_generator(), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), + file_name=file_name + ) + obs_1 = env.reset(False, False) + assert obs_0 == obs_1 + return -- GitLab