From b7020a0e89e741300fdab147d40f0809254e36d9 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 11 Jul 2019 09:41:30 -0400 Subject: [PATCH] added pkl_load functionality to the rail_env. It now sets all the info from pickle file before running all the prediction and observations. Also added a test for file load UPDATE: Edited the test --- tests/test_file_load.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/test_file_load.py b/tests/test_file_load.py index 6bc3260..af5644f 100644 --- a/tests/test_file_load.py +++ b/tests/test_file_load.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import numpy as np + from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv @@ -9,6 +11,7 @@ from tests.simple_rail import make_simple_rail def test_load_pkl(): + file_name = "test_pkl.pkl" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], @@ -16,10 +19,10 @@ def test_load_pkl(): number_of_agents=3, obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) - env.save("test_pkl.pkl") + env.save(file_name) # initialize agents_static - obs_0 = env.reset(False, False) - file_name = "test_pkl.pkl" + rails_initial = env.rail.grid + agents_initial = env.agents env = RailEnv(width=1, height=1, @@ -28,6 +31,10 @@ def test_load_pkl(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), file_name=file_name ) - obs_1 = env.reset(False, False) - assert obs_0 == obs_1 + rails_loaded = env.rail.grid + agents_loaded = env.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded)) + assert agents_initial == agents_loaded + return -- GitLab