From ba0ae4f5a2e7b533ac207e9a5a155405ec7389d6 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Fri, 12 Jul 2019 10:56:24 -0400
Subject: [PATCH] simple tests for all generators

---
 flatland/envs/generators.py |   3 +-
 tests/tests_generators.py   | 102 +++++++++++++++++++++++++++++++++++-
 2 files changed, 102 insertions(+), 3 deletions(-)

diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py
index 907b4a2..c6b4b5a 100644
--- a/flatland/envs/generators.py
+++ b/flatland/envs/generators.py
@@ -43,6 +43,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
     """
 
     def generator(width, height, num_agents, num_resets=0):
+
         if num_agents > nr_start_goal:
             num_agents = nr_start_goal
             print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
@@ -108,7 +109,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
                     break
 
             if not all_ok:
-                # we can might as well give up at this point
+                # we might as well give up at this point
                 break
 
             new_path = connect_rail(rail_trans, rail_array, start, goal)
diff --git a/tests/tests_generators.py b/tests/tests_generators.py
index 57fa45c..79a780d 100644
--- a/tests/tests_generators.py
+++ b/tests/tests_generators.py
@@ -3,14 +3,112 @@
 
 import numpy as np
 
-from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file
+from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file, complex_rail_generator, \
+    random_rail_generator, empty_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from tests.simple_rail import make_simple_rail
 
 
-def test_load_pkl():
+def test_empty_rail_generator():
+    np.random.seed(0)
+    n_agents = 1
+    x_dim = 5
+    y_dim = 10
+
+    # Check that a random level at with correct parameters is generated
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  number_of_agents=n_agents,
+                  rail_generator=empty_rail_generator()
+                  )
+    # Check the dimensions
+    assert env.rail.grid.shape == (y_dim, x_dim)
+    # Check that no grid was generated
+    assert np.count_nonzero(env.rail.grid) == 0
+    # Check that no agents where placed
+    assert env.get_num_agents() == 0
+
+    return
+
+
+def test_random_rail_generator():
+    np.random.seed(0)
+    n_agents = 1
+    x_dim = 5
+    y_dim = 10
+
+    # Check that a random level at with correct parameters is generated
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  number_of_agents=n_agents,
+                  rail_generator=random_rail_generator()
+                  )
+    assert env.rail.grid.shape == (y_dim, x_dim)
+    assert env.get_num_agents() == n_agents
+
+    return
+
+
+def test_complex_rail_generator():
+    n_agents = 10
+    n_start = 2
+    x_dim = 10
+    y_dim = 10
+    min_dist = 4
+
+    # Check that agent number is changed to fit generated level
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  number_of_agents=n_agents,
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  )
+    assert env.get_num_agents() == 2
+
+    min_dist = 2 * x_dim
+
+    # Check that no agents are generated when level cannot be generated
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  number_of_agents=n_agents,
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  )
+    assert env.get_num_agents() == 0
+
+    # Check that everything stays the same when correct parameters are given
+    min_dist = 2
+    n_start = 5
+    n_agents = 5
+
+    env = RailEnv(width=x_dim,
+                  height=y_dim,
+                  number_of_agents=n_agents,
+                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
+                  )
+    assert env.get_num_agents() == n_agents
+
+    return
+
+
+def test_rail_from_GridTransitionMap_generator():
+    rail, rail_map = make_simple_rail()
+    env = RailEnv(width=rail_map.shape[1],
+                  height=rail_map.shape[0],
+                  rail_generator=rail_from_GridTransitionMap_generator(rail),
+                  )
+    nr_rail_elements = np.count_nonzero(env.rail.grid)
+
+    # Check if the number of non-empty rail cells is ok
+    assert nr_rail_elements == 16
+
+    # Check that agents are placed on a rail
+    for a in env.agents:
+        assert env.rail.grid[a.position] != 0
+    return
+
+
+def tests_rail_from_file():
     file_name = "test_pkl.pkl"
     rail, rail_map = make_simple_rail()
     env = RailEnv(width=rail_map.shape[1],
-- 
GitLab