Skip to content
Snippets Groups Projects
Commit b08c9aae authored by Christian Eichenberger's avatar Christian Eichenberger :badminton:
Browse files

Merge branch '57-access-resources-through-importlib_resources' into 'master'

Resolve "Access resources using pkg_resources and/or importlib_resources"

Closes #57

See merge request flatland/flatland!57
parents f9dc9752 8d74cf3b
No related branches found
No related tags found
No related merge requests found
Showing
with 90 additions and 115 deletions
...@@ -9,7 +9,11 @@ Development ...@@ -9,7 +9,11 @@ Development
* G Spigler <giacomo.spigler@gmail.com> * G Spigler <giacomo.spigler@gmail.com>
* A Egli <adrian.egli@sbb.ch> * A Egli <adrian.egli@sbb.ch>
* E Nygren <erik.nygren@sbb.ch>
* Ch. Eichenberger <christian.markus.eichenberger@sbb.ch>
* Mattias Ljungström * Mattias Ljungström
......
...@@ -4,11 +4,12 @@ include HISTORY.rst ...@@ -4,11 +4,12 @@ include HISTORY.rst
include LICENSE include LICENSE
include README.rst include README.rst
include requirements_dev.txt include requirements_dev.txt
include requirements_continuous_integration.txt
graft svg graft svg
graft env-data graft env_data
recursive-include tests * recursive-include tests *
......
...@@ -17,3 +17,19 @@ Frequently Asked Questions (FAQs) ...@@ -17,3 +17,19 @@ Frequently Asked Questions (FAQs)
export LC_ALL=en_US.utf-8 export LC_ALL=en_US.utf-8
export LANG=en_US.utf-8 export LANG=en_US.utf-8
- We use `importlib-resources`_ to read from local files.
Sample usages:
.. code-block:: python
from importlib_resources import path
with path(package, resource) as file_in:
new_grid = np.load(file_in)
.. code-block:: python
from importlib_resources import read_binary
load_data = read_binary(package, resource)
self.set_full_state_msg(load_data)
.. _importlib-resources: https://importlib-resources.readthedocs.io/en/latest/
File moved
File moved
...@@ -53,22 +53,11 @@ class Scenario_Generator: ...@@ -53,22 +53,11 @@ class Scenario_Generator:
return env return env
@staticmethod @staticmethod
def load_scenario(filename, number_of_agents=3): def load_scenario(resource, package='env_data.railway', number_of_agents=3):
env = RailEnv(width=2 * (1 + number_of_agents), env = RailEnv(width=2 * (1 + number_of_agents),
height=1 + number_of_agents) height=1 + number_of_agents)
env.load_resource(package, resource)
""" env.reset(False, False)
env = RailEnv(width=20,
height=20,
rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
[filename,
number_of_agents=number_of_agents)
"""
if os.path.exists(filename):
env.load(filename)
env.reset(False, False)
else:
print("File does not exist:", filename, " Working directory: ", os.getcwd())
return env return env
...@@ -125,55 +114,57 @@ class Demo: ...@@ -125,55 +114,57 @@ class Demo:
self.renderer.close_window() self.renderer.close_window()
@staticmethod
def run_generate_random_scenario():
demo_000 = Demo(Scenario_Generator.generate_random_scenario())
demo_000.run_demo()
@staticmethod
def run_generate_complex_scenario():
demo_001 = Demo(Scenario_Generator.generate_complex_scenario())
demo_001.run_demo()
@staticmethod
def run_example_network_000():
demo_000 = Demo(Scenario_Generator.load_scenario('example_network_000.pkl'))
demo_000.run_demo()
@staticmethod
def run_example_network_001():
demo_001 = Demo(Scenario_Generator.load_scenario('example_network_001.pkl'))
demo_001.run_demo()
@staticmethod
def run_example_network_002():
demo_002 = Demo(Scenario_Generator.load_scenario('example_network_002.pkl'))
demo_002.run_demo()
@staticmethod
def run_example_network_003():
demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_network_003.pkl'))
demo_flatland_000.renderer.resize()
demo_flatland_000.set_max_framerate(5)
demo_flatland_000.run_demo(30)
@staticmethod
def run_example_flatland_000():
demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_000.pkl'))
demo_flatland_000.renderer.resize()
demo_flatland_000.run_demo(60)
@staticmethod
def run_example_flatland_001():
demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_001.pkl'))
demo_flatland_000.renderer.resize()
demo_flatland_000.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp'))
demo_flatland_000.run_demo(60)
@staticmethod
def run_complex_scene():
demo_001 = Demo(Scenario_Generator.load_scenario('complex_scene.pkl'))
demo_001.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp'))
demo_001.run_demo(360)
if False: if __name__ == "__main__":
demo_000 = Demo(Scenario_Generator.generate_random_scenario()) Demo.run_complex_scene()
demo_000.run_demo()
demo_000 = None
demo_001 = Demo(Scenario_Generator.generate_complex_scenario())
demo_001.run_demo()
demo_001 = None
demo_000 = Demo(Scenario_Generator.load_scenario(
os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_000.pkl')))
demo_000.run_demo()
demo_000 = None
demo_001 = Demo(Scenario_Generator.load_scenario(
os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_001.pkl')))
demo_001.run_demo()
demo_001 = None
demo_002 = Demo(Scenario_Generator.load_scenario(
os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_002.pkl')))
demo_002.run_demo()
demo_002 = None
demo_flatland_000 = Demo(
Scenario_Generator.load_scenario(
os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_flatland_000.pkl')))
demo_flatland_000.renderer.resize()
demo_flatland_000.run_demo(60)
demo_flatland_000 = None
demo_flatland_000 = Demo(
Scenario_Generator.load_scenario(
os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_003.pkl')))
demo_flatland_000.renderer.resize()
demo_flatland_000.set_max_framerate(5)
demo_flatland_000.run_demo(30)
demo_flatland_000 = None
demo_flatland_000 = Demo(
Scenario_Generator.load_scenario(
os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_flatland_001.pkl')))
demo_flatland_000.renderer.resize()
demo_flatland_000.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp'))
demo_flatland_000.run_demo(60)
demo_flatland_000 = None
demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/complex_scene.pkl'))
demo_001.set_record_frames('./rendering/frame_{:04d}.bmp')
demo_001.run_demo(360)
demo_001 = None
...@@ -2,7 +2,7 @@ import random ...@@ -2,7 +2,7 @@ import random
import numpy as np import numpy as np
from flatland.envs.generators import random_rail_generator # , rail_from_list_of_saved_GridTransitionMap_generator from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool from flatland.utils.rendertools import RenderTool
......
...@@ -29,7 +29,7 @@ class PredictionBuilder: ...@@ -29,7 +29,7 @@ class PredictionBuilder:
def get(self, handle=0): def get(self, handle=0):
""" """
Called whenever step_prediction is called on the environment. Called whenever predict is called on the environment.
Parameters Parameters
------- -------
......
...@@ -3,6 +3,7 @@ TransitionMap and derived classes. ...@@ -3,6 +3,7 @@ TransitionMap and derived classes.
""" """
import numpy as np import numpy as np
from importlib_resources import path
from numpy import array from numpy import array
from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions
...@@ -263,7 +264,7 @@ class GridTransitionMap(TransitionMap): ...@@ -263,7 +264,7 @@ class GridTransitionMap(TransitionMap):
""" """
np.save(filename, self.grid) np.save(filename, self.grid)
def load_transition_map(self, filename, override_gridsize=True): def load_transition_map(self, package, resource, override_gridsize=True):
""" """
Load the transitions grid from `filename' (npy format). Load the transitions grid from `filename' (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be The load function only updates the transitions grid, and possibly width and height, but the object has to be
...@@ -271,8 +272,10 @@ class GridTransitionMap(TransitionMap): ...@@ -271,8 +272,10 @@ class GridTransitionMap(TransitionMap):
Parameters Parameters
---------- ----------
filename : string package : string
Name of the file from which to load the transitions grid. Name of the package from which to load the transitions grid.
resource : string
Name of the file from which to load the transitions grid within the package.
override_gridsize : bool override_gridsize : bool
If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size
of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if
...@@ -280,7 +283,8 @@ class GridTransitionMap(TransitionMap): ...@@ -280,7 +283,8 @@ class GridTransitionMap(TransitionMap):
(height,width) ) (height,width) )
""" """
new_grid = np.load(filename) with path(package, resource) as file_in:
new_grid = np.load(file_in)
new_height = new_grid.shape[0] new_height = new_grid.shape[0]
new_width = new_grid.shape[1] new_width = new_grid.shape[1]
......
import numpy as np import numpy as np
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid8Transitions, RailEnvTransitions from flatland.core.transitions import RailEnvTransitions
from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror
from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail
...@@ -214,47 +214,6 @@ def rail_from_GridTransitionMap_generator(rail_map): ...@@ -214,47 +214,6 @@ def rail_from_GridTransitionMap_generator(rail_map):
return generator return generator
def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames):
"""
Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset.
Parameters
-------
list_of_filenames : list
List of filenames with the saved grids to load.
Returns
-------
function
Generator function that always returns the given `rail_map' object.
"""
def generator(width, height, num_agents, num_resets=0):
t_utils = RailEnvTransitions()
rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils)
rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False)
if rail_map.grid.dtype == np.uint64:
rail_map.transitions = Grid8Transitions()
agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail(
rail_map,
num_agents)
return rail_map, agents_position, agents_direction, agents_target
return generator
"""
def generate_rail_from_list_of_manual_specifications(list_of_specifications)
def generator(width, height, num_resets=0):
return generate_rail_from_manual_specifications(list_of_specifications)
return generator
"""
def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): def random_rail_generator(cell_type_relative_proportion=[1.0] * 11):
""" """
Dummy random level generator: Dummy random level generator:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment