Skip to content
Snippets Groups Projects
Commit b7020a0e authored by Erik Nygren's avatar Erik Nygren
Browse files

added pkl_load functionality to the rail_env.

It now sets all the info from pickle file before running all the prediction and observations.
Also added a test for file load
UPDATE: Edited the test
parent de7c1afd
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv
...@@ -9,6 +11,7 @@ from tests.simple_rail import make_simple_rail ...@@ -9,6 +11,7 @@ from tests.simple_rail import make_simple_rail
def test_load_pkl(): def test_load_pkl():
file_name = "test_pkl.pkl"
rail, rail_map = make_simple_rail() rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1], env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0], height=rail_map.shape[0],
...@@ -16,10 +19,10 @@ def test_load_pkl(): ...@@ -16,10 +19,10 @@ def test_load_pkl():
number_of_agents=3, number_of_agents=3,
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
) )
env.save("test_pkl.pkl") env.save(file_name)
# initialize agents_static # initialize agents_static
obs_0 = env.reset(False, False) rails_initial = env.rail.grid
file_name = "test_pkl.pkl" agents_initial = env.agents
env = RailEnv(width=1, env = RailEnv(width=1,
height=1, height=1,
...@@ -28,6 +31,10 @@ def test_load_pkl(): ...@@ -28,6 +31,10 @@ def test_load_pkl():
obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
file_name=file_name file_name=file_name
) )
obs_1 = env.reset(False, False) rails_loaded = env.rail.grid
assert obs_0 == obs_1 agents_loaded = env.agents
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
return return
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment