diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 89af545ba26e59b3cc399b4b775ba5846a7b90a5..129da81cb04cc3a3c02fa626036d8e09cb942cd2 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -42,7 +42,7 @@ def malfunction_from_file(filename) -> Tuple[MalfunctionGenerator, MalfunctionPr load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # TODO: make this better by using namedtuple in the pickle file - data['malfunction'] =MalfunctionProcessData._make(data['malfunction']) + data['malfunction'] = MalfunctionProcessData._make(data['malfunction']) if "malfunction" in data: # Mean malfunction in number of time steps mean_malfunction_rate = data["malfunction"].malfunction_rate diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 9025d415eef78c3b68495b1a5ae09c45652811bb..81c3a3569642c680b5b0c7246fb6d0b94685d9ed 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -19,8 +19,7 @@ from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap -from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionGenerator, \ - MalfunctionProcessData +from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator @@ -160,7 +159,7 @@ class RailEnv(Environment): """ super().__init__() - self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data + self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator self.rail: Optional[GridTransitionMap] = None diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index e01c74830959a392a5892ab8a43336613463e937..075edc139b6786933a32c915998c0fe56cb7a76c 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -23,12 +23,17 @@ def test_malfanction_from_params(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), number_of_agents=1) + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) + ) env.reset() - assert env.mean_malfunction_rate == 1000 - assert env.min_number_of_steps_broken == 2 - assert env.max_number_of_steps_broken == 5 + assert env.malfunction_process_data.malfunction_rate == 1000 + assert env.malfunction_process_data.min_duration == 2 + assert env.malfunction_process_data.max_duration == 5 def test_malfanction_to_and_from_file():