#!/usr/bin/env python # -*- coding: utf-8 -*- """ Tests for `flatland` package. """ import sys import matplotlib.pyplot as plt import numpy as np import flatland.utils.rendertools as rt from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv, random_rail_generator def checkFrozenImage(oRT, sFileImage, resave=False): sDirRoot = "." sDirImages = sDirRoot + "/images/" img_test = oRT.getImage() if resave: np.savez_compressed(sDirImages + sFileImage, img=img_test) return # this is now just for convenience - the file is not read back np.savez_compressed(sDirImages + "test/" + sFileImage, img=img_test) image_store = np.load(sDirImages + sFileImage) img_expected = image_store["img"] assert (img_test.shape == img_expected.shape) assert ((np.sum(np.square(img_test - img_expected)) / img_expected.size / 256) < 1e-3), \ "Image {} does not match".format(sFileImage) def test_render_env(save_new_images=False): # random.seed(100) np.random.seed(100) oEnv = RailEnv(width=10, height=10, rail_generator=random_rail_generator(), number_of_agents=0, # obs_builder_object=GlobalObsForRailEnv()) obs_builder_object=TreeObsForRailEnv(max_depth=2) ) sfTestEnv = "env-data/tests/test1.npy" oEnv.rail.load_transition_map(sfTestEnv) oRT = rt.RenderTool(oEnv) oRT.renderEnv() checkFrozenImage(oRT, "basic-env.npz", resave=save_new_images) oRT = rt.RenderTool(oEnv, gl="PIL") oRT.renderEnv() checkFrozenImage(oRT, "basic-env-PIL.npz", resave=save_new_images) # disable the tree / observation tests until env-agent save/load is available if False: lVisits = oRT.getTreeFromRail( oEnv.agents_position[0], oEnv.agents_direction[0], nDepth=17, bPlot=True) checkFrozenImage("env-tree-spatial.png") plt.figure(figsize=(8, 8)) xyTarg = oRT.env.agents_target[0] visitDest = oRT.plotTree(lVisits, xyTarg) checkFrozenImage("env-tree-graph.png") plt.figure(figsize=(10, 10)) oRT.renderEnv() oRT.plotPath(visitDest) checkFrozenImage("env-path.png") def main(): if len(sys.argv) == 2 and sys.argv[1] == "save": test_render_env(save_new_images=True) else: print("Run 'python test_rendertools.py save' to regenerate images") if __name__ == "__main__": main()