Skip to content
Snippets Groups Projects
Forked from Flatland / Flatland
1167 commits behind the upstream repository.
test_flatland_utils_rendertools.py 1.84 KiB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Tests for `flatland` package.
"""

import sys

import numpy as np
from importlib_resources import path

import flatland.utils.rendertools as rt
import images.test
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import empty_rail_generator


def checkFrozenImage(oRT, sFileImage, resave=False):
    sDirRoot = "."
    sDirImages = sDirRoot + "/images/"

    img_test = oRT.get_image()

    if resave:
        np.savez_compressed(sDirImages + sFileImage, img=img_test)
        return

    with path(images, sFileImage) as file_in:
        np.load(file_in)

    # TODO fails!
    #  assert (img_test.shape == img_expected.shape) \ #  noqa: E800
    #  assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \ #  noqa: E800
    #      "Image {} does not match".format(sFileImage) \ #  noqa: E800


def test_render_env(save_new_images=False):
    np.random.seed(100)
    oEnv = RailEnv(width=10, height=10,
                   rail_generator=empty_rail_generator(),
                   number_of_agents=0,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2)
                   )
    oEnv.rail.load_transition_map('env_data.tests', "test1.npy")
    oRT = rt.RenderTool(oEnv, gl="PILSVG")
    oRT.render_env(show=False)
    checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)

    oRT = rt.RenderTool(oEnv, gl="PIL")
    oRT.render_env()
    checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
    
def main():
    if len(sys.argv) == 2 and sys.argv[1] == "save":
        test_render_env(save_new_images=True)
    else:
        print("Run 'python test_flatland_utils_rendertools.py save' to regenerate images")
        test_render_env()


if __name__ == "__main__":
    main()