From b7020a0e89e741300fdab147d40f0809254e36d9 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Thu, 11 Jul 2019 09:41:30 -0400
Subject: [PATCH] added pkl_load functionality to the rail_env. It now sets all
 the info from pickle file before running all the prediction and observations.
 Also added a test for file load UPDATE: Edited the test

---
 tests/test_file_load.py | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/tests/test_file_load.py b/tests/test_file_load.py
index 6bc3260..af5644f 100644
--- a/tests/test_file_load.py
+++ b/tests/test_file_load.py
@@ -1,6 +1,8 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 
+import numpy as np
+
 from flatland.envs.generators import rail_from_GridTransitionMap_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
@@ -9,6 +11,7 @@ from tests.simple_rail import make_simple_rail
 
 
 def test_load_pkl():
+    file_name = "test_pkl.pkl"
     rail, rail_map = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
@@ -16,10 +19,10 @@ def test_load_pkl():
                   number_of_agents=3,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
-    env.save("test_pkl.pkl")
+    env.save(file_name)
     # initialize agents_static
-    obs_0 = env.reset(False, False)
-    file_name = "test_pkl.pkl"
+    rails_initial = env.rail.grid
+    agents_initial = env.agents
 
     env = RailEnv(width=1,
                   height=1,
@@ -28,6 +31,10 @@ def test_load_pkl():
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   file_name=file_name
                   )
-    obs_1 = env.reset(False, False)
-    assert obs_0 == obs_1
+    rails_loaded = env.rail.grid
+    agents_loaded = env.agents
+
+    assert np.all(np.array_equal(rails_initial, rails_loaded))
+    assert agents_initial == agents_loaded
+
     return
-- 
GitLab