diff --git a/images/basic-env.png b/images/basic-env.png new file mode 100644 index 0000000000000000000000000000000000000000..4799468af4b0d188874e3f0145103a1cb738f6dd Binary files /dev/null and b/images/basic-env.png differ diff --git a/images/test/basic-env.png b/images/test/basic-env.png new file mode 100644 index 0000000000000000000000000000000000000000..4799468af4b0d188874e3f0145103a1cb738f6dd Binary files /dev/null and b/images/test/basic-env.png differ diff --git a/requirements_dev.txt b/requirements_dev.txt index 70d99c53aead6b2ff46250a662c255bacda70c10..4cb4edd4f4dd0f41a32b434e4e8ca13d5d6199c8 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -12,3 +12,7 @@ pytest-runner==4.2 sphinx-rtd-theme==0.4.3 numpy==1.16.2 +recordtype==1.3 +xarray==0.11.3 +matplotlib==3.0.2 + diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py new file mode 100644 index 0000000000000000000000000000000000000000..3153434eb96ac6af388bec8d1611bf1fa8d61f89 --- /dev/null +++ b/tests/test_rendertools.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from flatland.core.env import RailEnv +#from flatland.core.transitions import GridTransitions +import numpy as np +import random +import os + +from recordtype import recordtype + +import numpy as np +from numpy import array +import xarray as xr +import matplotlib.pyplot as plt + +from flatland.core.transitions import RailEnvTransitions +#import flatland.core.env +from flatland.utils import rail_env_generator +from flatland.core.env import RailEnv +import flatland.utils.rendertools as rt + + + + +"""Tests for `flatland` package.""" + + + +def checkFrozenImage(sFileImage): + sTmpFileImage = "images/test/" + sFileImage + + if os.path.exists(sTmpFileImage): + os.remove(sTmpFileImage) + + plt.savefig(sTmpFileImage) + + bytesFrozenImage = None + for sDir in [ "images/", "images/test/" ]: + sfPath = sDir + sFileImage + with open(sfPath, "rb") as fIn: + bytesImage = fIn.read() + if bytesFrozenImage == None: + bytesFrozenImage = bytesImage + else: + assert(bytesFrozenImage == bytesImage) + + +def test_render_env(): + random.seed(100) + oRail = rail_env_generator.generate_random_rail(10,10) + type(oRail), len(oRail) + oEnv = RailEnv(oRail, number_of_agents=2) + oEnv.reset() + oRT = rt.RenderTool(oEnv) + plt.figure(figsize=(10,10)) + oRT.renderEnv() + + checkFrozenImage("basic-env.png") + + +