Skip to content
Snippets Groups Projects
Commit b5b2ca8e authored by hagrid67's avatar hagrid67
Browse files

updated test_renderenv to use npz format.

Added a test for the PIL graphics layer.
parent c4848fb0
No related branches found
No related tags found
No related merge requests found
......@@ -58,7 +58,6 @@ class View(object):
def init_canvas(self):
# update the rendertool with the env
self.new_env()
#plt.figure(figsize=(10, 10))
self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False)
img = self.oRT.getImage()
plt.clf()
......
......@@ -93,6 +93,7 @@ class MPLGL(GraphicsLayer):
color = tuple([iRGBA / 255 for iRGBA in color])
return color
class RenderTool(object):
Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"])
......
File added
File added
......@@ -7,6 +7,7 @@ Tests for `flatland` package.
from flatland.envs.rail_env import RailEnv, random_rail_generator
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
......@@ -14,27 +15,28 @@ import flatland.utils.rendertools as rt
from flatland.core.env_observation_builder import TreeObsForRailEnv
def checkFrozenImage(sFileImage):
def checkFrozenImage(oRT, sFileImage, resave=False):
sDirRoot = "."
sTmpFileImage = sDirRoot + "/images/test/" + sFileImage
sDirImages = sDirRoot + "/images/"
if os.path.exists(sTmpFileImage):
os.remove(sTmpFileImage)
img_test = oRT.getImage()
plt.savefig(sTmpFileImage)
if resave:
np.savez_compressed(sDirImages + sFileImage, img=img_test)
return
bytesFrozenImage = None
for sDir in ["/images/", "/images/test/"]:
sfPath = sDirRoot + sDir + sFileImage
bytesImage = plt.imread(sfPath)
if bytesFrozenImage is None:
bytesFrozenImage = bytesImage
else:
assert (bytesFrozenImage.shape == bytesImage.shape)
assert ((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3)
# this is now just for convenience - the file is not read back
np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test)
image_store = np.load(sDirImages + sFileImage)
img_expected = image_store["img"]
def test_render_env():
assert (img_test.shape == img_expected.shape)
assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \
"Image {} does not match".format(sFileImage)
def test_render_env(save_new_images=False):
# random.seed(100)
np.random.seed(100)
oEnv = RailEnv(width=10, height=10,
......@@ -45,15 +47,14 @@ def test_render_env():
)
sfTestEnv = "env-data/tests/test1.npy"
oEnv.rail.load_transition_map(sfTestEnv)
oEnv.reset()
oRT = rt.RenderTool(oEnv)
plt.figure(figsize=(10, 10))
oRT.renderEnv()
checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
checkFrozenImage("basic-env.png")
plt.figure(figsize=(10, 10))
oRT = rt.RenderTool(oEnv, gl="PIL")
oRT.renderEnv()
checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
# disable the tree / observation tests until env-agent save/load is available
if False:
......@@ -75,3 +76,14 @@ def test_render_env():
oRT.plotPath(visitDest)
checkFrozenImage("env-path.png")
def main():
if len(sys.argv) == 2 and sys.argv[1] == "save":
test_render_env(save_new_images=True)
else:
print("Run 'python test_rendertools.py save' to regenerate images")
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment