Skip to content
Snippets Groups Projects
Commit ba0ae4f5 authored by Erik Nygren's avatar Erik Nygren
Browse files

simple tests for all generators

parent a5c371df
No related branches found
No related tags found
No related merge requests found
......@@ -43,6 +43,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
"""
def generator(width, height, num_agents, num_resets=0):
if num_agents > nr_start_goal:
num_agents = nr_start_goal
print("complex_rail_generator: num_agents > nr_start_goal, changing num_agents")
......@@ -108,7 +109,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist=
break
if not all_ok:
# we can might as well give up at this point
# we might as well give up at this point
break
new_path = connect_rail(rail_trans, rail_array, start, goal)
......
......@@ -3,14 +3,112 @@
import numpy as np
from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file
from flatland.envs.generators import rail_from_GridTransitionMap_generator, rail_from_file, complex_rail_generator, \
random_rail_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail
def test_load_pkl():
def test_empty_rail_generator():
np.random.seed(0)
n_agents = 1
x_dim = 5
y_dim = 10
# Check that a random level at with correct parameters is generated
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=empty_rail_generator()
)
# Check the dimensions
assert env.rail.grid.shape == (y_dim, x_dim)
# Check that no grid was generated
assert np.count_nonzero(env.rail.grid) == 0
# Check that no agents where placed
assert env.get_num_agents() == 0
return
def test_random_rail_generator():
np.random.seed(0)
n_agents = 1
x_dim = 5
y_dim = 10
# Check that a random level at with correct parameters is generated
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=random_rail_generator()
)
assert env.rail.grid.shape == (y_dim, x_dim)
assert env.get_num_agents() == n_agents
return
def test_complex_rail_generator():
n_agents = 10
n_start = 2
x_dim = 10
y_dim = 10
min_dist = 4
# Check that agent number is changed to fit generated level
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == 2
min_dist = 2 * x_dim
# Check that no agents are generated when level cannot be generated
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == 0
# Check that everything stays the same when correct parameters are given
min_dist = 2
n_start = 5
n_agents = 5
env = RailEnv(width=x_dim,
height=y_dim,
number_of_agents=n_agents,
rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
)
assert env.get_num_agents() == n_agents
return
def test_rail_from_GridTransitionMap_generator():
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
height=rail_map.shape[0],
rail_generator=rail_from_GridTransitionMap_generator(rail),
)
nr_rail_elements = np.count_nonzero(env.rail.grid)
# Check if the number of non-empty rail cells is ok
assert nr_rail_elements == 16
# Check that agents are placed on a rail
for a in env.agents:
assert env.rail.grid[a.position] != 0
return
def tests_rail_from_file():
file_name = "test_pkl.pkl"
rail, rail_map = make_simple_rail()
env = RailEnv(width=rail_map.shape[1],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment