From ba0ae4f5a2e7b533ac207e9a5a155405ec7389d6 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Fri, 12 Jul 2019 10:56:24 -0400 Subject: [PATCH] simple tests for all generators --- flatland/envs/generators.py | 3 +- tests/tests_generators.py | 102 +++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 907b4a2..c6b4b5a 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -43,6 +43,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= """ def generator(width, height, num_agents, num_resets=0): + if num_agents > nr_start_goal: num_agents = nr_start_goal print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents") @@ -108,7 +109,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= break if not all_ok: - # we can might as well give up at this point + # we might as well give up at this point break new_path = connect_rail(rail_trans, rail_array, start, goal) diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 57fa45c..79a780d 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -3,14 +3,112 @@ import numpy as np -from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file +from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file, complex_rail_generator, \ + random_rail_generator, empty_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from tests.simple_rail import make_simple_rail -def test_load_pkl(): +def test_empty_rail_generator(): + np.random.seed(0) + n_agents = 1 + x_dim = 5 + y_dim = 10 + + # Check that a random level at with correct parameters is generated + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=empty_rail_generator() + ) + # Check the dimensions + assert env.rail.grid.shape == (y_dim, x_dim) + # Check that no grid was generated + assert np.count_nonzero(env.rail.grid) == 0 + # Check that no agents where placed + assert env.get_num_agents() == 0 + + return + + +def test_random_rail_generator(): + np.random.seed(0) + n_agents = 1 + x_dim = 5 + y_dim = 10 + + # Check that a random level at with correct parameters is generated + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=random_rail_generator() + ) + assert env.rail.grid.shape == (y_dim, x_dim) + assert env.get_num_agents() == n_agents + + return + + +def test_complex_rail_generator(): + n_agents = 10 + n_start = 2 + x_dim = 10 + y_dim = 10 + min_dist = 4 + + # Check that agent number is changed to fit generated level + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + ) + assert env.get_num_agents() == 2 + + min_dist = 2 * x_dim + + # Check that no agents are generated when level cannot be generated + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + ) + assert env.get_num_agents() == 0 + + # Check that everything stays the same when correct parameters are given + min_dist = 2 + n_start = 5 + n_agents = 5 + + env = RailEnv(width=x_dim, + height=y_dim, + number_of_agents=n_agents, + rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist) + ) + assert env.get_num_agents() == n_agents + + return + + +def test_rail_from_GridTransitionMap_generator(): + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_GridTransitionMap_generator(rail), + ) + nr_rail_elements = np.count_nonzero(env.rail.grid) + + # Check if the number of non-empty rail cells is ok + assert nr_rail_elements == 16 + + # Check that agents are placed on a rail + for a in env.agents: + assert env.rail.grid[a.position] != 0 + return + + +def tests_rail_from_file(): file_name = "test_pkl.pkl" rail, rail_map = make_simple_rail() env = RailEnv(width=rail_map.shape[1], -- GitLab