From 8b5566f8bc7bbaa0b13656e6dee26c074e5b29a0 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 31 Oct 2019 11:14:13 -0400 Subject: [PATCH] Introducing malfunction_generators This resolves issue #273 added test for saving and loading malfunction parameters --- tests/tests_malfunction_generators.py | 78 +++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/tests_malfunction_generators.py diff --git a/tests/tests_malfunction_generators.py b/tests/tests_malfunction_generators.py new file mode 100644 index 00000000..fa455b75 --- /dev/null +++ b/tests/tests_malfunction_generators.py @@ -0,0 +1,78 @@ +import random +from typing import Dict, List + +import numpy as np +from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay + +from flatland.core.env_observation_builder import ObservationBuilder +from flatland.core.grid.grid4 import Grid4TransitionsEnum +from flatland.core.grid.grid4_utils import get_new_position +from flatland.envs.agent_utils import RailAgentStatus +from flatland.envs.malfunction_generators import malfunction_from_params, malfunction_from_file +from flatland.envs.rail_env import RailEnv, RailEnvActions +from flatland.envs.rail_generators import rail_from_grid_transition_map +from flatland.envs.schedule_generators import random_schedule_generator +from flatland.utils.simple_rail import make_simple_rail2 + + +def test_malfanction_from_params(): + """ + Test loading malfunction from + Returns + ------- + + """ + stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + rail, rail_map = make_simple_rail2() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=10), + number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data)) + env.reset() + assert env.mean_malfunction_rate == 1000 + assert env.min_number_of_steps_broken == 2 + assert env.max_number_of_steps_broken == 5 + +def test_malfanction_to_and_from_file(): + """ + Test loading malfunction from + Returns + ------- + + """ + stochastic_data = {'malfunction_rate': 1000, # Rate of malfunction occurence + 'min_duration': 2, # Minimal duration of malfunction + 'max_duration': 5 # Max duration of malfunction + } + + rail, rail_map = make_simple_rail2() + + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=10), + number_of_agents=1, + malfunction_generator=malfunction_from_params(stochastic_data)) + + env.reset() + env.save("./malfunction_saving_loading_tests.pkl") + + env2 = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(seed=10), + number_of_agents=1, + malfunction_generator=malfunction_from_file("./malfunction_saving_loading_tests.pkl")) + + env2.reset() + + assert env2.mean_malfunction_rate == 1000 + assert env2.min_number_of_steps_broken == 2 + assert env2.max_number_of_steps_broken == 5 -- GitLab