diff --git a/examples/introduction_flatland_2_1.py b/examples/introduction_flatland_2_1.py index fa48acb23c9bf7e5f7c57df54abc733bc75ed2e7..3ec2d0d4cf2c63bd916b74d06c21a73d589c8c60 100644 --- a/examples/introduction_flatland_2_1.py +++ b/examples/introduction_flatland_2_1.py @@ -1,4 +1,5 @@ import numpy as np +import os # In Flatland you can use custom observation builders and predicitors # Observation builders generate the observation needed by the controller @@ -84,7 +85,7 @@ env = RailEnv(width=width, env.reset() # Initiate the renderer -env_renderer = RenderTool(env, gl="PILSVG", +env_renderer = RenderTool(env, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, show_debug=False, screen_height=600, # Adjust these parameters to fit your resolution @@ -243,6 +244,8 @@ score = 0 # Run episode frame_step = 0 +os.makedirs("tmp/frames", exist_ok=True) + for step in range(500): # Chose an action for each agent in the environment for a in range(env.get_num_agents()): @@ -255,7 +258,7 @@ for step in range(500): next_obs, all_rewards, done, _ = env.step(action_dict) env_renderer.render_env(show=True, show_observations=False, show_predictions=False) - env_renderer.gl.save_image('./misc/Fames2/flatland_frame_{:04d}.png'.format(step)) + env_renderer.gl.save_image('tmp/frames/flatland_frame_{:04d}.png'.format(step)) frame_step += 1 # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/flatland/envs/persistence.py b/flatland/envs/persistence.py index e77e97aa10405b9b9eb0119a24d463a3402103f5..31339183f851155d446238bb50a051cba2e9ec36 100644 --- a/flatland/envs/persistence.py +++ b/flatland/envs/persistence.py @@ -115,9 +115,12 @@ class RailEnvPersister(object): # TODO: inefficient - each one of these generators loads the complete env file. env = rail_env.RailEnv(width=1, height=1, - rail_generator=rail_gen.rail_from_file(filename), - schedule_generator=sched_gen.schedule_from_file(filename), - malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename), + rail_generator=rail_gen.rail_from_file(filename, + load_from_package=load_from_package), + schedule_generator=sched_gen.schedule_from_file(filename, + load_from_package=load_from_package), + malfunction_generator_and_process_data=mal_gen.malfunction_from_file(filename, + load_from_package=load_from_package), obs_builder_object=DummyObservationBuilder(), record_steps=True) diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index 0e980219855f5859940507be7f160b7fdfc05990..196f4a382ac0b92f975384f9aeaa3d4127d6e93a 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -8,7 +8,7 @@ from numpy import array from recordtype import recordtype from flatland.utils.graphics_pil import PILGL, PILSVG -from flatland.utils.flask_util import simple_flask_server + # TODO: suggested renaming to RailEnvRenderTool, as it will only work with RailEnv! @@ -23,7 +23,7 @@ class AgentRenderVariant(IntEnum): class RenderTool(object): """ RenderTool is a facade to a renderer, either local or browser """ - def __init__(self, env, gl="BROWSER", jupyter=False, + def __init__(self, env, gl="PGL", jupyter=False, agent_render_variant=AgentRenderVariant.ONE_STEP_BEHIND, show_debug=False, clear_debug_text=True, screen_width=800, screen_height=600, host="localhost", port=None): @@ -45,6 +45,7 @@ class RenderTool(object): self.gl = self.renderer.gl elif gl == "BROWSER": + from flatland.utils.flask_util import simple_flask_server self.renderer = RenderBrowser(env, host=host, port=port) else: print("[", gl, "] not found, switch to PILSVG or BROWSER") diff --git a/tests/test_flatland_envs_rail_env.py b/tests/test_flatland_envs_rail_env.py index 07cf778925864e2f1f871bf4621d9c4a65bca220..9fce4c0b154d0433b2b88559cf433ec721d61461 100644 --- a/tests/test_flatland_envs_rail_env.py +++ b/tests/test_flatland_envs_rail_env.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- import numpy as np +import os from flatland.core.grid.rail_env_grid import RailEnvTransitions from flatland.core.transition_map import GridTransitionMap @@ -21,8 +22,8 @@ def test_load_env(): #env = RailEnv(10, 10) #env.reset() # env.load_resource('env_data.tests', 'test-10x10.mpk') - #env, env_dict = RailEnvPersister.load_resource("env_data.tests", "test-10x10.mpk") - env, env_dict = RailEnvPersister.load_new("./env_data/tests/test-10x10.mpk") + env, env_dict = RailEnvPersister.load_resource("env_data.tests", "test-10x10.mpk") + #env, env_dict = RailEnvPersister.load_new("./env_data/tests/test-10x10.mpk") agent_static = EnvAgent((0, 0), 2, (5, 5), False) env.add_agent(agent_static) @@ -41,12 +42,13 @@ def test_save_load(): agent_2_dir = env.agents[1].direction agent_2_tar = env.agents[1].target - env.save("test_save_2.pkl") - RailEnvPersister.save(env, "test_save.pkl") - + os.makedirs("tmp", exist_ok=True) + + RailEnvPersister.save(env, "tmp/test_save.pkl") + env.save("tmp/test_save_2.pkl") #env.load("test_save.dat") - env, env_dict = RailEnvPersister.load_new("test_save.pkl") + env, env_dict = RailEnvPersister.load_new("tmp/test_save.pkl") assert (env.width == 10) assert (env.height == 10) assert (len(env.agents) == 2) @@ -58,6 +60,29 @@ def test_save_load(): assert (agent_2_tar == env.agents[1].target) +def test_save_load_mpk(): + env = RailEnv(width=10, height=10, + rail_generator=complex_rail_generator(nr_start_goal=2, nr_extra=5, min_dist=6, seed=1), + schedule_generator=complex_schedule_generator(), number_of_agents=2) + env.reset() + + os.makedirs("tmp", exist_ok=True) + + RailEnvPersister.save(env, "tmp/test_save.mpk") + + #env.load("test_save.dat") + env2, env_dict = RailEnvPersister.load_new("tmp/test_save.mpk") + assert (env.width == env2.width) + assert (env.height == env2.height) + assert (len(env2.agents) == len(env.agents)) + + for agent1, agent2 in zip(env.agents, env2.agents): + assert(agent1.position == agent2.position) + assert(agent1.direction == agent2.direction) + assert(agent1.target == agent2.target) + + + def test_rail_environment_single_agent(): # We instantiate the following map on a 3x3 grid # _ _