From d2a71e196ff6065d0b7bc7201514968f3f700ca3 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Wed, 19 Jun 2019 08:41:41 +0200
Subject: [PATCH] #62 first steps unit test coverage

---
 flatland/core/transition_map.py               | 30 +++----------------
 tests/test_env_edit.py                        | 11 -------
 tests/test_flatland_core_transition_map.py    | 13 ++++++++
 ...s.py => test_flatland_core_transitions.py} |  0
 ...ents.py => test_flatland_envs_rail_env.py} |  9 +++++-
 ....py => test_flatland_utils_rendertools.py} |  2 +-
 tests/test_player.py                          |  6 ----
 7 files changed, 26 insertions(+), 45 deletions(-)
 delete mode 100644 tests/test_env_edit.py
 create mode 100644 tests/test_flatland_core_transition_map.py
 rename tests/{test_transitions.py => test_flatland_core_transitions.py} (100%)
 rename tests/{test_environments.py => test_flatland_envs_rail_env.py} (96%)
 rename tests/{test_rendertools.py => test_flatland_utils_rendertools.py} (96%)
 delete mode 100644 tests/test_player.py

diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py
index 43b9a72a..6c9bde42 100644
--- a/flatland/core/transition_map.py
+++ b/flatland/core/transition_map.py
@@ -264,7 +264,7 @@ class GridTransitionMap(TransitionMap):
         """
         np.save(filename, self.grid)
 
-    def load_transition_map(self, package, resource, override_gridsize=True):
+    def load_transition_map(self, package, resource):
         """
         Load the transitions grid from `filename' (npy format).
         The load function only updates the transitions grid, and possibly width and height, but the object has to be
@@ -289,28 +289,9 @@ class GridTransitionMap(TransitionMap):
         new_height = new_grid.shape[0]
         new_width = new_grid.shape[1]
 
-        if override_gridsize:
-            self.width = new_width
-            self.height = new_height
-            self.grid = new_grid
-
-        else:
-            if new_grid.dtype == np.uint16:
-                self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
-            elif new_grid.dtype == np.uint64:
-                self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
-
-            self.grid[0:min(self.height, new_height),
-            0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
-                                            0:min(self.width, new_width)]
-
-    def is_cell_valid(self, rcPos):
-        cell_transition = self.grid[tuple(rcPos)]
-
-        if not self.transitions.is_valid(cell_transition):
-            return False
-        else:
-            return True
+        self.width = new_width
+        self.height = new_height
+        self.grid = new_grid
 
     def cell_neighbours_valid(self, rcPos, check_this_cell=False):
         """
@@ -364,9 +345,6 @@ class GridTransitionMap(TransitionMap):
 
         return True
 
-    def cell_repr(self, rcPos):
-        return self.transitions.repr(self.get_transitions(rcPos))
-
 # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
 # (most general implementation) or to make Grid-class specific methods for
 # slicing over the 3 dimensions?  I'd say both perhaps.
diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py
deleted file mode 100644
index f0d86292..00000000
--- a/tests/test_env_edit.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from flatland.envs.agent_utils import EnvAgentStatic
-from flatland.envs.rail_env import RailEnv
-
-
-def test_load_env():
-    env = RailEnv(10, 10)
-    env.load_resource('env_data.tests', 'test-10x10.mpk')
-
-    agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
-    env.add_agent_static(agent_static)
-    assert env.get_num_agents() == 1
diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py
new file mode 100644
index 00000000..cf1a9206
--- /dev/null
+++ b/tests/test_flatland_core_transition_map.py
@@ -0,0 +1,13 @@
+from flatland.core.transition_map import GridTransitionMap
+from flatland.core.transitions import Grid4Transitions, Grid8Transitions, Grid4TransitionsEnum
+
+
+def test_grid4_set_transitions():
+    grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
+    grid4_map.set_transition((0, 0), Grid4TransitionsEnum.EAST, 1)
+    actual_transitions  = grid4_map.get_transitions((0,0))
+    assert False
+
+
+def test_grid8_set_transitions():
+    grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
diff --git a/tests/test_transitions.py b/tests/test_flatland_core_transitions.py
similarity index 100%
rename from tests/test_transitions.py
rename to tests/test_flatland_core_transitions.py
diff --git a/tests/test_environments.py b/tests/test_flatland_envs_rail_env.py
similarity index 96%
rename from tests/test_environments.py
rename to tests/test_flatland_envs_rail_env.py
index 11f0acba..e8811fce 100644
--- a/tests/test_environments.py
+++ b/tests/test_flatland_envs_rail_env.py
@@ -4,7 +4,7 @@ import numpy as np
 
 from flatland.core.transition_map import GridTransitionMap
 from flatland.core.transitions import Grid4Transitions
-from flatland.envs.agent_utils import EnvAgent
+from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
 from flatland.envs.generators import complex_rail_generator
 from flatland.envs.generators import rail_from_GridTransitionMap_generator
 from flatland.envs.observations import GlobalObsForRailEnv
@@ -12,6 +12,13 @@ from flatland.envs.rail_env import RailEnv
 
 """Tests for `flatland` package."""
 
+def test_load_env():
+    env = RailEnv(10, 10)
+    env.load_resource('env_data.tests', 'test-10x10.mpk')
+
+    agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
+    env.add_agent_static(agent_static)
+    assert env.get_num_agents() == 1
 
 def test_save_load():
     env = RailEnv(width=10, height=10,
diff --git a/tests/test_rendertools.py b/tests/test_flatland_utils_rendertools.py
similarity index 96%
rename from tests/test_rendertools.py
rename to tests/test_flatland_utils_rendertools.py
index 14edfee7..ff7cbd01 100644
--- a/tests/test_rendertools.py
+++ b/tests/test_flatland_utils_rendertools.py
@@ -79,7 +79,7 @@ def main():
     if len(sys.argv) == 2 and sys.argv[1] == "save":
         test_render_env(save_new_images=True)
     else:
-        print("Run 'python test_rendertools.py save' to regenerate images")
+        print("Run 'python test_flatland_utils_rendertools.py save' to regenerate images")
         test_render_env()
 
 
diff --git a/tests/test_player.py b/tests/test_player.py
deleted file mode 100644
index 757fc90d..00000000
--- a/tests/test_player.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from examples.play_model import main
-
-
-def test_main():
-    main(render=True, n_steps=20, n_trials=2, sGL="PIL")
-    main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
-- 
GitLab