Skip to content
Snippets Groups Projects
Commit d2a71e19 authored by u214892's avatar u214892
Browse files

#62 first steps unit test coverage

parent 0db89d48
No related branches found
No related tags found
No related merge requests found
...@@ -264,7 +264,7 @@ class GridTransitionMap(TransitionMap): ...@@ -264,7 +264,7 @@ class GridTransitionMap(TransitionMap):
""" """
np.save(filename, self.grid) np.save(filename, self.grid)
def load_transition_map(self, package, resource, override_gridsize=True): def load_transition_map(self, package, resource):
""" """
Load the transitions grid from `filename' (npy format). Load the transitions grid from `filename' (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be The load function only updates the transitions grid, and possibly width and height, but the object has to be
...@@ -289,28 +289,9 @@ class GridTransitionMap(TransitionMap): ...@@ -289,28 +289,9 @@ class GridTransitionMap(TransitionMap):
new_height = new_grid.shape[0] new_height = new_grid.shape[0]
new_width = new_grid.shape[1] new_width = new_grid.shape[1]
if override_gridsize: self.width = new_width
self.width = new_width self.height = new_height
self.height = new_height self.grid = new_grid
self.grid = new_grid
else:
if new_grid.dtype == np.uint16:
self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
elif new_grid.dtype == np.uint64:
self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
self.grid[0:min(self.height, new_height),
0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
0:min(self.width, new_width)]
def is_cell_valid(self, rcPos):
cell_transition = self.grid[tuple(rcPos)]
if not self.transitions.is_valid(cell_transition):
return False
else:
return True
def cell_neighbours_valid(self, rcPos, check_this_cell=False): def cell_neighbours_valid(self, rcPos, check_this_cell=False):
""" """
...@@ -364,9 +345,6 @@ class GridTransitionMap(TransitionMap): ...@@ -364,9 +345,6 @@ class GridTransitionMap(TransitionMap):
return True return True
def cell_repr(self, rcPos):
return self.transitions.repr(self.get_transitions(rcPos))
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids # TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for # (most general implementation) or to make Grid-class specific methods for
# slicing over the 3 dimensions? I'd say both perhaps. # slicing over the 3 dimensions? I'd say both perhaps.
......
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.rail_env import RailEnv
def test_load_env():
env = RailEnv(10, 10)
env.load_resource('env_data.tests', 'test-10x10.mpk')
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
assert env.get_num_agents() == 1
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid4Transitions, Grid8Transitions, Grid4TransitionsEnum
def test_grid4_set_transitions():
grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
grid4_map.set_transition((0, 0), Grid4TransitionsEnum.EAST, 1)
actual_transitions = grid4_map.get_transitions((0,0))
assert False
def test_grid8_set_transitions():
grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
File moved
...@@ -4,7 +4,7 @@ import numpy as np ...@@ -4,7 +4,7 @@ import numpy as np
from flatland.core.transition_map import GridTransitionMap from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid4Transitions from flatland.core.transitions import Grid4Transitions
from flatland.envs.agent_utils import EnvAgent from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
from flatland.envs.generators import complex_rail_generator from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.observations import GlobalObsForRailEnv
...@@ -12,6 +12,13 @@ from flatland.envs.rail_env import RailEnv ...@@ -12,6 +12,13 @@ from flatland.envs.rail_env import RailEnv
"""Tests for `flatland` package.""" """Tests for `flatland` package."""
def test_load_env():
env = RailEnv(10, 10)
env.load_resource('env_data.tests', 'test-10x10.mpk')
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
assert env.get_num_agents() == 1
def test_save_load(): def test_save_load():
env = RailEnv(width=10, height=10, env = RailEnv(width=10, height=10,
......
...@@ -79,7 +79,7 @@ def main(): ...@@ -79,7 +79,7 @@ def main():
if len(sys.argv) == 2 and sys.argv[1] == "save": if len(sys.argv) == 2 and sys.argv[1] == "save":
test_render_env(save_new_images=True) test_render_env(save_new_images=True)
else: else:
print("Run 'python test_rendertools.py save' to regenerate images") print("Run 'python test_flatland_utils_rendertools.py save' to regenerate images")
test_render_env() test_render_env()
......
from examples.play_model import main
def test_main():
main(render=True, n_steps=20, n_trials=2, sGL="PIL")
main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment