From faa4ace663a6c2ef6c8491c0d2c8a07ca242105c Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 4 Jun 2019 15:07:35 +0200 Subject: [PATCH] cleanup and integration tests --- flatland/envs/predictions.py | 4 +- flatland/envs/rail_env.py | 8 +- flatland/utils/graphics_pil.py | 4 +- setup.py | 13 ++-- tests/test_integration_test.py | 138 +++++++++++++++++++++++++++++++++ 5 files changed, 152 insertions(+), 15 deletions(-) create mode 100644 tests/test_integration_test.py diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py index 95c1a984..0420ab70 100644 --- a/flatland/envs/predictions.py +++ b/flatland/envs/predictions.py @@ -54,8 +54,8 @@ class DummyPredictorForRailEnv(PredictionBuilder): for index in range(1, self.max_depth): action_done = False for action in action_priorities: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action, - agent) + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + self.env._check_action_on_agent(action, agent) if all([new_cell_isValid, transition_isValid]): # move and change direction to face the new_direction that was # performed diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 82d694cc..da389d09 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -100,7 +100,6 @@ class RailEnv(Environment): if self.prediction_builder: self.prediction_builder._set_env(self) - self.action_space = [1] self.observation_space = self.obs_builder.observation_space # updated on resets? @@ -219,8 +218,8 @@ class RailEnv(Environment): return if action > 0: - cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action, - agent) + cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \ + self._check_action_on_agent(action, agent) if all([new_cell_isValid, transition_isValid, cell_isFree]): # move and change direction to face the new_direction that was # performed @@ -302,8 +301,7 @@ class RailEnv(Environment): def predict(self): if not self.prediction_builder: return {} - return self.prediction_builder.get() - + return self.prediction_builder.get() def check_action(self, agent, action): transition_isValid = None diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py index b738e5e1..ba322383 100644 --- a/flatland/utils/graphics_pil.py +++ b/flatland/utils/graphics_pil.py @@ -395,8 +395,8 @@ class PILSVG(PILGL): if isSelected: svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg") - self.clear_layer(3,0) - self.drawImageRC(svgBG,(row,col),layer=3) + self.clear_layer(3, 0) + self.drawImageRC(svgBG, (row, col), layer=3) def recolorImage(self, pil, a3BaseColor, ltColors): rgbaImg = array(pil) diff --git a/setup.py b/setup.py index 7ed4339e..ce7b232f 100644 --- a/setup.py +++ b/setup.py @@ -27,21 +27,22 @@ if os.name == 'nt': is64bit = p[0] == '64bit' if sys.version[0:3] == '3.5': if is64bit: - url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp35-cp35m-win_amd64.whl' + + url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp35-cp35m-win_amd64.whl' else: - url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp35-cp35m-win32.whl' + url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp35-cp35m-win32.whl' if sys.version[0:3] == '3.6': if is64bit: - url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp36-cp36m-win_amd64.whl' + url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp36-cp36m-win_amd64.whl' else: - url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp36-cp36m-win32.whl' + url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp36-cp36m-win32.whl' if sys.version[0:3] == '3.7': if is64bit: - url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp37-cp37m-win_amd64.whl' + url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp37-cp37m-win_amd64.whl' else: - url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp37-cp37m-win32.whl' + url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp37-cp37m-win32.whl' try: import pycairo diff --git a/tests/test_integration_test.py b/tests/test_integration_test.py new file mode 100644 index 00000000..8b6db60f --- /dev/null +++ b/tests/test_integration_test.py @@ -0,0 +1,138 @@ +import os +import random +import time + +import numpy as np + +from flatland.envs.generators import complex_rail_generator +from flatland.envs.generators import random_rail_generator +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool + +# ensure that every demo run behave constantly equal +random.seed(1) +np.random.seed(1) + + +class Scenario_Generator: + @staticmethod + def generate_random_scenario(number_of_agents=3): + # Example generate a rail given a manual specification, + # a map of tuples (cell_type, rotation) + transition_probability = [15, # empty cell - Case 0 + 5, # Case 1 - straight + 5, # Case 2 - simple switch + 1, # Case 3 - diamond crossing + 1, # Case 4 - single slip + 1, # Case 5 - double slip + 1, # Case 6 - symmetrical + 0, # Case 7 - dead end + 1, # Case 1b (8) - simple turn right + 1, # Case 1c (9) - simple turn left + 1] # Case 2b (10) - simple switch mirrored + + # Example generate a random rail + + env = RailEnv(width=20, + height=20, + rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability), + number_of_agents=number_of_agents) + + return env + + @staticmethod + def generate_complex_scenario(number_of_agents=3): + env = RailEnv(width=15, + height=15, + rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10, + max_dist=99999, seed=0), + number_of_agents=number_of_agents) + + return env + + @staticmethod + def load_scenario(filename, number_of_agents=3): + env = RailEnv(width=2 * (1 + number_of_agents), + height=1 + number_of_agents) + + """ + env = RailEnv(width=20, + height=20, + rail_generator=rail_from_list_of_saved_GridTransitionMap_generator( + [filename]), + number_of_agents=number_of_agents) + """ + if os.path.exists(filename): + print("load file: ", filename) + env.load(filename) + env.reset(False, False) + else: + print("File does not exist:", filename, " Working directory: ", os.getcwd()) + + return env + + +class Demo: + + def __init__(self, env): + self.env = env + self.create_renderer() + self.action_size = 4 + self.max_frame_rate = 60 + self.record_frames = None + + def set_record_frames(self, record_frames): + self.record_frames = record_frames + + def create_renderer(self): + self.renderer = RenderTool(self.env, gl="PILSVG") + handle = self.env.get_agent_handles() + return handle + + def set_max_framerate(self, max_frame_rate): + self.max_frame_rate = max_frame_rate + + def run_demo(self, max_nbr_of_steps=30): + action_dict = dict() + + # Reset environment + _ = self.env.reset(False, False) + + time.sleep(0.0001) # to satisfy lint... + + for step in range(max_nbr_of_steps): + + # Action + for iAgent in range(self.env.get_num_agents()): + # allways walk straight forward + action = 2 + + # update the actions + action_dict.update({iAgent: action}) + + # environment step (apply the actions to all agents) + next_obs, all_rewards, done, _ = self.env.step(action_dict) + + # render + self.renderer.renderEnv(show=True, show_observations=False) + + if done['__all__']: + break + + if self.record_frames is not None: + self.renderer.gl.saveImage(self.record_frames.format(step)) + + self.renderer.close_window() + + +def test_temp_pk1(): + demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/temp.pkl')) + demo_001.run_demo(10) + # TODO test assertions + + +def test_flatland_001_pkl(): + demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_flatland_001.pkl')) + demo_001.set_record_frames('./rendering/frame_{:04d}.bmp') + demo_001.run_demo(60) + # TODO test assertions -- GitLab