diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index afaf2b7f496ef31c7d5228d1f389c254cc16df2e..bb5cd34e1b79ef9933984fbe480761cec1fd5711 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -1,6 +1,7 @@ import numpy as np from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator @@ -28,7 +29,7 @@ def test_get_global_observation(): grid_mode=False ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 4ce04e5e5e07ca5b5e864efffea951221322f306..75634a2299a599998dfeed0dd48e401d7795a794 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -1,5 +1,6 @@ import numpy as np +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -108,12 +109,12 @@ def test_seeding_and_malfunction(): for tests in range(1, 100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) env.reset(True, False, True, random_seed=tests) env2.reset(True, False, True, random_seed=tests)