Commit aa91c6ff authored by u214892's avatar u214892
Browse files

make env-data a module named env_data to access it with importlib_resource

parent 16553d4e
Pipeline #1034 failed with stage
in 9 minutes and 24 seconds
......@@ -53,17 +53,9 @@ class Scenario_Generator:
return env
@staticmethod
def load_scenario(resource, package='env-data.railway', number_of_agents=3):
def load_scenario(resource, package='env_data.railway', number_of_agents=3):
env = RailEnv(width=2 * (1 + number_of_agents),
height=1 + number_of_agents)
"""
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
[filename,
number_of_agents=number_of_agents)
"""
env.load_resource(package, resource)
env.reset(False, False)
......
......@@ -2,7 +2,7 @@ import random
import numpy as np
from flatland.envs.generators import random_rail_generator # , rail_from_list_of_saved_GridTransitionMap_generator
from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
......
......@@ -3,6 +3,7 @@ TransitionMap and derived classes.
"""
import numpy as np
from importlib_resources import path
from numpy import array
from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
......@@ -263,7 +264,7 @@ class GridTransitionMap(TransitionMap):
"""
np.save(filename, self.grid)
def load_transition_map(self, filename, override_gridsize=True):
def load_transition_map(self, package, resource, override_gridsize=True):
"""
Load the transitions grid from `filename' (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
......@@ -271,8 +272,10 @@ class GridTransitionMap(TransitionMap):
Parameters
----------
filename : string
Name of the file from which to load the transitions grid.
package : string
Name of the package from which to load the transitions grid.
resource : string
Name of the file from which to load the transitions grid within the package.
override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
......@@ -280,7 +283,8 @@ class GridTransitionMap(TransitionMap):
(height,width) )
"""
new_grid = np.load(filename)
with path(package, resource) as file_in:
new_grid = np.load(file_in)
new_height = new_grid.shape[0]
new_width = new_grid.shape[1]
......
......@@ -214,47 +214,6 @@ def rail_from_GridTransitionMap_generator(rail_map):
return generator
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target
return generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
return generate_rail_from_manual_specifications(list_of_specifications)
return generator
"""
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
"""
Dummy random level generator:
......
......@@ -4,7 +4,7 @@ from flatland.envs.rail_env import RailEnv
def test_load_env():
env = RailEnv(10, 10)
env.load_resource('env-data.tests', 'test-10x10.mpk')
env.load_resource('env_data.tests', 'test-10x10.mpk')
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
......
......@@ -8,11 +8,14 @@ import sys
import matplotlib.pyplot as plt
import numpy as np
from importlib_resources import path
import flatland.utils.rendertools as rt
import images.test
from flatland.envs.generators import empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
import env_data.tests
def checkFrozenImage(oRT, sFileImage, resave=False):
......@@ -25,10 +28,8 @@ def checkFrozenImage(oRT, sFileImage, resave=False):
np.savez_compressed(sDirImages + sFileImage, img=img_test)
return
# this is now just for convenience - the file is not read back
np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test)
np.load(sDirImages + sFileImage)
with path(images.test, sFileImage) as file_in:
np.load(file_in)
# TODO fails!
# assert (img_test.shape == img_expected.shape) \ # noqa: E800
......@@ -43,8 +44,7 @@ def test_render_env(save_new_images=False):
number_of_agents=0,
obs_builder_object=TreeObsForRailEnv(max_depth=2)
)
sfTestEnv = "env-data/tests/test1.npy"
oEnv.rail.load_transition_map(sfTestEnv)
oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
oRT = rt.RenderTool(oEnv, gl="PILSVG")
oRT.renderEnv(show=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment