From faa4ace663a6c2ef6c8491c0d2c8a07ca242105c Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Tue, 4 Jun 2019 15:07:35 +0200
Subject: [PATCH] cleanup and integration tests

---
 flatland/envs/predictions.py   |   4 +-
 flatland/envs/rail_env.py      |   8 +-
 flatland/utils/graphics_pil.py |   4 +-
 setup.py                       |  13 ++--
 tests/test_integration_test.py | 138 +++++++++++++++++++++++++++++++++
 5 files changed, 152 insertions(+), 15 deletions(-)
 create mode 100644 tests/test_integration_test.py

diff --git a/flatland/envs/predictions.py b/flatland/envs/predictions.py
index 95c1a98..0420ab7 100644
--- a/flatland/envs/predictions.py
+++ b/flatland/envs/predictions.py
@@ -54,8 +54,8 @@ class DummyPredictorForRailEnv(PredictionBuilder):
             for index in range(1, self.max_depth):
                 action_done = False
                 for action in action_priorities:
-                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self.env._check_action_on_agent(action,
-                                                                                                                                     agent)
+                    cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
+                        self.env._check_action_on_agent(action, agent)
                     if all([new_cell_isValid, transition_isValid]):
                         # move and change direction to face the new_direction that was
                         # performed
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 82d694c..da389d0 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -100,7 +100,6 @@ class RailEnv(Environment):
         if self.prediction_builder:
             self.prediction_builder._set_env(self)
 
-
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
@@ -219,8 +218,8 @@ class RailEnv(Environment):
                 return
 
             if action > 0:
-                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = self._check_action_on_agent(action,
-                                                                                                                             agent)
+                cell_isFree, new_cell_isValid, new_direction, new_position, transition_isValid = \
+                    self._check_action_on_agent(action, agent)
                 if all([new_cell_isValid, transition_isValid, cell_isFree]):
                     # move and change direction to face the new_direction that was
                     # performed
@@ -302,8 +301,7 @@ class RailEnv(Environment):
     def predict(self):
         if not self.prediction_builder:
             return {}
-        return  self.prediction_builder.get()
-
+        return self.prediction_builder.get()
 
     def check_action(self, agent, action):
         transition_isValid = None
diff --git a/flatland/utils/graphics_pil.py b/flatland/utils/graphics_pil.py
index b738e5e..ba32238 100644
--- a/flatland/utils/graphics_pil.py
+++ b/flatland/utils/graphics_pil.py
@@ -395,8 +395,8 @@ class PILSVG(PILGL):
 
             if isSelected:
                 svgBG = self.pilFromSvgFile("./svg/Selected_Target.svg")
-                self.clear_layer(3,0)
-                self.drawImageRC(svgBG,(row,col),layer=3)
+                self.clear_layer(3, 0)
+                self.drawImageRC(svgBG, (row, col), layer=3)
 
     def recolorImage(self, pil, a3BaseColor, ltColors):
         rgbaImg = array(pil)
diff --git a/setup.py b/setup.py
index 7ed4339..ce7b232 100644
--- a/setup.py
+++ b/setup.py
@@ -27,21 +27,22 @@ if os.name == 'nt':
     is64bit = p[0] == '64bit'
     if sys.version[0:3] == '3.5':
         if is64bit:
-            url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp35-cp35m-win_amd64.whl'
+
+            url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp35-cp35m-win_amd64.whl'
         else:
-            url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp35-cp35m-win32.whl'
+            url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp35-cp35m-win32.whl'
 
     if sys.version[0:3] == '3.6':
         if is64bit:
-            url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp36-cp36m-win_amd64.whl'
+            url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp36-cp36m-win_amd64.whl'
         else:
-            url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp36-cp36m-win32.whl'
+            url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp36-cp36m-win32.whl'
 
     if sys.version[0:3] == '3.7':
         if is64bit:
-            url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp37-cp37m-win_amd64.whl'
+            url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp37-cp37m-win_amd64.whl'
         else:
-            url = 'https://download.lfd.uci.edu/pythonlibs/q5gtlas7/pycairo-1.18.0-cp37-cp37m-win32.whl'
+            url = 'https://download.lfd.uci.edu/pythonlibs/t4jqbe6o/pycairo-1.18.0-cp37-cp37m-win32.whl'
 
     try:
         import pycairo
diff --git a/tests/test_integration_test.py b/tests/test_integration_test.py
new file mode 100644
index 0000000..8b6db60
--- /dev/null
+++ b/tests/test_integration_test.py
@@ -0,0 +1,138 @@
+import os
+import random
+import time
+
+import numpy as np
+
+from flatland.envs.generators import complex_rail_generator
+from flatland.envs.generators import random_rail_generator
+from flatland.envs.rail_env import RailEnv
+from flatland.utils.rendertools import RenderTool
+
+# ensure that every demo run behave constantly equal
+random.seed(1)
+np.random.seed(1)
+
+
+class Scenario_Generator:
+    @staticmethod
+    def generate_random_scenario(number_of_agents=3):
+        # Example generate a rail given a manual specification,
+        # a map of tuples (cell_type, rotation)
+        transition_probability = [15,  # empty cell - Case 0
+                                  5,  # Case 1 - straight
+                                  5,  # Case 2 - simple switch
+                                  1,  # Case 3 - diamond crossing
+                                  1,  # Case 4 - single slip
+                                  1,  # Case 5 - double slip
+                                  1,  # Case 6 - symmetrical
+                                  0,  # Case 7 - dead end
+                                  1,  # Case 1b (8)  - simple turn right
+                                  1,  # Case 1c (9)  - simple turn left
+                                  1]  # Case 2b (10) - simple switch mirrored
+
+        # Example generate a random rail
+
+        env = RailEnv(width=20,
+                      height=20,
+                      rail_generator=random_rail_generator(cell_type_relative_proportion=transition_probability),
+                      number_of_agents=number_of_agents)
+
+        return env
+
+    @staticmethod
+    def generate_complex_scenario(number_of_agents=3):
+        env = RailEnv(width=15,
+                      height=15,
+                      rail_generator=complex_rail_generator(nr_start_goal=6, nr_extra=30, min_dist=10,
+                                                            max_dist=99999, seed=0),
+                      number_of_agents=number_of_agents)
+
+        return env
+
+    @staticmethod
+    def load_scenario(filename, number_of_agents=3):
+        env = RailEnv(width=2 * (1 + number_of_agents),
+                      height=1 + number_of_agents)
+
+        """
+        env = RailEnv(width=20,
+                      height=20,
+                      rail_generator=rail_from_list_of_saved_GridTransitionMap_generator(
+                          [filename]),
+                      number_of_agents=number_of_agents)
+        """
+        if os.path.exists(filename):
+            print("load file: ", filename)
+            env.load(filename)
+            env.reset(False, False)
+        else:
+            print("File does not exist:", filename, " Working directory: ", os.getcwd())
+
+        return env
+
+
+class Demo:
+
+    def __init__(self, env):
+        self.env = env
+        self.create_renderer()
+        self.action_size = 4
+        self.max_frame_rate = 60
+        self.record_frames = None
+
+    def set_record_frames(self, record_frames):
+        self.record_frames = record_frames
+
+    def create_renderer(self):
+        self.renderer = RenderTool(self.env, gl="PILSVG")
+        handle = self.env.get_agent_handles()
+        return handle
+
+    def set_max_framerate(self, max_frame_rate):
+        self.max_frame_rate = max_frame_rate
+
+    def run_demo(self, max_nbr_of_steps=30):
+        action_dict = dict()
+
+        # Reset environment
+        _ = self.env.reset(False, False)
+
+        time.sleep(0.0001)  # to satisfy lint...
+
+        for step in range(max_nbr_of_steps):
+
+            # Action
+            for iAgent in range(self.env.get_num_agents()):
+                # allways walk straight forward
+                action = 2
+
+                # update the actions
+                action_dict.update({iAgent: action})
+
+            # environment step (apply the actions to all agents)
+            next_obs, all_rewards, done, _ = self.env.step(action_dict)
+
+            # render
+            self.renderer.renderEnv(show=True, show_observations=False)
+
+            if done['__all__']:
+                break
+
+            if self.record_frames is not None:
+                self.renderer.gl.saveImage(self.record_frames.format(step))
+
+        self.renderer.close_window()
+
+
+def test_temp_pk1():
+    demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/temp.pkl'))
+    demo_001.run_demo(10)
+    # TODO test assertions
+
+
+def test_flatland_001_pkl():
+    demo_001 = Demo(Scenario_Generator.load_scenario('./env-data/railway/example_flatland_001.pkl'))
+    demo_001.set_record_frames('./rendering/frame_{:04d}.bmp')
+    demo_001.run_demo(60)
+    # TODO test assertions
-- 
GitLab