diff --git a/flatland/utils/editor.py b/flatland/utils/editor.py index 5e40701ad69c78247bc90b13c6e937b69aa72cf9..3e55a793f840eab6738877528da6ac01e61629be 100644 --- a/flatland/utils/editor.py +++ b/flatland/utils/editor.py @@ -58,7 +58,6 @@ class View(object): def init_canvas(self): # update the rendertool with the env self.new_env() - #plt.figure(figsize=(10, 10)) self.oRT.renderEnv(spacing=False, arrows=False, sRailColor="gray", show=False) img = self.oRT.getImage() plt.clf() diff --git a/flatland/utils/rendertools.py b/flatland/utils/rendertools.py index fb2751165687e3a2f4e1e5e3dd5768021a110735..316344027f05c651482fc1ea555d957b9234e3d4 100644 --- a/flatland/utils/rendertools.py +++ b/flatland/utils/rendertools.py @@ -93,6 +93,7 @@ class MPLGL(GraphicsLayer): color = tuple([iRGBA / 255 for iRGBA in color]) return color + class RenderTool(object): Visit = recordtype("Visit", ["rc", "iDir", "iDepth", "prev"]) diff --git a/images/basic-env-PIL.npz b/images/basic-env-PIL.npz new file mode 100644 index 0000000000000000000000000000000000000000..8ffaf023e1116b0c92702212ddb04c71b82f0655 Binary files /dev/null and b/images/basic-env-PIL.npz differ diff --git a/images/basic-env.npz b/images/basic-env.npz new file mode 100644 index 0000000000000000000000000000000000000000..356da5d70146b3b8081dd99c0fe5e6bd70646e53 Binary files /dev/null and b/images/basic-env.npz differ diff --git a/tests/test_rendertools.py b/tests/test_rendertools.py index 73b42c385867ba6bc93ff794ec3ecda0bf82125d..ea4ad3c46f3465c2f30aabaf8ae9cff7e2efd7f7 100644 --- a/tests/test_rendertools.py +++ b/tests/test_rendertools.py @@ -7,6 +7,7 @@ Tests for `flatland` package. from flatland.envs.rail_env import RailEnv, random_rail_generator import numpy as np import os +import sys import matplotlib.pyplot as plt @@ -14,27 +15,28 @@ import flatland.utils.rendertools as rt from flatland.core.env_observation_builder import TreeObsForRailEnv -def checkFrozenImage(sFileImage): +def checkFrozenImage(oRT, sFileImage, resave=False): sDirRoot = "." - sTmpFileImage = sDirRoot + "/images/test/" + sFileImage + sDirImages = sDirRoot + "/images/" - if os.path.exists(sTmpFileImage): - os.remove(sTmpFileImage) + img_test = oRT.getImage() - plt.savefig(sTmpFileImage) + if resave: + np.savez_compressed(sDirImages + sFileImage, img=img_test) + return - bytesFrozenImage = None - for sDir in ["/images/", "/images/test/"]: - sfPath = sDirRoot + sDir + sFileImage - bytesImage = plt.imread(sfPath) - if bytesFrozenImage is None: - bytesFrozenImage = bytesImage - else: - assert (bytesFrozenImage.shape == bytesImage.shape) - assert ((np.sum(np.square(bytesFrozenImage - bytesImage)) / bytesFrozenImage.size) < 1e-3) + # this is now just for convenience - the file is not read back + np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test) + image_store = np.load(sDirImages + sFileImage) + img_expected = image_store["img"] -def test_render_env(): + assert (img_test.shape == img_expected.shape) + assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \ + "Image {} does not match".format(sFileImage) + + +def test_render_env(save_new_images=False): # random.seed(100) np.random.seed(100) oEnv = RailEnv(width=10, height=10, @@ -45,15 +47,14 @@ def test_render_env(): ) sfTestEnv = "env-data/tests/test1.npy" oEnv.rail.load_transition_map(sfTestEnv) - oEnv.reset() oRT = rt.RenderTool(oEnv) - plt.figure(figsize=(10, 10)) oRT.renderEnv() + + checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) - checkFrozenImage("basic-env.png") - - plt.figure(figsize=(10, 10)) + oRT = rt.RenderTool(oEnv, gl="PIL") oRT.renderEnv() + checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) # disable the tree / observation tests until env-agent save/load is available if False: @@ -75,3 +76,14 @@ def test_render_env(): oRT.plotPath(visitDest) checkFrozenImage("env-path.png") + + +def main(): + if len(sys.argv) == 2 and sys.argv[1] == "save": + test_render_env(save_new_images=True) + else: + print("Run 'python test_rendertools.py save' to regenerate images") + + +if __name__ == "__main__": + main() \ No newline at end of file