test_generators.py 5.47 KB
Newer Older
1
2
3
#!/usr/bin/env python
# -*- coding: utf-8 -*-

4
5
import numpy as np

6
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
7
8
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.envs.rail_env import RailEnv
9
10
from flatland.envs.rail_generators import rail_from_grid_transition_map, rail_from_file, empty_rail_generator
from flatland.envs.line_generators import sparse_line_generator, line_from_file
11
from flatland.utils.simple_rail import make_simple_rail
12
from flatland.envs.persistence import RailEnvPersister
13
from flatland.envs.step_utils.states import TrainState
14
15


Erik Nygren's avatar
Erik Nygren committed
16
def test_empty_rail_generator():
17
    n_agents = 2
Erik Nygren's avatar
Erik Nygren committed
18
19
20
21
    x_dim = 5
    y_dim = 10

    # Check that a random level at with correct parameters is generated
Dipam Chakraborty's avatar
Dipam Chakraborty committed
22
    rail, _ = empty_rail_generator().generate(width=x_dim, height=y_dim, num_agents=n_agents)
Erik Nygren's avatar
Erik Nygren committed
23
    # Check the dimensions
Dipam Chakraborty's avatar
Dipam Chakraborty committed
24
    assert rail.grid.shape == (y_dim, x_dim)
Erik Nygren's avatar
Erik Nygren committed
25
    # Check that no grid was generated
Dipam Chakraborty's avatar
Dipam Chakraborty committed
26
    assert np.count_nonzero(rail.grid) == 0
Erik Nygren's avatar
Erik Nygren committed
27
28


29
def test_rail_from_grid_transition_map():
30
    rail, rail_map, optionals = make_simple_rail()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
31
    n_agents = 2
32
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
33
                  line_generator=sparse_line_generator(), number_of_agents=n_agents)
34
35
36
37
    env.reset(False, False)

    for a_idx in range(len(env.agents)):
        env.agents[a_idx].position =  env.agents[a_idx].initial_position
38
        env.agents[a_idx]._set_state(TrainState.MOVING)
39

Erik Nygren's avatar
Erik Nygren committed
40
41
42
43
44
45
46
47
    nr_rail_elements = np.count_nonzero(env.rail.grid)

    # Check if the number of non-empty rail cells is ok
    assert nr_rail_elements == 16

    # Check that agents are placed on a rail
    for a in env.agents:
        assert env.rail.grid[a.position] != 0
48
49

    assert env.get_num_agents() == n_agents
Erik Nygren's avatar
Erik Nygren committed
50
51
52


def tests_rail_from_file():
53
54
55
56
    file_name = "test_with_distance_map.pkl"

    # Test to save and load file with distance map.

57
    rail, rail_map, optionals = make_simple_rail()
58

59
    env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail, optionals),
60
                  line_generator=sparse_line_generator(), number_of_agents=3,
61
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
62
    env.reset()
63
64
    #env.save(file_name)
    RailEnvPersister.save(env, file_name)
65
    dist_map_shape = np.shape(env.distance_map.get())
66
67
    rails_initial = env.rail.grid
    agents_initial = env.agents
68

69
    env = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
70
                  line_generator=line_from_file(file_name), number_of_agents=1,
71
                  obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
72
    env.reset()
73
74
75
76
77
    rails_loaded = env.rail.grid
    agents_loaded = env.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded))
    assert agents_initial == agents_loaded
78
79

    # Check that distance map was not recomputed
80
81
    assert np.shape(env.distance_map.get()) == dist_map_shape
    assert env.distance_map.get() is not None
82
83
84
85
86

    # Test to save and load file without distance map.

    file_name_2 = "test_without_distance_map.pkl"

87
    env2 = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0],
88
                   rail_generator=rail_from_grid_transition_map(rail), line_generator=sparse_line_generator(),
89
                   number_of_agents=3, obs_builder_object=GlobalObsForRailEnv())
90
    env2.reset()
91
92
    #env2.save(file_name_2)
    RailEnvPersister.save(env2, file_name_2)
93
94
95

    rails_initial_2 = env2.rail.grid
    agents_initial_2 = env2.agents
96

97
    env2 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name_2),
98
                   line_generator=line_from_file(file_name_2), number_of_agents=1,
99
                   obs_builder_object=GlobalObsForRailEnv())
100
    env2.reset()
101
102
103
104
105
106
107
108
109
    rails_loaded_2 = env2.rail.grid
    agents_loaded_2 = env2.agents

    assert np.all(np.array_equal(rails_initial_2, rails_loaded_2))
    assert agents_initial_2 == agents_loaded_2
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save with distance map and load without

110
    env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name),
111
                   line_generator=line_from_file(file_name), number_of_agents=1,
112
                   obs_builder_object=GlobalObsForRailEnv())
113
    env3.reset()
114
115
116
117
118
119
120
121
122
123
124
125
    rails_loaded_3 = env3.rail.grid
    agents_loaded_3 = env3.agents

    assert np.all(np.array_equal(rails_initial, rails_loaded_3))
    assert agents_initial == agents_loaded_3
    assert not hasattr(env2.obs_builder, "distance_map")

    # Test to save without distance map and load with generating distance map

    env4 = RailEnv(width=1,
                   height=1,
                   rail_generator=rail_from_file(file_name_2),
126
                   line_generator=line_from_file(file_name_2),
127
128
129
                   number_of_agents=1,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2),
                   )
130
    env4.reset()
131
132
133
    rails_loaded_4 = env4.rail.grid
    agents_loaded_4 = env4.agents

134
135
    # Check that no distance map was saved
    assert not hasattr(env2.obs_builder, "distance_map")
136
137
    assert np.all(np.array_equal(rails_initial_2, rails_loaded_4))
    assert agents_initial_2 == agents_loaded_4
138
139

    # Check that distance map was generated with correct shape
140
141
    assert env4.distance_map.get() is not None
    assert np.shape(env4.distance_map.get()) == dist_map_shape