Skip to content
Snippets Groups Projects
test_flatland_envs_sparse_rail_generator.py 10.9 KiB
Newer Older
import numpy as np

from flatland.core.grid.grid_utils import Vec2dOperations as Vec2d
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from flatland.utils.rendertools import RenderTool
Erik Nygren's avatar
Erik Nygren committed

def test_sparse_rail_generator():
    env = RailEnv(width=50,
                  height=50,
                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
                                                       num_intersections=10,  # Number of interesections in map
                                                       num_trainstations=50,  # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,  # Proximity of stations to city center
                                                       num_neighb=3,  # Number of connections to other cities
Erik Nygren's avatar
Erik Nygren committed
                                                       grid_mode=False  # Ordered distribution of nodes
                  schedule_generator=sparse_schedule_generator(),
                  number_of_agents=10,
                  obs_builder_object=GlobalObsForRailEnv())
    # approximative test (replace image comparison)
    assert (np.sum(env.rail.grid) == 10746925)
    s0 = 0
    s1 = 0
    for a in range(env.get_num_agents()):
        s0 = Vec2d.get_manhattan_distance(env.agents[a].position, (0, 0))
        s1 = Vec2d.get_chebyshev_distance(env.agents[a].position, (0, 0))
    assert s0 == 33
    assert s1 == 31
def test_rail_env_action_required_info():
    np.random.seed(0)
    speed_ration_map = {1.: 0.25,  # Fast passenger train
                        1. / 2.: 0.25,  # Fast freight train
                        1. / 3.: 0.25,  # Slow commuter train
                        1. / 4.: 0.25}  # Slow freight train
    env_always_action = RailEnv(width=50,
                                height=50,
                                rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
                                                                     num_intersections=10,
                                                                     # Number of interesections in map
                                                                     num_trainstations=50,
                                                                     # Number of possible start/targets on map
                                                                     min_node_dist=6,  # Minimal distance of nodes
                                                                     node_radius=3,
                                                                     # Proximity of stations to city center
                                                                     num_neighb=3,
                                                                     # Number of connections to other cities
                                                                     seed=5,  # Random seed
                                                                     grid_mode=False  # Ordered distribution of nodes
                                                                     ),
                                schedule_generator=sparse_schedule_generator(speed_ration_map),
                                number_of_agents=10,
                                obs_builder_object=GlobalObsForRailEnv())
    np.random.seed(0)
    env_only_if_action_required = RailEnv(width=50,
                                          height=50,
                                          rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
                                                                               num_intersections=10,
                                                                               # Number of interesections in map
                                                                               num_trainstations=50,
                                                                               # Number of possible start/targets on map
                                                                               min_node_dist=6,
                                                                               # Minimal distance of nodes
                                                                               node_radius=3,
                                                                               # Proximity of stations to city center
                                                                               num_neighb=3,
                                                                               # Number of connections to other cities
                                                                               seed=5,  # Random seed
                                                                               grid_mode=False
                                                                               # Ordered distribution of nodes
                                                                               ),
                                          schedule_generator=sparse_schedule_generator(speed_ration_map),
                                          number_of_agents=10,
                                          obs_builder_object=GlobalObsForRailEnv())
    env_renderer = RenderTool(env_always_action, gl="PILSVG", )

    for step in range(100):
        print("step {}".format(step))

        action_dict_always_action = dict()
        action_dict_only_if_action_required = dict()
        # Chose an action for each agent in the environment
        for a in range(env_always_action.get_num_agents()):
            action = np.random.choice(np.arange(4))
            action_dict_always_action.update({a: action})
            if step == 0 or info_only_if_action_required['action_required'][a]:
                action_dict_only_if_action_required.update({a: action})
                print("[{}] not action_required {}, speed_data={}".format(step, a,
                                                                          env_always_action.agents[a].speed_data))

        obs_always_action, rewards_always_action, done_always_action, info_always_action = env_always_action.step(
            action_dict_always_action)
        obs_only_if_action_required, rewards_only_if_action_required, done_only_if_action_required, info_only_if_action_required = env_only_if_action_required.step(
            action_dict_only_if_action_required)

        for a in range(env_always_action.get_num_agents()):
            assert len(obs_always_action[a]) == len(obs_only_if_action_required[a])
            for i in range(len(obs_always_action[a])):
                assert np.array_equal(obs_always_action[a][i], obs_only_if_action_required[a][i])
            assert np.array_equal(rewards_always_action[a], rewards_only_if_action_required[a])
            assert np.array_equal(done_always_action[a], done_only_if_action_required[a])
            assert info_always_action['action_required'][a] == info_only_if_action_required['action_required'][a]

        env_renderer.render_env(show=True, show_observations=False, show_predictions=False)

        if done_always_action['__all__']:
            break
    env_renderer.close_window()
u214892's avatar
u214892 committed
def test_rail_env_malfunction_speed_info():
    np.random.seed(0)
    stochastic_data = {'prop_malfunction': 0.5,  # Percentage of defective agents
                       'malfunction_rate': 30,  # Rate of malfunction occurence
                       'min_duration': 3,  # Minimal duration of malfunction
                       'max_duration': 10  # Max duration of malfunction
                       }
    env = RailEnv(width=50,
                  height=50,
                  rail_generator=sparse_rail_generator(num_cities=10,  # Number of cities in map
                                                       num_intersections=10,
                                                       # Number of interesections in map
                                                       num_trainstations=50,
                                                       # Number of possible start/targets on map
                                                       min_node_dist=6,  # Minimal distance of nodes
                                                       node_radius=3,
                                                       # Proximity of stations to city center
                                                       num_neighb=3,
                                                       # Number of connections to other cities
                                                       seed=5,  # Random seed
                                                       grid_mode=False  # Ordered distribution of nodes
                                                       ),
                  schedule_generator=sparse_schedule_generator(),
                  number_of_agents=10,
                  obs_builder_object=GlobalObsForRailEnv(),
                  stochastic_data=stochastic_data)

    env_renderer = RenderTool(env, gl="PILSVG", )
    for step in range(100):
        action_dict = dict()
        # Chose an action for each agent in the environment
        for a in range(env.get_num_agents()):
            action = np.random.choice(np.arange(4))
            action_dict.update({a: action})

        obs, rewards, done, info = env.step(
            action_dict)

        assert 'malfunction' in info
        for a in range(env.get_num_agents()):
            assert info['malfunction'][a] >= 0
u214892's avatar
u214892 committed
            assert info['speed'][a] >= 0 and info['speed'][a] <= 1
            assert info['speed'][a] == env.agents[a].speed_data['speed']

        env_renderer.render_env(show=True, show_observations=False, show_predictions=False)

        if done['__all__']:
            break
    env_renderer.close_window()
def test_sparse_generator_with_too_man_cities_does_not_break_down():
    np.random.seed(0)

    RailEnv(width=50,
            height=50,
            rail_generator=sparse_rail_generator(
                num_cities=100,  # Number of cities in map
                num_intersections=10,  # Number of interesections in map
                num_trainstations=50,  # Number of possible start/targets on map
                min_node_dist=6,  # Minimal distance of nodes
                node_radius=3,  # Proximity of stations to city center
                num_neighb=3,  # Number of connections to other cities
                seed=5,  # Random seed
                grid_mode=False  # Ordered distribution of nodes
            ),
            schedule_generator=sparse_schedule_generator(),
            number_of_agents=10,
            obs_builder_object=GlobalObsForRailEnv())