Skip to content
Snippets Groups Projects
Commit 0dd22f5e authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

updated test for seeding

parent 145e254e
No related branches found
No related tags found
No related merge requests found
import numpy as np
from flatland.envs.rail_env import RailEnv from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import rail_from_grid_transition_map from flatland.envs.rail_generators import rail_from_grid_transition_map
from flatland.envs.schedule_generators import random_schedule_generator from flatland.envs.schedule_generators import random_schedule_generator
...@@ -8,38 +6,30 @@ from flatland.utils.simple_rail import make_simple_rail2 ...@@ -8,38 +6,30 @@ from flatland.utils.simple_rail import make_simple_rail2
def test_random_seeding(): def test_random_seeding():
# Set fixed malfunction duration for this test # Set fixed malfunction duration for this test
stochastic_data = {'prop_malfunction': 1.,
'malfunction_rate': 1000,
'min_duration': 3,
'max_duration': 3}
rail, rail_map = make_simple_rail2() rail, rail_map = make_simple_rail2()
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=1,
stochastic_data=stochastic_data, # Malfunction data generator
)
# reset to initialize agents_static
obs, info = env.reset(True, True, False, random_seed=0)
env.agents[0].target = (0, 0)
# Move target to unreachable position in order to not interfere with test # Move target to unreachable position in order to not interfere with test
for idx in range(4): for idx in range(1000):
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=0),
number_of_agents=10
)
env.reset(True, True, False, random_seed=0) env.reset(True, True, False, random_seed=0)
np.random.seed(0)
# Test generation print # Test generation print
print("assert env.agents[0].initial_position == {}".format(env.agents[0].initial_position))
env.agents[0].target = (0, 0) env.agents[0].target = (0, 0)
# assert env.agents[0].initial_position == (3, 3)
for step in range(10): for step in range(10):
actions = {} actions = {}
actions[0] = 2
for i in range(len(obs)):
actions[i] = np.random.randint(4)
env.step(actions) env.step(actions)
#assert env.agents[0].position == (3, 9) agent_positions = []
for a in range(env.get_num_agents()):
agent_positions += env.agents[a].initial_position
# print(agent_positions)
assert agent_positions == [1, 3, 3, 3, 3, 5, 3, 6, 4, 6, 3, 1, 2, 3, 5, 6, 3, 7, 3, 4]
# Test generation print # Test generation print
print("assert env.agents[0].position == {}".format(env.agents[0].position)) assert env.agents[0].position == (3, 7)
# print("env.agents[0].initial_position == {}".format(env.agents[0].initial_position))
#print("assert env.agents[0].position == {}".format(env.agents[0].position))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment