diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index c72fc5190c106b0d28a8b2df84b7c3009d7404c9..a5a46923ddeafe66b653632eb8c4f29a7f3466ad 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -151,7 +151,7 @@ def test_malfunction_process_statistically(): obs_builder_object=SingleAgentNavigationObs() ) # reset to initialize agents_static - env.reset(False, False, False, random_seed=0) + env.reset(True, True, False, random_seed=0) env.agents[0].target = (0, 0) nb_malfunction = 0 diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py new file mode 100644 index 0000000000000000000000000000000000000000..67c02e863f94a342f6e72eac1a0f465f2c980927 --- /dev/null +++ b/tests/test_random_seeding.py @@ -0,0 +1,47 @@ +import random + +import numpy as np + +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_simple_rail2 + + +def test_random_seeding(): + # Set fixed malfunction duration for this test + stochastic_data = {'prop_malfunction': 1., + 'malfunction_rate': 1000, + 'min_duration': 3, + 'max_duration': 3} + + rail, rail_map = make_simple_rail2() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=1, + stochastic_data=stochastic_data, # Malfunction data generator + ) + # reset to initialize agents_static + obs, info = env.reset(True, True, False, random_seed=0) + env.agents[0].target = (0, 0) + assert env.agents[0].initial_position == (3, 3) + # Move target to unreachable position in order to not interfere with test + for idx in range(2): + env.reset(True, True, False, random_seed=0) + # Test generation print + # print("assert env.agents[0].initial_position == {}".format(env.agents[0].initial_position)) + env.agents[0].target = (0, 0) + assert env.agents[0].initial_position == (3, 3) + for step in range(3): + actions = {} + + for i in range(len(obs)): + actions[i] = np.random.randint(4) + env.step(actions) + assert env.agents[0].position == (3, 9) + # Test generation print + # print("assert env.agents[0].position == {}".format(env.agents[0].position))