diff --git a/AUTHORS.rst b/AUTHORS.rst index 39a8017090e88e8a6b52119c3cbdc1f10e09ce41..f7ab2e089c2315142abbdbe52cfe17492218d341 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -9,7 +9,11 @@ Development * G Spigler <giacomo.spigler@gmail.com> -* A Egli <adrian.egli@sbb.ch> +* A Egli <adrian.egli@sbb.ch> + +* E Nygren <erik.nygren@sbb.ch> + +* Ch. Eichenberger <christian.markus.eichenberger@sbb.ch> * Mattias Ljungström diff --git a/MANIFEST.in b/MANIFEST.in index 60c8cd7a4045fd3a36f8eab6809897fd09e67c42..ca50ea340f7f443d230f1f473a34331525fdbef1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,11 +4,12 @@ include HISTORY.rst include LICENSE include README.rst include requirements_dev.txt +include requirements_continuous_integration.txt graft svg -graft env-data +graft env_data recursive-include tests * diff --git a/docs/FAQ.rst b/docs/FAQ.rst index b7055321a710e3b3f47cbf5084f18eee083cf291..909fa2c1a6442168b41f89efa81de477301fb6ed 100644 --- a/docs/FAQ.rst +++ b/docs/FAQ.rst @@ -17,3 +17,19 @@ Frequently Asked Questions (FAQs) export LC_ALL=en_US.utf-8 export LANG=en_US.utf-8 + +- We use `importlib-resources`_ to read from local files. + Sample usages: + + .. code-block:: python + from importlib_resources import path + + with path(package, resource) as file_in: + new_grid = np.load(file_in) + + .. code-block:: python + from importlib_resources import read_binary + load_data = read_binary(package, resource) + self.set_full_state_msg(load_data) + + .. _importlib-resources: https://importlib-resources.readthedocs.io/en/latest/ diff --git a/env_data/__init__.py b/env_data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/env_data/railway/__init__.py b/env_data/railway/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/env-data/railway/complex_scene.pkl b/env_data/railway/complex_scene.pkl similarity index 100% rename from env-data/railway/complex_scene.pkl rename to env_data/railway/complex_scene.pkl diff --git a/env-data/railway/example_flatland_000.pkl b/env_data/railway/example_flatland_000.pkl similarity index 100% rename from env-data/railway/example_flatland_000.pkl rename to env_data/railway/example_flatland_000.pkl diff --git a/env-data/railway/example_flatland_001.pkl b/env_data/railway/example_flatland_001.pkl similarity index 100% rename from env-data/railway/example_flatland_001.pkl rename to env_data/railway/example_flatland_001.pkl diff --git a/env-data/railway/example_network_000.pkl b/env_data/railway/example_network_000.pkl similarity index 100% rename from env-data/railway/example_network_000.pkl rename to env_data/railway/example_network_000.pkl diff --git a/env-data/railway/example_network_001.pkl b/env_data/railway/example_network_001.pkl similarity index 100% rename from env-data/railway/example_network_001.pkl rename to env_data/railway/example_network_001.pkl diff --git a/env-data/railway/example_network_002.pkl b/env_data/railway/example_network_002.pkl similarity index 100% rename from env-data/railway/example_network_002.pkl rename to env_data/railway/example_network_002.pkl diff --git a/env-data/railway/example_network_003.pkl b/env_data/railway/example_network_003.pkl similarity index 100% rename from env-data/railway/example_network_003.pkl rename to env_data/railway/example_network_003.pkl diff --git a/env_data/tests/__init__.py b/env_data/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/env-data/tests/test-10x10.mpk b/env_data/tests/test-10x10.mpk similarity index 100% rename from env-data/tests/test-10x10.mpk rename to env_data/tests/test-10x10.mpk diff --git a/env-data/tests/test1.npy b/env_data/tests/test1.npy similarity index 100% rename from env-data/tests/test1.npy rename to env_data/tests/test1.npy diff --git a/examples/demo.py b/examples/demo.py index 933a9b1f5e00906f46e3800b2014e3d4aba60147..fb7196bdaef41ed36d2e190860678302209fae69 100644 --- a/examples/demo.py +++ b/examples/demo.py @@ -53,22 +53,11 @@ class Scenario_Generator: return env @staticmethod - def load_scenario(filename, number_of_agents=3): + def load_scenario(resource, package='env_data.railway', number_of_agents=3): env = RailEnv(width=2 * (1 + number_of_agents), height=1 + number_of_agents) - - """ - env = RailEnv(width=20, - height=20, - rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - [filename, - number_of_agents=number_of_agents) - """ - if os.path.exists(filename): - env.load(filename) - env.reset(False, False) - else: - print("File does not exist:", filename, " Working directory: ", os.getcwd()) + env.load_resource(package, resource) + env.reset(False, False) return env @@ -125,55 +114,57 @@ class Demo: self.renderer.close_window() + @staticmethod + def run_generate_random_scenario(): + demo_000 = Demo(Scenario_Generator.generate_random_scenario()) + demo_000.run_demo() + + @staticmethod + def run_generate_complex_scenario(): + demo_001 = Demo(Scenario_Generator.generate_complex_scenario()) + demo_001.run_demo() + + @staticmethod + def run_example_network_000(): + demo_000 = Demo(Scenario_Generator.load_scenario('example_network_000.pkl')) + demo_000.run_demo() + + @staticmethod + def run_example_network_001(): + demo_001 = Demo(Scenario_Generator.load_scenario('example_network_001.pkl')) + demo_001.run_demo() + + @staticmethod + def run_example_network_002(): + demo_002 = Demo(Scenario_Generator.load_scenario('example_network_002.pkl')) + demo_002.run_demo() + + @staticmethod + def run_example_network_003(): + demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_network_003.pkl')) + demo_flatland_000.renderer.resize() + demo_flatland_000.set_max_framerate(5) + demo_flatland_000.run_demo(30) + + @staticmethod + def run_example_flatland_000(): + demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_000.pkl')) + demo_flatland_000.renderer.resize() + demo_flatland_000.run_demo(60) + + @staticmethod + def run_example_flatland_001(): + demo_flatland_000 = Demo(Scenario_Generator.load_scenario('example_flatland_001.pkl')) + demo_flatland_000.renderer.resize() + demo_flatland_000.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp')) + demo_flatland_000.run_demo(60) + + @staticmethod + def run_complex_scene(): + demo_001 = Demo(Scenario_Generator.load_scenario('complex_scene.pkl')) + demo_001.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp')) + demo_001.run_demo(360) + -if False: - demo_000 = Demo(Scenario_Generator.generate_random_scenario()) - demo_000.run_demo() - demo_000 = None - - demo_001 = Demo(Scenario_Generator.generate_complex_scenario()) - demo_001.run_demo() - demo_001 = None - - demo_000 = Demo(Scenario_Generator.load_scenario( - os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_000.pkl'))) - demo_000.run_demo() - demo_000 = None - - demo_001 = Demo(Scenario_Generator.load_scenario( - os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_001.pkl'))) - demo_001.run_demo() - demo_001 = None - - demo_002 = Demo(Scenario_Generator.load_scenario( - os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_002.pkl'))) - demo_002.run_demo() - demo_002 = None - - demo_flatland_000 = Demo( - Scenario_Generator.load_scenario( - os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_flatland_000.pkl'))) - demo_flatland_000.renderer.resize() - demo_flatland_000.run_demo(60) - demo_flatland_000 = None - - demo_flatland_000 = Demo( - Scenario_Generator.load_scenario( - os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_network_003.pkl'))) - demo_flatland_000.renderer.resize() - demo_flatland_000.set_max_framerate(5) - demo_flatland_000.run_demo(30) - demo_flatland_000 = None - - demo_flatland_000 = Demo( - Scenario_Generator.load_scenario( - os.path.join(__file_dirname__, '..', 'env-data', 'railway', 'example_flatland_001.pkl'))) - demo_flatland_000.renderer.resize() - demo_flatland_000.set_record_frames(os.path.join(__file_dirname__, '..', 'rendering', 'frame_{:04d}.bmp')) - demo_flatland_000.run_demo(60) - demo_flatland_000 = None - -demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/complex_scene.pkl')) -demo_001.set_record_frames('./rendering/frame_{:04d}.bmp') -demo_001.run_demo(360) -demo_001 = None +if __name__ == "__main__": + Demo.run_complex_scene() diff --git a/examples/simple_example_2.py b/examples/simple_example_2.py index 05290f15d5f5ca672321e38560d9871be170610d..1d2c1e6d72be9be840459cfed1faf1dad1d261e9 100644 --- a/examples/simple_example_2.py +++ b/examples/simple_example_2.py @@ -2,7 +2,7 @@ import random import numpy as np -from flatland.envs.generators import random_rail_generator # , rail_from_list_of_saved_GridTransitionMap_generator +from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool diff --git a/flatland/core/env.py b/flatland/core/env.py index 3618d965a39b5a71fd1cf24aa81f2f876d5c6365..1bc5b6f3eba4ee4713bd3c8d6b88440006c215a5 100644 --- a/flatland/core/env.py +++ b/flatland/core/env.py @@ -84,21 +84,6 @@ class Environment: """ raise NotImplementedError() - def predict(self): - """ - Predictions step. - - Returns predictions for the agents. - The returns are dicts mapping from agent_id strings to values. - - Returns - ------- - predictions : dict - New predictions for each ready agent. - - """ - raise NotImplementedError() - def get_agent_handles(self): """ Returns a list of agents' handles to be used as keys in the step() diff --git a/flatland/core/env_prediction_builder.py b/flatland/core/env_prediction_builder.py index 9f5e4dc5033ba3a789313b47d94b033251cb8276..060dbfc38ec6b035f3264db7ef394545e54387f6 100644 --- a/flatland/core/env_prediction_builder.py +++ b/flatland/core/env_prediction_builder.py @@ -29,7 +29,7 @@ class PredictionBuilder: def get(self, handle=0): """ - Called whenever step_prediction is called on the environment. + Called whenever predict is called on the environment. Parameters ------- diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 271ced516c89a2f7e541aef5a11f10f36159ccf8..43b9a72ae7906c8a7dce58a2b58155a99ff760ae 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -3,6 +3,7 @@ TransitionMap and derived classes. """ import numpy as np +from importlib_resources import path from numpy import array from .transitions import Grid4Transitions, Grid8Transitions, RailEnvTransitions @@ -263,7 +264,7 @@ class GridTransitionMap(TransitionMap): """ np.save(filename, self.grid) - def load_transition_map(self, filename, override_gridsize=True): + def load_transition_map(self, package, resource, override_gridsize=True): """ Load the transitions grid from `filename' (npy format). The load function only updates the transitions grid, and possibly width and height, but the object has to be @@ -271,8 +272,10 @@ class GridTransitionMap(TransitionMap): Parameters ---------- - filename : string - Name of the file from which to load the transitions grid. + package : string + Name of the package from which to load the transitions grid. + resource : string + Name of the file from which to load the transitions grid within the package. override_gridsize : bool If override_gridsize=True, the width and height of the GridTransitionMap object are replaced with the size of the map loaded from `filename'. If override_gridsize=False, the transitions grid is either cropped (if @@ -280,7 +283,8 @@ class GridTransitionMap(TransitionMap): (height,width) ) """ - new_grid = np.load(filename) + with path(package, resource) as file_in: + new_grid = np.load(file_in) new_height = new_grid.shape[0] new_width = new_grid.shape[1] diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index c4cf908b7718b552d23d7415dd5f0a5a12604e7e..f644bc120d4b514f1c54e0330cfc8dc4654a4f4e 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -1,7 +1,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap -from flatland.core.transitions import Grid8Transitions, RailEnvTransitions +from flatland.core.transitions import RailEnvTransitions from flatland.envs.env_utils import distance_on_rail, connect_rail, get_direction, mirror from flatland.envs.env_utils import get_rnd_agents_pos_tgt_dir_on_rail @@ -214,47 +214,6 @@ def rail_from_GridTransitionMap_generator(rail_map): return generator -def rail_from_list_of_saved_GridTransitionMap_generator(list_of_filenames): - """ - Utility to sequentially and cyclically return GridTransitionMap-s from a list of files, on each environment reset. - - Parameters - ------- - list_of_filenames : list - List of filenames with the saved grids to load. - - Returns - ------- - function - Generator function that always returns the given `rail_map' object. - """ - - def generator(width, height, num_agents, num_resets=0): - t_utils = RailEnvTransitions() - rail_map = GridTransitionMap(width=width, height=height, transitions=t_utils) - rail_map.load_transition_map(list_of_filenames[num_resets % len(list_of_filenames)], override_gridsize=False) - - if rail_map.grid.dtype == np.uint64: - rail_map.transitions = Grid8Transitions() - - agents_position, agents_direction, agents_target = get_rnd_agents_pos_tgt_dir_on_rail( - rail_map, - num_agents) - - return rail_map, agents_position, agents_direction, agents_target - - return generator - - -""" -def generate_rail_from_list_of_manual_specifications(list_of_specifications) - def generator(width, height, num_resets=0): - return generate_rail_from_manual_specifications(list_of_specifications) - - return generator -""" - - def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): """ Dummy random level generator: diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index a3d88d773db9edaa0777e2aee94593a0392a956c..541f8ad592d1481afb8eb6da2eb7b887aacae419 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -31,7 +31,6 @@ class TreeObsForRailEnv(ObservationBuilder): self.location_has_agent = {} self.location_has_agent_direction = {} self.predictor = predictor - self.agents_previous_reset = None def reset(self): @@ -174,9 +173,10 @@ class TreeObsForRailEnv(ObservationBuilder): in the `handles' list. """ - # TODO: @Erik this is where the predictions should be computed, storing any temporary data inside this object. + self.predictions = [] if self.predictor: - print(self.predictor.get(0)) + for a in range(len(handles)): + self.predictions.append(self.predictor.get(a)) observations = {} for h in handles: observations[h] = self.get(h) @@ -222,6 +222,8 @@ class TreeObsForRailEnv(ObservationBuilder): (possible future use: number of other agents in other direction in this branch, ie. number of conflicts) 0 = no agent present other direction than myself + #8: possible conflict detected + Missing/padding nodes are filled in with -inf (truncated). Missing values in present node are filled in with +inf (truncated). @@ -256,7 +258,6 @@ class TreeObsForRailEnv(ObservationBuilder): for branch_direction in [(orientation + i) % 4 for i in range(-1, 3)]: if possible_transitions[branch_direction]: new_cell = self._new_position(agent.position, branch_direction) - branch_observation, branch_visited = \ self._explore_branch(handle, new_cell, branch_direction, root_observation, 1) observation = observation + branch_observation diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 3338e68126ee9a198e0208e17b1986f1c9fde6c8..f7dec074e6dbf43694ddf5284675cebb10ea5b59 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -18,7 +18,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): def get(self, handle=None): """ - Called whenever step_prediction is called on the environment. + Called whenever predict is called on the environment. Parameters ------- @@ -45,10 +45,16 @@ class DummyPredictorForRailEnv(PredictionBuilder): action_priorities = [RailEnvActions.MOVE_FORWARD, RailEnvActions.MOVE_LEFT, RailEnvActions.MOVE_RIGHT] _agent_initial_position = agent.position _agent_initial_direction = agent.direction - prediction = np.zeros(shape=(self.max_depth, 5)) + prediction = np.zeros(shape=(self.max_depth + 1, 5)) prediction[0] = [0, _agent_initial_position[0], _agent_initial_position[1], _agent_initial_direction, 0] - for index in range(1, self.max_depth): + for index in range(1, self.max_depth + 1): action_done = False + # if we're at the target, stop moving... + if agent.position == agent.target: + prediction[index] = [index, agent.target[0], agent.target[1], agent.direction, + RailEnvActions.STOP_MOVING] + + continue for action in action_priorities: cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ self.env._check_action_on_agent(action, agent) @@ -61,7 +67,7 @@ class DummyPredictorForRailEnv(PredictionBuilder): action_done = True break if not action_done: - print("Cannot move further.") + raise Exception("Cannot move further. Something is wrong") prediction_dict[agent.handle] = prediction agent.position = _agent_initial_position agent.direction = _agent_initial_direction diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 5d20a5d9f38230f353b0a9616c49ede333206c49..c22e1c5120b54a170f9c59bb54c7666ca910f086 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -292,7 +292,6 @@ class RailEnv(Environment): np.equal(new_position, [agent2.position for agent2 in self.agents]).all(1)) return cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid - def check_action(self, agent, action): transition_isValid = None possible_transitions = self.rail.get_transitions((*agent.position, agent.direction)) @@ -324,7 +323,6 @@ class RailEnv(Environment): self.obs_dict = self.obs_builder.get_many(list(range(self.get_num_agents()))) return self.obs_dict - def get_full_state_msg(self): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] @@ -366,3 +364,8 @@ class RailEnv(Environment): with open(filename, "rb") as file_in: load_data = file_in.read() self.set_full_state_msg(load_data) + + def load_resource(self, package, resource): + from importlib_resources import read_binary + load_data = read_binary(package, resource) + self.set_full_state_msg(load_data) diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index d4e5c38e975fbbfd5d357b78d1cd868ac828701e..81565d62d3ab17e740a5e1b635a97fd97d01980f 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -323,7 +323,8 @@ class Controller(object): def restartAgents(self, event): self.log("Restart Agents - nAgents:", self.view.wRegenNAgents.value) if self.model.init_agents_static is not None: - self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in self.model.init_agents_static] + self.model.env.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in + self.model.init_agents_static] self.model.env.agents = None self.model.init_agents_static = None self.player = None diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index 382895317fc56095e4443b220a456dd1e92681e7..94dffa4d13c29a4e235649c85e8c5f2729754a9a 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -1,12 +1,12 @@ import io import os -import site import time import tkinter as tk import numpy as np from PIL import Image, ImageDraw, ImageTk # , ImageFont from numpy import array +from pkg_resources import resource_string as resource_bytes from flatland.utils.graphics_layer import GraphicsLayer @@ -258,18 +258,9 @@ class PILSVG(PILGL): self.lwAgents = [] self.agents_prev = [] - def pilFromSvgFile(self, sfPath): - try: - with open(sfPath, "r") as fIn: - bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell) - except: # noqa: E722 - newList = '' - for directory in site.getsitepackages(): - x = [word for word in os.listdir(directory) if word.startswith('flatland')] - if len(x) > 0: - newList = directory + '/' + x[0] - with open(newList + '/' + sfPath, "r") as fIn: - bytesPNG = svg2png(file_obj=fIn, output_height=self.nPixCell, output_width=self.nPixCell) + def pilFromSvgFile(self, package, resource): + bytestring = resource_bytes(package, resource) + bytesPNG = svg2png(bytestring=bytestring, output_height=self.nPixCell, output_width=self.nPixCell) with io.BytesIO(bytesPNG) as fIn: pil_img = Image.open(fIn) pil_img.load() @@ -382,10 +373,7 @@ class PILSVG(PILGL): lDirs = list("NESW") - # svgBG = SVG("./svg/Background_#91D1DD.svg") - for sTrans, sFile in dDirFile.items(): - sPathSvg = "./svg/" + sFile # Translate the ascii transition description in the format "NE WS" to the # binary list of transitions as per RailEnv - NESW (in) x NESW (out) @@ -399,7 +387,7 @@ class PILSVG(PILGL): sTrans16 = "".join(lTrans16) binTrans = int(sTrans16, 2) - pilRail = self.pilFromSvgFile(sPathSvg) + pilRail = self.pilFromSvgFile('svg', sFile) if rotate: # For rotations, we also store the base image @@ -447,7 +435,7 @@ class PILSVG(PILGL): print("Illegal target rail:", row, col, format(binTrans, "#018b")[2:]) if isSelected: - svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg") + svgBG = self.pilFromSvgFile("svg", "Selected_Target.svg") self.clear_layer(3, 0) self.drawImageRC(svgBG, (row, col), layer=3) @@ -470,13 +458,13 @@ class PILSVG(PILGL): # Seed initial train/zug files indexed by tuple(iDirIn, iDirOut): dDirsFile = { - (0, 0): "svg/Zug_Gleis_#0091ea.svg", - (1, 2): "svg/Zug_1_Weiche_#0091ea.svg", - (0, 3): "svg/Zug_2_Weiche_#0091ea.svg" + (0, 0): "Zug_Gleis_#0091ea.svg", + (1, 2): "Zug_1_Weiche_#0091ea.svg", + (0, 3): "Zug_2_Weiche_#0091ea.svg" } # "paint" color of the train images we load - this is the color we will change. - # a3BaseColor = self.rgb_s2i("0091ea") + # a3BaseColor = self.rgb_s2i("0091ea") \# noqa: E800 # temporary workaround for trains / agents renamed with different colour: a3BaseColor = self.rgb_s2i("d50000") @@ -485,7 +473,7 @@ class PILSVG(PILGL): for tDirs, sPathSvg in dDirsFile.items(): iDirIn, iDirOut = tDirs - pilZug = self.pilFromSvgFile(sPathSvg) + pilZug = self.pilFromSvgFile("svg", sPathSvg) # Rotate both the directions and the image and save in the dict for iDirRot in range(4): @@ -511,7 +499,7 @@ class PILSVG(PILGL): self.drawImageRC(pilZug, (row, col), layer=1) if isSelected: - svgBG = self.pilFromSvgFile("./svg/Selected_Agent.svg") + svgBG = self.pilFromSvgFile("svg", "Selected_Agent.svg") self.clear_layer(2, 0) self.drawImageRC(svgBG, (row, col), layer=2) diff --git a/images/__init__.py b/images/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/requirements_dev.txt b/requirements_dev.txt index 28ed09e6321ec0f9711d8784b97f06bfe87fb2b5..146631ced56d34341354f22df19e89760ce54788 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -4,7 +4,7 @@ twine==1.12.1 pytest==3.8.2 pytest-runner==4.2 pytest-xvfb==1.2.0 -numpy==1.16.2 +numpy==1.16.4 recordtype==1.3 xarray==0.11.3 matplotlib==3.0.2 @@ -15,3 +15,5 @@ msgpack==0.6.1 svgutils==0.3.1 screeninfo==0.3.1 pyarrow==0.13.0 +importlib-metadata==0.17 +importlib_resources==1.0.2 diff --git a/setup.py b/setup.py index d517c279cb5d31a09174b71419b3615160abc39a..39bffd173a000a1fa40018cc5b856c9a26e2e253 100644 --- a/setup.py +++ b/setup.py @@ -63,12 +63,12 @@ else: def get_all_svg_files(directory='./svg/'): ret = [] for f in os.listdir(directory): - ret.append(directory + f) + if f != '__pycache__': + ret.append(directory + f) return ret # Gather requirements from requirements_dev.txt -# TODO : We could potentially split up the test/dev dependencies later install_reqs = [] requirements_path = 'requirements_dev.txt' with open(requirements_path, 'r') as f: diff --git a/svg/__init__.py b/svg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py index 0707cf3ce51c77d9918e25e5b3858a46574b321e..f0d86292ce926147017bab5e4d777019bbcfb143 100644 --- a/tests/test_env_edit.py +++ b/tests/test_env_edit.py @@ -4,7 +4,7 @@ from flatland.envs.rail_env import RailEnv def test_load_env(): env = RailEnv(10, 10) - env.load("env-data/tests/test-10x10.mpk") + env.load_resource('env_data.tests', 'test-10x10.mpk') agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) env.add_agent_static(agent_static) diff --git a/tests/test_env_prediction_builder.py b/tests/test_env_prediction_builder.py index 35a6a27b970ce54e1cabd3cf8c80d30a34800a25..cb7d26df75017b8a531ff6552eb6c1651c80a08f 100644 --- a/tests/test_env_prediction_builder.py +++ b/tests/test_env_prediction_builder.py @@ -5,7 +5,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap, Grid4Transitions from flatland.envs.generators import rail_from_GridTransitionMap_generator -from flatland.envs.observations import GlobalObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import DummyPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -64,8 +64,7 @@ def test_predictions(): height=rail_map.shape[0], rail_generator=rail_from_GridTransitionMap_generator(rail), number_of_agents=1, - obs_builder_object=GlobalObsForRailEnv(), - prediction_builder_object=DummyPredictorForRailEnv(max_depth=20) + obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=DummyPredictorForRailEnv(max_depth=10)), ) env.reset() @@ -73,8 +72,9 @@ def test_predictions(): # set initial position and direction for testing... env.agents[0].position = (5, 6) env.agents[0].direction = 0 + env.agents[0].target = (3., 0.) - predictions = env.predict() + predictions = env.obs_builder.predictor.get() positions = np.array(list(map(lambda prediction: [prediction[1], prediction[2]], predictions[0]))) directions = np.array(list(map(lambda prediction: [prediction[3]], predictions[0]))) time_offsets = np.array(list(map(lambda prediction: [prediction[0]], predictions[0]))) @@ -89,18 +89,11 @@ def test_predictions(): [3., 3.], [3., 2.], [3., 1.], + # at target (3,0): stay in this position from here on [3., 0.], - [3., 1.], - [3., 2.], - [3., 3.], - [3., 4.], - [3., 5.], - [3., 6.], - [3., 7.], - [3., 8.], - [3., 9.], - [3., 8.], - [3., 7.]]) + [3., 0.], + [3., 0.], + ]) expected_directions = np.array([[0.], [0.], [0.], @@ -109,18 +102,11 @@ def test_predictions(): [3.], [3.], [3.], + # at target (3,0): stay in this position from here on [3.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], [3.], - [3.]]) + [3.] + ]) expected_time_offsets = np.array([[0.], [1.], [2.], @@ -132,15 +118,7 @@ def test_predictions(): [8.], [9.], [10.], - [11.], - [12.], - [13.], - [14.], - [15.], - [16.], - [17.], - [18.], - [19.]]) + ]) expected_actions = np.array([[0.], [2.], [2.], @@ -149,18 +127,12 @@ def test_predictions(): [2.], [2.], [2.], + # reaching target by straight [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.], - [2.]]) + # at target: stopped moving + [4.], + [4.], + ]) assert np.array_equal(positions, expected_positions) assert np.array_equal(directions, expected_directions) assert np.array_equal(time_offsets, expected_time_offsets) diff --git a/tests/test_integration_test.py b/tests/test_integration_test.py index 8b6db60f991a523a064b8b5f73e07fbda1d768cd..808af10157298cc3630e7ba7980481b0eb17e833 100644 --- a/tests/test_integration_test.py +++ b/tests/test_integration_test.py @@ -1,138 +1,49 @@ -import os import random -import time import numpy as np -from flatland.envs.generators import complex_rail_generator -from flatland.envs.generators import random_rail_generator -from flatland.envs.rail_env import RailEnv -from flatland.utils.rendertools import RenderTool +from examples.demo import Demo # ensure that every demo run behave constantly equal random.seed(1) np.random.seed(1) -class Scenario_Generator: - @staticmethod - def generate_random_scenario(number_of_agents=3): - # Example generate a rail given a manual specification, - # a map of tuples (cell_type, rotation) - transition_probability = [15, # empty cell - Case 0 - 5, # Case 1 - straight - 5, # Case 2 - simple switch - 1, # Case 3 - diamond crossing - 1, # Case 4 - single slip - 1, # Case 5 - double slip - 1, # Case 6 - symmetrical - 0, # Case 7 - dead end - 1, # Case 1b (8) - simple turn right - 1, # Case 1c (9) - simple turn left - 1] # Case 2b (10) - simple switch mirrored - - # Example generate a random rail - - env = RailEnv(width=20, - height=20, - rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), - number_of_agents=number_of_agents) - - return env - - @staticmethod - def generate_complex_scenario(number_of_agents=3): - env = RailEnv(width=15, - height=15, - rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10, - max_dist=99999, seed=0), - number_of_agents=number_of_agents) - - return env - - @staticmethod - def load_scenario(filename, number_of_agents=3): - env = RailEnv(width=2 * (1 + number_of_agents), - height=1 + number_of_agents) - - """ - env = RailEnv(width=20, - height=20, - rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( - [filename]), - number_of_agents=number_of_agents) - """ - if os.path.exists(filename): - print("load file: ", filename) - env.load(filename) - env.reset(False, False) - else: - print("File does not exist:", filename, " Working directory: ", os.getcwd()) - - return env - - -class Demo: - - def __init__(self, env): - self.env = env - self.create_renderer() - self.action_size = 4 - self.max_frame_rate = 60 - self.record_frames = None - - def set_record_frames(self, record_frames): - self.record_frames = record_frames - - def create_renderer(self): - self.renderer = RenderTool(self.env, gl="PILSVG") - handle = self.env.get_agent_handles() - return handle - - def set_max_framerate(self, max_frame_rate): - self.max_frame_rate = max_frame_rate +def test_flatland_000(): + Demo.run_example_flatland_000() + # TODO test assertions - def run_demo(self, max_nbr_of_steps=30): - action_dict = dict() - # Reset environment - _ = self.env.reset(False, False) +def test_flatland_001(): + Demo.run_example_flatland_001() + # TODO test assertions - time.sleep(0.0001) # to satisfy lint... - for step in range(max_nbr_of_steps): +def test_network_000(): + Demo.run_example_network_000() + # TODO test assertions - # Action - for iAgent in range(self.env.get_num_agents()): - # allways walk straight forward - action = 2 - # update the actions - action_dict.update({iAgent: action}) +def test_network_001(): + Demo.run_example_network_001() + # TODO test assertions - # environment step (apply the actions to all agents) - next_obs, all_rewards, done, _ = self.env.step(action_dict) - # render - self.renderer.renderEnv(show=True, show_observations=False) - - if done['__all__']: - break +def test_network_002(): + Demo.run_example_network_002() + # TODO test assertions - if self.record_frames is not None: - self.renderer.gl.saveImage(self.record_frames.format(step)) - self.renderer.close_window() +def test_complex_scene(): + Demo.run_complex_scene() + # TODO test assertions -def test_temp_pk1(): - demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/temp.pkl')) - demo_001.run_demo(10) +def test_generate_complex_scenario(): + Demo.run_generate_complex_scenario() # TODO test assertions -def test_flatland_001_pkl(): - demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_flatland_001.pkl')) - demo_001.set_record_frames('./rendering/frame_{:04d}.bmp') - demo_001.run_demo(60) +def test_generate_random_scenario(): + Demo.run_generate_random_scenario() # TODO test assertions diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index de8985b678208ea3a09abd63ecab9d2419be20c4..14edfee709ce2c3a9110a5af34293d6b5c8d01f4 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -8,8 +8,10 @@ import sys import matplotlib.pyplot as plt import numpy as np +from importlib_resources import path import flatland.utils.rendertools as rt +import images.test from flatland.envs.generators import empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv @@ -25,10 +27,8 @@ def checkFrozenImage(oRT, sFileImage, resave=False): np.savez_compressed(sDirImages + sFileImage, img=img_test) return - # this is now just for convenience - the file is not read back - np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test) - - np.load(sDirImages + sFileImage) + with path(images, sFileImage) as file_in: + np.load(file_in) # TODO fails! # assert (img_test.shape == img_expected.shape) \ # noqa: E800 @@ -43,8 +43,7 @@ def test_render_env(save_new_images=False): number_of_agents=0, obs_builder_object=TreeObsForRailEnv(max_depth=2) ) - sfTestEnv = "env-data/tests/test1.npy" - oEnv.rail.load_transition_map(sfTestEnv) + oEnv.rail.load_transition_map('env_data.tests', "test1.npy") oRT = rt.RenderTool(oEnv, gl="PILSVG") oRT.renderEnv(show=False)