Skip to content
Snippets Groups Projects
Commit b7020a0e authored by Erik Nygren's avatar Erik Nygren
Browse files

added pkl_load functionality to the rail_env.

It now sets all the info from pickle file before running all the prediction and observations.
Also added a test for file load
UPDATE: Edited the test
parent de7c1afd
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
......@@ -9,6 +11,7 @@ from tests.simple_rail import make_simple_rail
def test_load_pkl():
file_name = "test_pkl.pkl"
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
......@@ -16,10 +19,10 @@ def test_load_pkl():
number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
)
env.save("test_pkl.pkl")
env.save(file_name)
# initialize agents_static
obs_0 = env.reset(False, False)
file_name = "test_pkl.pkl"
rails_initial = env.rail.grid
agents_initial = env.agents
env = RailEnv(width=1,
height=1,
......@@ -28,6 +31,10 @@ def test_load_pkl():
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
file_name=file_name
)
obs_1 = env.reset(False, False)
assert obs_0 == obs_1
rails_loaded = env.rail.grid
agents_loaded = env.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment