From aa91c6ffd02d5101341d60b1fc451f1c6d6226d3 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 12 Jun 2019 11:35:51 +0200 Subject: [PATCH] make env-data a module named env_data to access it with importlib_resource --- {env-data => env_data}/__init__.py | 0 {env-data => env_data}/railway/__init__.py | 0 .../railway/complex_scene.pkl | Bin .../railway/example_flatland_000.pkl | Bin .../railway/example_flatland_001.pkl | Bin .../railway/example_network_000.pkl | Bin .../railway/example_network_001.pkl | Bin .../railway/example_network_002.pkl | Bin .../railway/example_network_003.pkl | Bin {env-data => env_data}/tests/__init__.py | 0 {env-data => env_data}/tests/test-10x10.mpk | Bin {env-data => env_data}/tests/test1.npy | Bin examples/demo.py | 10 +---- examples/simple_example_2.py | 2 +- flatland/core/transition_map.py | 12 +++-- flatland/envs/generators.py | 41 ------------------ images/__init__.py | 0 tests/test_env_edit.py | 2 +- tests/test_rendertools.py | 12 ++--- 19 files changed, 17 insertions(+), 62 deletions(-) rename {env-data => env_data}/__init__.py (100%) rename {env-data => env_data}/railway/__init__.py (100%) rename {env-data => env_data}/railway/complex_scene.pkl (100%) rename {env-data => env_data}/railway/example_flatland_000.pkl (100%) rename {env-data => env_data}/railway/example_flatland_001.pkl (100%) rename {env-data => env_data}/railway/example_network_000.pkl (100%) rename {env-data => env_data}/railway/example_network_001.pkl (100%) rename {env-data => env_data}/railway/example_network_002.pkl (100%) rename {env-data => env_data}/railway/example_network_003.pkl (100%) rename {env-data => env_data}/tests/__init__.py (100%) rename {env-data => env_data}/tests/test-10x10.mpk (100%) rename {env-data => env_data}/tests/test1.npy (100%) create mode 100644 images/__init__.py diff --git a/env-data/__init__.py b/env_data/__init__.py similarity index 100% rename from env-data/__init__.py rename to env_data/__init__.py diff --git a/env-data/railway/__init__.py b/env_data/railway/__init__.py similarity index 100% rename from env-data/railway/__init__.py rename to env_data/railway/__init__.py diff --git a/env-data/railway/complex_scene.pkl b/env_data/railway/complex_scene.pkl similarity index 100% rename from env-data/railway/complex_scene.pkl rename to env_data/railway/complex_scene.pkl diff --git a/env-data/railway/example_flatland_000.pkl b/env_data/railway/example_flatland_000.pkl similarity index 100% rename from env-data/railway/example_flatland_000.pkl rename to env_data/railway/example_flatland_000.pkl diff --git a/env-data/railway/example_flatland_001.pkl b/env_data/railway/example_flatland_001.pkl similarity index 100% rename from env-data/railway/example_flatland_001.pkl rename to env_data/railway/example_flatland_001.pkl diff --git a/env-data/railway/example_network_000.pkl b/env_data/railway/example_network_000.pkl similarity index 100% rename from env-data/railway/example_network_000.pkl rename to env_data/railway/example_network_000.pkl diff --git a/env-data/railway/example_network_001.pkl b/env_data/railway/example_network_001.pkl similarity index 100% rename from env-data/railway/example_network_001.pkl rename to env_data/railway/example_network_001.pkl diff --git a/env-data/railway/example_network_002.pkl b/env_data/railway/example_network_002.pkl similarity index 100% rename from env-data/railway/example_network_002.pkl rename to env_data/railway/example_network_002.pkl diff --git a/env-data/railway/example_network_003.pkl b/env_data/railway/example_network_003.pkl similarity index 100% rename from env-data/railway/example_network_003.pkl rename to env_data/railway/example_network_003.pkl diff --git a/env-data/tests/__init__.py b/env_data/tests/__init__.py similarity index 100% rename from env-data/tests/__init__.py rename to env_data/tests/__init__.py diff --git a/env-data/tests/test-10x10.mpk b/env_data/tests/test-10x10.mpk similarity index 100% rename from env-data/tests/test-10x10.mpk rename to env_data/tests/test-10x10.mpk diff --git a/env-data/tests/test1.npy b/env_data/tests/test1.npy similarity index 100% rename from env-data/tests/test1.npy rename to env_data/tests/test1.npy diff --git a/examples/demo.py b/examples/demo.py index a1ddbf6..ae2c872 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -53,17 +53,9 @@ class Scenario_Generator: return env @staticmethod - def load_scenario(resource, package='env-data.railway', number_of_agents=3): + def load_scenario(resource, package='env_data.railway', number_of_agents=3): env = RailEnv(width=2 * (1 + number_of_agents), height=1 + number_of_agents) - - """ - env = RailEnv(width=20, - height=20, - rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - [filename, - number_of_agents=number_of_agents) - """ env.load_resource(package, resource) env.reset(False, False) diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index 05290f1..1d2c1e6 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -2,7 +2,7 @@ import random import numpy as np -from flatland.envs.generators import random_rail_generator # , rail_from_list_of_saved_GridTransitionMap_generator +from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 271ced5..43b9a72 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -3,6 +3,7 @@ TransitionMap and derived classes. """ import numpy as np +from importlib_resources import path from numpy import array from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions @@ -263,7 +264,7 @@ class GridTransitionMap(TransitionMap): """ np.save(filename, self.grid) - def load_transition_map(self, filename, override_gridsize=True): + def load_transition_map(self, package, resource, override_gridsize=True): """ Load the transitions grid from `filename' (npy format). The load function only updates the transitions grid, and possibly width and height, but the object has to be @@ -271,8 +272,10 @@ class GridTransitionMap(TransitionMap): Parameters ---------- - filename : string - Name of the file from which to load the transitions grid. + package : string + Name of the package from which to load the transitions grid. + resource : string + Name of the file from which to load the transitions grid within the package. override_gridsize : bool If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if @@ -280,7 +283,8 @@ class GridTransitionMap(TransitionMap): (height,width) ) """ - new_grid = np.load(filename) + with path(package, resource) as file_in: + new_grid = np.load(file_in) new_height = new_grid.shape[0] new_width = new_grid.shape[1] diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index c4cf908..791a07a 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -214,47 +214,6 @@ def rail_from_GridTransitionMap_generator(rail_map): return generator -def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): - """ - Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset. - - Parameters - ------- - list_of_filenames : list - List of filenames with the saved grids to load. - - Returns - ------- - function - Generator function that always returns the given `rail_map' object. - """ - - def generator(width, height, num_agents, num_resets=0): - t_utils = RailEnvTransitions() - rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) - rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) - - if rail_map.grid.dtype == np.uint64: - rail_map.transitions = Grid8Transitions() - - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail_map, - num_agents) - - return rail_map, agents_position, agents_direction, agents_target - - return generator - - -""" -def generate_rail_from_list_of_manual_specifications(list_of_specifications) - def generator(width, height, num_resets=0): - return generate_rail_from_manual_specifications(list_of_specifications) - - return generator -""" - - def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): """ Dummy random level generator: diff --git a/images/__init__.py b/images/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py index ec9dc07..f0d8629 100644 --- a/tests/test_env_edit.py +++ b/tests/test_env_edit.py @@ -4,7 +4,7 @@ from flatland.envs.rail_env import RailEnv def test_load_env(): env = RailEnv(10, 10) - env.load_resource('env-data.tests', 'test-10x10.mpk') + env.load_resource('env_data.tests', 'test-10x10.mpk') agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) env.add_agent_static(agent_static) diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index de8985b..cacc2a4 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -8,11 +8,14 @@ import sys import matplotlib.pyplot as plt import numpy as np +from importlib_resources import path import flatland.utils.rendertools as rt +import images.test from flatland.envs.generators import empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv +import env_data.tests def checkFrozenImage(oRT, sFileImage, resave=False): @@ -25,10 +28,8 @@ def checkFrozenImage(oRT, sFileImage, resave=False): np.savez_compressed(sDirImages + sFileImage, img=img_test) return - # this is now just for convenience - the file is not read back - np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test) - - np.load(sDirImages + sFileImage) + with path(images.test, sFileImage) as file_in: + np.load(file_in) # TODO fails! # assert (img_test.shape == img_expected.shape) \ # noqa: E800 @@ -43,8 +44,7 @@ def test_render_env(save_new_images=False): number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2) ) - sfTestEnv = "env-data/tests/test1.npy" - oEnv.rail.load_transition_map(sfTestEnv) + oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.renderEnv(show=False) -- GitLab