diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index f8d1bc66b4f9a66c9657902aaa67ae42c9fd8c71..99082ddd764c62e16ad6b71bef7901333732cc4b 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -6,7 +6,7 @@ import numpy as np from numpy.random.mtrand import RandomState from flatland.envs.agent_utils import EnvAgent, RailAgentStatus -from flatland.envs import persistence +from flatland.envs import persistence Malfunction = NamedTuple('Malfunction', [('num_broken_steps', int)]) MalfunctionParameters = NamedTuple('MalfunctionParameters', @@ -25,7 +25,7 @@ def _malfunction_prob(rate: float) -> float: if rate <= 0: return 0. else: - return 1 - np.exp(- (1 / rate)) + return 1 - np.exp(-rate) def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[MalfunctionGenerator, MalfunctionProcessData]: @@ -42,7 +42,7 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun """ # with open(filename, "rb") as file_in: # load_data = file_in.read() - + # if filename.endswith("mpk"): # data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # elif filename.endswith("pkl"): @@ -52,7 +52,7 @@ def malfunction_from_file(filename: str, load_from_package=None) -> Tuple[Malfun if "malfunction" in env_dict: env_dict['malfunction'] = oMPD = MalfunctionProcessData._make(env_dict['malfunction']) else: - oMPD=None + oMPD = None if oMPD is not None: # Mean malfunction in number of time steps mean_malfunction_rate = oMPD.malfunction_rate diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py index 464e7523db805bd1a45441c0666f5d37245439c1..eaa3112708f3f0e5d255b7e454078d9a59e7ca22 100644 --- a/tests/test_flatland_malfunction.py +++ b/tests/test_flatland_malfunction.py @@ -119,9 +119,9 @@ def test_malfunction_process(): def test_malfunction_process_statistically(): - """Tests hat malfunctions are produced by stochastic_data!""" + """Tests that malfunctions are produced by stochastic_data!""" # Set fixed malfunction duration for this test - stochastic_data = MalfunctionParameters(malfunction_rate=5, # Rate of malfunction occurence + stochastic_data = MalfunctionParameters(malfunction_rate=1/5, # Rate of malfunction occurence min_duration=5, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) @@ -168,7 +168,7 @@ def test_malfunction_process_statistically(): def test_malfunction_before_entry(): """Tests that malfunctions are working properly for agents before entering the environment!""" # Set fixed malfunction duration for this test - stochastic_data = MalfunctionParameters(malfunction_rate=2, # Rate of malfunction occurence + stochastic_data = MalfunctionParameters(malfunction_rate=1/2, # Rate of malfunction occurrence min_duration=10, # Minimal duration of malfunction max_duration=10 # Max duration of malfunction ) @@ -215,7 +215,7 @@ def test_malfunction_values_and_behavior(): rail, rail_map = make_simple_rail2() action_dict: Dict[int, RailEnvActions] = {} - stochastic_data = MalfunctionParameters(malfunction_rate=0.001, # Rate of malfunction occurence + stochastic_data = MalfunctionParameters(malfunction_rate=1/0.001, # Rate of malfunction occurence min_duration=10, # Minimal duration of malfunction max_duration=10 # Max duration of malfunction ) @@ -241,7 +241,7 @@ def test_malfunction_values_and_behavior(): def test_initial_malfunction(): - stochastic_data = MalfunctionParameters(malfunction_rate=1000, # Rate of malfunction occurence + stochastic_data = MalfunctionParameters(malfunction_rate=1/1000, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction ) @@ -390,7 +390,7 @@ def test_initial_malfunction_stop_moving(): def test_initial_malfunction_do_nothing(): - stochastic_data = MalfunctionParameters(malfunction_rate=70, # Rate of malfunction occurence + stochastic_data = MalfunctionParameters(malfunction_rate=1/70, # Rate of malfunction occurence min_duration=2, # Minimal duration of malfunction max_duration=5 # Max duration of malfunction )