From 42d8682e60b543d1416859926697a7b83886384b Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Mon, 4 Nov 2019 14:32:44 -0500 Subject: [PATCH] fixed tests --- flatland/envs/malfunction_generators.py | 2 +- flatland/envs/rail_env.py | 5 ++--- tests/test_malfunction_generators.py | 15 ++++++++++----- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 89af545b..129da81c 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -42,7 +42,7 @@ def malfunction_from_file(filename) -> Tuple[MalfunctionGenerator, MalfunctionPr load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # TODO: make this better by using namedtuple in the pickle file - data['malfunction'] =MalfunctionProcessData._make(data['malfunction']) + data['malfunction'] = MalfunctionProcessData._make(data['malfunction']) if "malfunction" in data: # Mean malfunction in number of time steps mean_malfunction_rate = data["malfunction"].malfunction_rate diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 9025d415..81c3a356 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -19,8 +19,7 @@ from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap -from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionGenerator, \ - MalfunctionProcessData +from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionProcessData from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator @@ -160,7 +159,7 @@ class RailEnv(Environment): """ super().__init__() - self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data + self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator self.rail: Optional[GridTransitionMap] = None diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index e01c7483..075edc13 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -23,12 +23,17 @@ def test_malfanction_from_params(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), number_of_agents=1) + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) + ) env.reset() - assert env.mean_malfunction_rate == 1000 - assert env.min_number_of_steps_broken == 2 - assert env.max_number_of_steps_broken == 5 + assert env.malfunction_process_data.malfunction_rate == 1000 + assert env.malfunction_process_data.min_duration == 2 + assert env.malfunction_process_data.max_duration == 5 def test_malfanction_to_and_from_file(): -- GitLab