Skip to content
Snippets Groups Projects
tests_generators.py 6.68 KiB
Newer Older
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np

from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \
    random_rail_generator, empty_rail_generator
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
from tests.simple_rail import make_simple_rail


def test_empty_rail_generator():
    np.random.seed(0)
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
                  rail_generator=empty_rail_generator()
                  )
    # Check the dimensions
    assert env.rail.grid.shape == (y_dim, x_dim)
    # Check that no grid was generated
    assert np.count_nonzero(env.rail.grid) == 0
    # Check that no agents where placed
    assert env.get_num_agents() == 0


def test_random_rail_generator():
    np.random.seed(0)
    n_agents = 1
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
                  rail_generator=random_rail_generator()
                  )
    assert env.rail.grid.shape == (y_dim, x_dim)
    assert env.get_num_agents() == n_agents


def test_complex_rail_generator():
    n_agents = 10
    n_start = 2
    x_dim = 10
    y_dim = 10
    min_dist = 4

    # Check that agent number is changed to fit generated level
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
                  )
    assert env.get_num_agents() == 2
    assert env.rail.grid.shape == (y_dim, x_dim)

    min_dist = 2 * x_dim

    # Check that no agents are generated when level cannot be generated
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
                  )
    assert env.get_num_agents() == 0
    assert env.rail.grid.shape == (y_dim, x_dim)

    # Check that everything stays the same when correct parameters are given
    min_dist = 2
    n_start = 5
    n_agents = 5

    env = RailEnv(width=x_dim,
                  height=y_dim,
                  number_of_agents=n_agents,
                  rail_generator=complex_rail_generator(nr_start_goal=n_start, nr_extra=0, min_dist=min_dist)
                  )
    assert env.get_num_agents() == n_agents
    assert env.rail.grid.shape == (y_dim, x_dim)
def test_rail_from_grid_transition_map():
    rail, rail_map = make_simple_rail()
    n_agents = 3
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  number_of_agents=n_agents
                  )
    nr_rail_elements = np.count_nonzero(env.rail.grid)

    # Check if the number of non-empty rail cells is ok
    assert nr_rail_elements == 16

    # Check that agents are placed on a rail
    for a in env.agents:
        assert env.rail.grid[a.position] != 0

    assert env.get_num_agents() == n_agents


def tests_rail_from_file():
    file_name = "test_with_distance_map.pkl"

    # Test to save and load file with distance map.

    rail, rail_map = make_simple_rail()
    env = RailEnv(width=rail_map.shape[1],
                  height=rail_map.shape[0],
                  rail_generator=rail_from_grid_transition_map(rail),
                  number_of_agents=3,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
    env.save(file_name)
    # initialize agents_static
    rails_initial = env.rail.grid
    agents_initial = env.agents

    env = RailEnv(width=1,
                  height=1,
                  rail_generator=rail_from_file(file_name),
                  number_of_agents=1,
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                  )
    rails_loaded = env.rail.grid
    agents_loaded = env.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded
    assert env.obs_builder.distance_map is not None

    # Test to save and load file without distance map.

    file_name_2 = "test_without_distance_map.pkl"

    env2 = RailEnv(width=rail_map.shape[1],
                   height=rail_map.shape[0],
                   rail_generator=rail_from_grid_transition_map(rail),
                   number_of_agents=3,
                   obs_builder_object=GlobalObsForRailEnv(),
                   )

    env2.save(file_name_2)

    # initialize agents_static
    rails_initial_2 = env2.rail.grid
    agents_initial_2 = env2.agents
    env2 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv(),
                   )

    rails_loaded_2 = env2.rail.grid
    agents_loaded_2 = env2.agents

    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
    assert agents_initial_2 == agents_loaded_2
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save with distance map and load without

    # initialize agents_static
    env3 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name),
                   number_of_agents=1,
                   obs_builder_object=GlobalObsForRailEnv(),
                   )

    rails_loaded_3 = env3.rail.grid
    agents_loaded_3 = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
    assert agents_initial == agents_loaded_3
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save without distance map and load with generating distance map

    # initialize agents_static
    env4 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
                   )

    rails_loaded_4 = env4.rail.grid
    agents_loaded_4 = env4.agents

    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
    assert agents_initial_2 == agents_loaded_4
    assert env.obs_builder.distance_map is not None