From c7f100730b9ecb72363b61bb1516899d07c903f4 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 31 Oct 2019 11:06:39 -0400 Subject: [PATCH] Introducing malfunction_generators This resolves issue #273 fixed tests --- tests/test_global_observation.py | 3 ++- tests/test_random_seeding.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_global_observation.py b/tests/test_global_observation.py index afaf2b7f..bb5cd34e 100644 --- a/tests/test_global_observation.py +++ b/tests/test_global_observation.py @@ -1,6 +1,7 @@ import numpy as np from flatland.envs.agent_utils import EnvAgent, RailAgentStatus +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv, RailEnvActions from flatland.envs.rail_generators import sparse_rail_generator @@ -28,7 +29,7 @@ def test_get_global_observation(): grid_mode=False ), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=number_of_agents, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) env.reset() obs, all_rewards, done, _ = env.step({i: RailEnvActions.MOVE_FORWARD for i in range(number_of_agents)}) diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py index 4ce04e5e..75634a22 100644 --- a/tests/test_random_seeding.py +++ b/tests/test_random_seeding.py @@ -1,5 +1,6 @@ import numpy as np +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -108,12 +109,12 @@ def test_seeding_and_malfunction(): for tests in range(1, 100): env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) # Tree Observation env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), schedule_generator=random_schedule_generator(), number_of_agents=10, - obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=stochastic_data) + obs_builder_object=GlobalObsForRailEnv(), malfunction_generator=malfunction_from_params(stochastic_data)) env.reset(True, False, True, random_seed=tests) env2.reset(True, False, True, random_seed=tests) -- GitLab