Skip to content
Snippets Groups Projects
test_rendertools.py 2.49 KiB
Newer Older
hagrid67's avatar
hagrid67 committed
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Tests for `flatland` package.
"""
hagrid67's avatar
hagrid67 committed

import sys
hagrid67's avatar
hagrid67 committed

import matplotlib.pyplot as plt
import numpy as np
hagrid67's avatar
hagrid67 committed

import flatland.utils.rendertools as rt
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.rail_env import RailEnv, random_rail_generator
def checkFrozenImage(oRT, sFileImage, resave=False):
    sDirRoot = "."
    sDirImages = sDirRoot + "/images/"
    img_test = oRT.getImage()
    if resave:
        np.savez_compressed(sDirImages + sFileImage, img=img_test)
        return
    # this is now just for convenience - the file is not read back
    np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test)
hagrid67's avatar
hagrid67 committed

    image_store = np.load(sDirImages + sFileImage)
    img_expected = image_store["img"]
hagrid67's avatar
hagrid67 committed

    assert (img_test.shape == img_expected.shape)
    assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \
        "Image {} does not match".format(sFileImage)


def test_render_env(save_new_images=False):
gmollard's avatar
gmollard committed
    oEnv = RailEnv(width=10, height=10,
                   rail_generator=random_rail_generator(),
                   number_of_agents=0,
                   # obs_builder_object=GlobalObsForRailEnv())
                   obs_builder_object=TreeObsForRailEnv(max_depth=2)
                   )
    sfTestEnv = "env-data/tests/test1.npy"
    oEnv.rail.load_transition_map(sfTestEnv)
hagrid67's avatar
hagrid67 committed
    oRT = rt.RenderTool(oEnv)
    oRT.renderEnv()
    checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images)
    oRT = rt.RenderTool(oEnv, gl="PIL")
hagrid67's avatar
hagrid67 committed
    oRT.renderEnv()
    checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images)
    # disable the tree / observation tests until env-agent save/load is available
    if False:
        lVisits = oRT.getTreeFromRail(
            oEnv.agents_position[0],
            oEnv.agents_direction[0],
            nDepth=17, bPlot=True)
hagrid67's avatar
hagrid67 committed

        checkFrozenImage("env-tree-spatial.png")
        plt.figure(figsize=(8, 8))
        xyTarg = oRT.env.agents_target[0]
        visitDest = oRT.plotTree(lVisits, xyTarg)
hagrid67's avatar
hagrid67 committed

        checkFrozenImage("env-tree-graph.png")
hagrid67's avatar
hagrid67 committed

        plt.figure(figsize=(10, 10))
        oRT.renderEnv()
        oRT.plotPath(visitDest)
hagrid67's avatar
hagrid67 committed



def main():
    if len(sys.argv) == 2 and sys.argv[1] == "save":
        test_render_env(save_new_images=True)
    else:
        print("Run 'python test_rendertools.py save' to regenerate images")


if __name__ == "__main__":