Commit d2a71e19 authored by u214892's avatar u214892
Browse files

#62 first steps unit test coverage

parent 0db89d48
Pipeline #1157 canceled with stage
......@@ -264,7 +264,7 @@ class GridTransitionMap(TransitionMap):
"""
np.save(filename, self.grid)
def load_transition_map(self, package, resource, override_gridsize=True):
def load_transition_map(self, package, resource):
"""
Load the transitions grid from `filename' (npy format).
The load function only updates the transitions grid, and possibly width and height, but the object has to be
......@@ -289,29 +289,10 @@ class GridTransitionMap(TransitionMap):
new_height = new_grid.shape[0]
new_width = new_grid.shape[1]
if override_gridsize:
self.width = new_width
self.height = new_height
self.grid = new_grid
else:
if new_grid.dtype == np.uint16:
self.grid = np.zeros((self.height, self.width), dtype=np.uint16)
elif new_grid.dtype == np.uint64:
self.grid = np.zeros((self.height, self.width), dtype=np.uint64)
self.grid[0:min(self.height, new_height),
0:min(self.width, new_width)] = new_grid[0:min(self.height, new_height),
0:min(self.width, new_width)]
def is_cell_valid(self, rcPos):
cell_transition = self.grid[tuple(rcPos)]
if not self.transitions.is_valid(cell_transition):
return False
else:
return True
def cell_neighbours_valid(self, rcPos, check_this_cell=False):
"""
Check validity of cell at rcPos = tuple(row, column)
......@@ -364,9 +345,6 @@ class GridTransitionMap(TransitionMap):
return True
def cell_repr(self, rcPos):
return self.transitions.repr(self.get_transitions(rcPos))
# TODO: GIACOMO: is it better to provide those methods with lists of cell_ids
# (most general implementation) or to make Grid-class specific methods for
# slicing over the 3 dimensions? I'd say both perhaps.
......
from flatland.envs.agent_utils import EnvAgentStatic
from flatland.envs.rail_env import RailEnv
def test_load_env():
env = RailEnv(10, 10)
env.load_resource('env_data.tests', 'test-10x10.mpk')
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
assert env.get_num_agents() == 1
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid4Transitions, Grid8Transitions, Grid4TransitionsEnum
def test_grid4_set_transitions():
grid4_map = GridTransitionMap(2, 2, Grid4Transitions([]))
grid4_map.set_transition((0, 0), Grid4TransitionsEnum.EAST, 1)
actual_transitions = grid4_map.get_transitions((0,0))
assert False
def test_grid8_set_transitions():
grid8_map = GridTransitionMap(2, 2, Grid8Transitions([]))
......@@ -4,7 +4,7 @@ import numpy as np
from flatland.core.transition_map import GridTransitionMap
from flatland.core.transitions import Grid4Transitions
from flatland.envs.agent_utils import EnvAgent
from flatland.envs.agent_utils import EnvAgent, EnvAgentStatic
from flatland.envs.generators import complex_rail_generator
from flatland.envs.generators import rail_from_GridTransitionMap_generator
from flatland.envs.observations import GlobalObsForRailEnv
......@@ -12,6 +12,13 @@ from flatland.envs.rail_env import RailEnv
"""Tests for `flatland` package."""
def test_load_env():
env = RailEnv(10, 10)
env.load_resource('env_data.tests', 'test-10x10.mpk')
agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False)
env.add_agent_static(agent_static)
assert env.get_num_agents() == 1
def test_save_load():
env = RailEnv(width=10, height=10,
......
......@@ -79,7 +79,7 @@ def main():
if len(sys.argv) == 2 and sys.argv[1] == "save":
test_render_env(save_new_images=True)
else:
print("Run 'python test_rendertools.py save' to regenerate images")
print("Run 'python test_flatland_utils_rendertools.py save' to regenerate images")
test_render_env()
......
from examples.play_model import main
def test_main():
main(render=True, n_steps=20, n_trials=2, sGL="PIL")
main(render=True, n_steps=20, n_trials=2, sGL="PILSVG")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment