From 31dbc9167cc0c95ac3fb1f3946ce7de8a8f99406 Mon Sep 17 00:00:00 2001
From: hagrid67 <jdhwatson@gmail.com>
Date: Wed, 3 Jun 2020 11:33:16 +0100
Subject: [PATCH] switched back to load_resource in rail_env test. passed
 load_from_package arg down to from_file generators.  switch default renderer
 to PGL from BROWSER. Added conditional import for flask_util in rendertools
 to hopefully avoid some startup delays and hanging tests.  Added os.makedirs
 to intro_flatland 2.1 for saved frames folder.  Removed explicit gl=PILSVG

---
 examples/introduction_flatland_2_1.py |  7 +++--
 flatland/envs/persistence.py          |  9 ++++---
 flatland/utils/rendertools.py         |  5 ++--
 tests/test_flatland_envs_rail_env.py  | 37 ++++++++++++++++++++++-----
 4 files changed, 45 insertions(+), 13 deletions(-)

diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py
index fa48acb2..3ec2d0d4 100644
--- a/examples/introduction_flatland_2_1.py
+++ b/examples/introduction_flatland_2_1.py
@@ -1,4 +1,5 @@
 import numpy as np
+import os
 
 # In Flatland you can use custom observation builders and predicitors
 # Observation builders generate the observation needed by the controller
@@ -84,7 +85,7 @@ env = RailEnv(width=width,
 env.reset()
 
 # Initiate the renderer
-env_renderer = RenderTool(env, gl="PILSVG",
+env_renderer = RenderTool(env,
                           agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                           show_debug=False,
                           screen_height=600,  # Adjust these parameters to fit your resolution
@@ -243,6 +244,8 @@ score = 0
 # Run episode
 frame_step = 0
 
+os.makedirs("tmp/frames", exist_ok=True)
+
 for step in range(500):
     # Chose an action for each agent in the environment
     for a in range(env.get_num_agents()):
@@ -255,7 +258,7 @@ for step in range(500):
     next_obs, all_rewards, done, _ = env.step(action_dict)
 
     env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
-    env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step))
+    env_renderer.gl.save_image('tmp/frames/flatland_frame_{:04d}.png'.format(step))
     frame_step += 1
     # Update replay buffer and train agent
     for a in range(env.get_num_agents()):
diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py
index e77e97aa..31339183 100644
--- a/flatland/envs/persistence.py
+++ b/flatland/envs/persistence.py
@@ -115,9 +115,12 @@ class RailEnvPersister(object):
 
         # TODO: inefficient - each one of these generators loads the complete env file.
         env = rail_env.RailEnv(width=1, height=1,
-                rail_generator=rail_gen.rail_from_file(filename),
-                schedule_generator=sched_gen.schedule_from_file(filename),
-                malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename),
+                rail_generator=rail_gen.rail_from_file(filename, 
+                    load_from_package=load_from_package),
+                schedule_generator=sched_gen.schedule_from_file(filename,
+                    load_from_package=load_from_package),
+                malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename,
+                    load_from_package=load_from_package),
                 obs_builder_object=DummyObservationBuilder(),
                 record_steps=True)
 
diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py
index 0e980219..196f4a38 100644
--- a/flatland/utils/rendertools.py
+++ b/flatland/utils/rendertools.py
@@ -8,7 +8,7 @@ from numpy import array
 from recordtype import recordtype
 
 from flatland.utils.graphics_pil import PILGL, PILSVG
-from flatland.utils.flask_util import simple_flask_server
+
 
 # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv!
 
@@ -23,7 +23,7 @@ class AgentRenderVariant(IntEnum):
 class RenderTool(object):
     """ RenderTool is a facade to a renderer, either local or browser
     """
-    def __init__(self, env, gl="BROWSER", jupyter=False,
+    def __init__(self, env, gl="PGL", jupyter=False,
                  agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND,
                  show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600,
                  host="localhost", port=None):
@@ -45,6 +45,7 @@ class RenderTool(object):
             self.gl = self.renderer.gl
 
         elif gl == "BROWSER":
+            from flatland.utils.flask_util import simple_flask_server
             self.renderer = RenderBrowser(env, host=host, port=port)
         else:
             print("[", gl, "] not found, switch to PILSVG or BROWSER")
diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py
index 07cf7789..9fce4c0b 100644
--- a/tests/test_flatland_envs_rail_env.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 import numpy as np
+import os
 
 from flatland.core.grid.rail_env_grid import RailEnvTransitions
 from flatland.core.transition_map import GridTransitionMap
@@ -21,8 +22,8 @@ def test_load_env():
     #env = RailEnv(10, 10)
     #env.reset()
     # env.load_resource('env_data.tests', 'test-10x10.mpk')
-    #env, env_dict = RailEnvPersister.load_resource("env_data.tests", "test-10x10.mpk")
-    env, env_dict = RailEnvPersister.load_new("./env_data/tests/test-10x10.mpk")
+    env, env_dict = RailEnvPersister.load_resource("env_data.tests", "test-10x10.mpk")
+    #env, env_dict = RailEnvPersister.load_new("./env_data/tests/test-10x10.mpk")
 
     agent_static = EnvAgent((0, 0), 2, (5, 5), False)
     env.add_agent(agent_static)
@@ -41,12 +42,13 @@ def test_save_load():
     agent_2_dir = env.agents[1].direction
     agent_2_tar = env.agents[1].target
     
-    env.save("test_save_2.pkl")
-    RailEnvPersister.save(env, "test_save.pkl")
-    
+    os.makedirs("tmp", exist_ok=True)
+
+    RailEnvPersister.save(env, "tmp/test_save.pkl")
+    env.save("tmp/test_save_2.pkl")
 
     #env.load("test_save.dat")
-    env, env_dict = RailEnvPersister.load_new("test_save.pkl")
+    env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl")
     assert (env.width == 10)
     assert (env.height == 10)
     assert (len(env.agents) == 2)
@@ -58,6 +60,29 @@ def test_save_load():
     assert (agent_2_tar == env.agents[1].target)
 
 
+def test_save_load_mpk():
+    env = RailEnv(width=10, height=10,
+                  rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1),
+                  schedule_generator=complex_schedule_generator(), number_of_agents=2)
+    env.reset()
+
+    os.makedirs("tmp", exist_ok=True)
+
+    RailEnvPersister.save(env, "tmp/test_save.mpk")
+
+    #env.load("test_save.dat")
+    env2, env_dict = RailEnvPersister.load_new("tmp/test_save.mpk")
+    assert (env.width == env2.width)
+    assert (env.height == env2.height)
+    assert (len(env2.agents) == len(env.agents))
+    
+    for agent1, agent2 in zip(env.agents, env2.agents):
+        assert(agent1.position == agent2.position)
+        assert(agent1.direction == agent2.direction)
+        assert(agent1.target == agent2.target)
+
+
+
 def test_rail_environment_single_agent():
     # We instantiate the following map on a 3x3 grid
     #  _  _
-- 
GitLab