From d2a71e196ff6065d0b7bc7201514968f3f700ca3 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Wed, 19 Jun 2019 08:41:41 +0200 Subject: [PATCH] #62 first steps unit test coverage --- flatland/core/transition_map.py | 30 +++---------------- tests/test_env_edit.py | 11 ------- tests/test_flatland_core_transition_map.py | 13 ++++++++ ...s.py => test_flatland_core_transitions.py} | 0 ...ents.py => test_flatland_envs_rail_env.py} | 9 +++++- ....py => test_flatland_utils_rendertools.py} | 2 +- tests/test_player.py | 6 ---- 7 files changed, 26 insertions(+), 45 deletions(-) delete mode 100644 tests/test_env_edit.py create mode 100644 tests/test_flatland_core_transition_map.py rename tests/{test_transitions.py => test_flatland_core_transitions.py} (100%) rename tests/{test_environments.py => test_flatland_envs_rail_env.py} (96%) rename tests/{test_rendertools.py => test_flatland_utils_rendertools.py} (96%) delete mode 100644 tests/test_player.py diff --git a/flatland/core/transition_map.py b/flatland/core/transition_map.py index 43b9a72a..6c9bde42 100644 --- a/flatland/core/transition_map.py +++ b/flatland/core/transition_map.py @@ -264,7 +264,7 @@ class GridTransitionMap(TransitionMap): """ np.save(filename, self.grid) - def load_transition_map(self, package, resource, override_gridsize=True): + def load_transition_map(self, package, resource): """ Load the transitions grid from `filename' (npy format). The load function only updates the transitions grid, and possibly width and height, but the object has to be @@ -289,28 +289,9 @@ class GridTransitionMap(TransitionMap): new_height = new_grid.shape[0] new_width = new_grid.shape[1] - if override_gridsize: - self.width = new_width - self.height = new_height - self.grid = new_grid - - else: - if new_grid.dtype == np.uint16: - self.grid = np.zeros((self.height, self.width), dtype=np.uint16) - elif new_grid.dtype == np.uint64: - self.grid = np.zeros((self.height, self.width), dtype=np.uint64) - - self.grid[0:min(self.height, new_height), - 0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height), - 0:min(self.width, new_width)] - - def is_cell_valid(self, rcPos): - cell_transition = self.grid[tuple(rcPos)] - - if not self.transitions.is_valid(cell_transition): - return False - else: - return True + self.width = new_width + self.height = new_height + self.grid = new_grid def cell_neighbours_valid(self, rcPos, check_this_cell=False): """ @@ -364,9 +345,6 @@ class GridTransitionMap(TransitionMap): return True - def cell_repr(self, rcPos): - return self.transitions.repr(self.get_transitions(rcPos)) - # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids # (most general implementation) or to make Grid-class specific methods for # slicing over the 3 dimensions? I'd say both perhaps. diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py deleted file mode 100644 index f0d86292..00000000 --- a/tests/test_env_edit.py +++ /dev/null @@ -1,11 +0,0 @@ -from flatland.envs.agent_utils import EnvAgentStatic -from flatland.envs.rail_env import RailEnv - - -def test_load_env(): - env = RailEnv(10, 10) - env.load_resource('env_data.tests', 'test-10x10.mpk') - - agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) - env.add_agent_static(agent_static) - assert env.get_num_agents() == 1 diff --git a/tests/test_flatland_core_transition_map.py b/tests/test_flatland_core_transition_map.py new file mode 100644 index 00000000..cf1a9206 --- /dev/null +++ b/tests/test_flatland_core_transition_map.py @@ -0,0 +1,13 @@ +from flatland.core.transition_map import GridTransitionMap +from flatland.core.transitions import Grid4Transitions, Grid8Transitions, Grid4TransitionsEnum + + +def test_grid4_set_transitions(): + grid4_map = GridTransitionMap(2, 2, Grid4Transitions([])) + grid4_map.set_transition((0, 0), Grid4TransitionsEnum.EAST, 1) + actual_transitions = grid4_map.get_transitions((0,0)) + assert False + + +def test_grid8_set_transitions(): + grid8_map = GridTransitionMap(2, 2, Grid8Transitions([])) diff --git a/tests/test_transitions.py b/tests/test_flatland_core_transitions.py similarity index 100% rename from tests/test_transitions.py rename to tests/test_flatland_core_transitions.py diff --git a/tests/test_environments.py b/tests/test_flatland_envs_rail_env.py similarity index 96% rename from tests/test_environments.py rename to tests/test_flatland_envs_rail_env.py index 11f0acba..e8811fce 100644 --- a/tests/test_environments.py +++ b/tests/test_flatland_envs_rail_env.py @@ -4,7 +4,7 @@ import numpy as np from flatland.core.transition_map import GridTransitionMap from flatland.core.transitions import Grid4Transitions -from flatland.envs.agent_utils import EnvAgent +from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.observations import GlobalObsForRailEnv @@ -12,6 +12,13 @@ from flatland.envs.rail_env import RailEnv """Tests for `flatland` package.""" +def test_load_env(): + env = RailEnv(10, 10) + env.load_resource('env_data.tests', 'test-10x10.mpk') + + agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) + env.add_agent_static(agent_static) + assert env.get_num_agents() == 1 def test_save_load(): env = RailEnv(width=10, height=10, diff --git a/tests/test_rendertools.py b/tests/test_flatland_utils_rendertools.py similarity index 96% rename from tests/test_rendertools.py rename to tests/test_flatland_utils_rendertools.py index 14edfee7..ff7cbd01 100644 --- a/tests/test_rendertools.py +++ b/tests/test_flatland_utils_rendertools.py @@ -79,7 +79,7 @@ def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True) else: - print("Run 'python test_rendertools.py save' to regenerate images") + print("Run 'python test_flatland_utils_rendertools.py save' to regenerate images") test_render_env() diff --git a/tests/test_player.py b/tests/test_player.py deleted file mode 100644 index 757fc90d..00000000 --- a/tests/test_player.py +++ /dev/null @@ -1,6 +0,0 @@ -from examples.play_model import main - - -def test_main(): - main(render=True, n_steps=20, n_trials=2, sGL="PIL") - main(render=True, n_steps=20, n_trials=2, sGL="PILSVG") -- GitLab