test_random_seeding.py 13.2 KB
Newer Older
1
2
3
4
import numpy as np

from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
5
from flatland.envs.rail_env import RailEnv
6
from flatland.envs.rail_generators import rail_from_grid_transition_map, sparse_rail_generator
7
from flatland.envs.line_generators import sparse_line_generator
Erik Nygren's avatar
Erik Nygren committed
8
9
10
from flatland.utils.simple_rail import make_simple_rail2


11
def ndom_seeding():
Erik Nygren's avatar
Erik Nygren committed
12
    # Set fixed malfunction duration for this test
13
    rail, rail_map, optionals = make_simple_rail2()
Erik Nygren's avatar
Erik Nygren committed
14
15

    # Move target to unreachable position in order to not interfere with test
Erik Nygren's avatar
Erik Nygren committed
16
    for idx in range(100):
17
        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
18
                      line_generator=sparse_line_generator(seed=12), number_of_agents=10)
19
        env.reset(True, True, random_seed=1)
20

Erik Nygren's avatar
Erik Nygren committed
21
        env.agents[0].target = (0, 0)
22
        for step in range(10):
Erik Nygren's avatar
Erik Nygren committed
23
            actions = {}
Erik Nygren's avatar
Erik Nygren committed
24
            actions[0] = 2
Erik Nygren's avatar
Erik Nygren committed
25
            env.step(actions)
Erik Nygren's avatar
Erik Nygren committed
26
        agent_positions = []
27
28
29
30
31
32
33
34
35
36
37

        env.agents[0].initial_position == (3, 2)
        env.agents[1].initial_position == (3, 5)
        env.agents[2].initial_position == (3, 6)
        env.agents[3].initial_position == (5, 6)
        env.agents[4].initial_position == (3, 4)
        env.agents[5].initial_position == (3, 1)
        env.agents[6].initial_position == (3, 9)
        env.agents[7].initial_position == (4, 6)
        env.agents[8].initial_position == (0, 3)
        env.agents[9].initial_position == (3, 7)
Erik Nygren's avatar
Erik Nygren committed
38
        # Test generation print
39
40
        # for a in range(env.get_num_agents()):
        #    print("env.agents[{}].initial_position == {}".format(a,env.agents[a].initial_position))
Erik Nygren's avatar
Erik Nygren committed
41
        # print("env.agents[0].initial_position == {}".format(env.agents[0].initial_position))
42
        # print("assert env.agents[0].position ==  {}".format(env.agents[0].position))
Erik Nygren's avatar
Erik Nygren committed
43
44
45


def test_seeding_and_observations():
46
    # Test if two different instances diverge with different observations
47
48
    rail, rail_map, optionals = make_simple_rail2()
    optionals['agents_hints']['num_agents'] = 10
49
50
    # Make two seperate envs with different observation builders
    # Global Observation
51
52
    env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                  line_generator=sparse_line_generator(seed=12), number_of_agents=10,
53
                  obs_builder_object=GlobalObsForRailEnv())
54
    # Tree Observation
55
56
    env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                   line_generator=sparse_line_generator(seed=12), number_of_agents=10,
57
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()))
58

59
60
    env.reset(False, False, random_seed=12)
    env2.reset(False, False, random_seed=12)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    # Check that both environments produce the same initial start positions
    assert env.agents[0].initial_position == env2.agents[0].initial_position
    assert env.agents[1].initial_position == env2.agents[1].initial_position
    assert env.agents[2].initial_position == env2.agents[2].initial_position
    assert env.agents[3].initial_position == env2.agents[3].initial_position
    assert env.agents[4].initial_position == env2.agents[4].initial_position
    assert env.agents[5].initial_position == env2.agents[5].initial_position
    assert env.agents[6].initial_position == env2.agents[6].initial_position
    assert env.agents[7].initial_position == env2.agents[7].initial_position
    assert env.agents[8].initial_position == env2.agents[8].initial_position
    assert env.agents[9].initial_position == env2.agents[9].initial_position

    action_dict = {}
    for step in range(10):
        for a in range(env.get_num_agents()):
            action = np.random.randint(4)
            action_dict[a] = action
        env.step(action_dict)
        env2.step(action_dict)
    # Check that both environments end up in the same position
    assert env.agents[0].position == env2.agents[0].position
    assert env.agents[1].position == env2.agents[1].position
    assert env.agents[2].position == env2.agents[2].position
    assert env.agents[3].position == env2.agents[3].position
    assert env.agents[4].position == env2.agents[4].position
    assert env.agents[5].position == env2.agents[5].position
    assert env.agents[6].position == env2.agents[6].position
    assert env.agents[7].position == env2.agents[7].position
    assert env.agents[8].position == env2.agents[8].position
    assert env.agents[9].position == env2.agents[9].position
    for a in range(env.get_num_agents()):
        print("assert env.agents[{}].position == env2.agents[{}].position".format(a, a))


def test_seeding_and_malfunction():
    # Test if two different instances diverge with different observations
97
98
    rail, rail_map, optionals = make_simple_rail2()
    optionals['agents_hints']['num_agents'] = 10
Erik Nygren's avatar
Erik Nygren committed
99
    stochastic_data = {'prop_malfunction': 0.4,
100
101
102
103
104
                       'malfunction_rate': 2,
                       'min_duration': 10,
                       'max_duration': 10}
    # Make two seperate envs with different and see if the exhibit the same malfunctions
    # Global Observation
105
    for tests in range(1, 100):
106
107
        env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                      line_generator=sparse_line_generator(), number_of_agents=10,
Erik Nygren's avatar
Erik Nygren committed
108
                      obs_builder_object=GlobalObsForRailEnv())
109

110
        # Tree Observation
111
112
        env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail, optionals),
                       line_generator=sparse_line_generator(), number_of_agents=10,
Erik Nygren's avatar
Erik Nygren committed
113
                       obs_builder_object=GlobalObsForRailEnv())
114

115
116
        env.reset(True, False, random_seed=tests)
        env2.reset(True, False, random_seed=tests)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

        # Check that both environments produce the same initial start positions
        assert env.agents[0].initial_position == env2.agents[0].initial_position
        assert env.agents[1].initial_position == env2.agents[1].initial_position
        assert env.agents[2].initial_position == env2.agents[2].initial_position
        assert env.agents[3].initial_position == env2.agents[3].initial_position
        assert env.agents[4].initial_position == env2.agents[4].initial_position
        assert env.agents[5].initial_position == env2.agents[5].initial_position
        assert env.agents[6].initial_position == env2.agents[6].initial_position
        assert env.agents[7].initial_position == env2.agents[7].initial_position
        assert env.agents[8].initial_position == env2.agents[8].initial_position
        assert env.agents[9].initial_position == env2.agents[9].initial_position

        action_dict = {}
        for step in range(10):
            for a in range(env.get_num_agents()):
                action = np.random.randint(4)
                action_dict[a] = action
                # print("----------------------")
                # print(env.agents[a].malfunction_data, env.agents[a].status)
                # print(env2.agents[a].malfunction_data, env2.agents[a].status)

Erik Nygren's avatar
Erik Nygren committed
139
140
141
142
143
            _, reward1, done1, _ = env.step(action_dict)
            _, reward2, done2, _ = env2.step(action_dict)
            for a in range(env.get_num_agents()):
                assert reward1[a] == reward2[a]
                assert done1[a] == done2[a]
144
145
146
147
148
149
150
151
152
153
154
155
        # Check that both environments end up in the same position

        assert env.agents[0].position == env2.agents[0].position
        assert env.agents[1].position == env2.agents[1].position
        assert env.agents[2].position == env2.agents[2].position
        assert env.agents[3].position == env2.agents[3].position
        assert env.agents[4].position == env2.agents[4].position
        assert env.agents[5].position == env2.agents[5].position
        assert env.agents[6].position == env2.agents[6].position
        assert env.agents[7].position == env2.agents[7].position
        assert env.agents[8].position == env2.agents[8].position
        assert env.agents[9].position == env2.agents[9].position
156
157
158
159
160
161
162
163
164
165
166
167
168


def test_reproducability_env():
    """
    Test that no random generators are present within the env that get influenced by external np random
    """
    speed_ration_map = {1.: 1.,  # Fast passenger train
                        1. / 2.: 0.,  # Fast freight train
                        1. / 3.: 0.,  # Slow commuter train
                        1. / 4.: 0.}  # Slow freight train

    env = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
                                                                            max_rails_between_cities=3,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
169
                                                                            seed=10,  # Random seed
170
171
                                                                            grid_mode=True
                                                                            ),
172
                  line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
173
    env.reset(True, True, random_seed=1)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
174
    excpeted_grid = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
Dipam Chakraborty's avatar
Dipam Chakraborty committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [0, 0, 0, 0, 0, 16386, 1025, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [16386, 17411, 1025, 5633, 17411, 3089, 1025, 1097, 5633, 17411, 1025, 5633, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 4608],
                    [32800, 32800, 0, 72, 3089, 5633, 1025, 17411, 1097, 2064, 0, 72, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 37408],
                    [32800, 32800, 0, 0, 0, 72, 1025, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
                    [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
                    [32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
                    [32800, 32872, 4608, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16386, 34864],
                    [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
                    [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
                    [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
                    [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
                    [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
                    [32800, 32800, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
                    [72, 37408, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800, 32800],
                    [0, 49186, 2064, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 72, 37408],
                    [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
                    [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
                    [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
                    [0, 32800, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32800],
                    [0, 32872, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 1025, 1025, 1025, 5633, 17411, 1025, 1025, 1025, 5633, 17411, 1025, 34864],
                    [0, 72, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1025, 1025, 1025, 1097, 3089, 1025, 1025, 1025, 1097, 3089, 1025, 2064],
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]

205
206
207
208
209
    assert env.rail.grid.tolist() == excpeted_grid

    # Test that we don't have interference from calling mulitple function outisde
    env2 = RailEnv(width=25, height=30, rail_generator=sparse_rail_generator(max_num_cities=5,
                                                                             max_rails_between_cities=3,
Dipam Chakraborty's avatar
Dipam Chakraborty committed
210
                                                                             seed=10,  # Random seed
211
212
                                                                             grid_mode=True
                                                                             ),
213
                   line_generator=sparse_line_generator(speed_ration_map), number_of_agents=1)
Dipam Chakraborty's avatar
Dipam Chakraborty committed
214
    np.random.seed(1)
215
216
    for i in range(10):
        np.random.randn()
Dipam Chakraborty's avatar
Dipam Chakraborty committed
217
    env2.reset(True, True, random_seed=1)
218
    assert env2.rail.grid.tolist() == excpeted_grid