diff --git a/flatland/envs/malfunction_generators.py b/flatland/envs/malfunction_generators.py index 1810b12c00bb84dd0668bc26f7075c36fc4e7655..89af545ba26e59b3cc399b4b775ba5846a7b90a5 100644 --- a/flatland/envs/malfunction_generators.py +++ b/flatland/envs/malfunction_generators.py @@ -41,14 +41,15 @@ def malfunction_from_file(filename) -> Tuple[MalfunctionGenerator, MalfunctionPr with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') - + # TODO: make this better by using namedtuple in the pickle file + data['malfunction'] =MalfunctionProcessData._make(data['malfunction']) if "malfunction" in data: # Mean malfunction in number of time steps - mean_malfunction_rate = data["malfunction"]["malfunction_rate"] + mean_malfunction_rate = data["malfunction"].malfunction_rate # Uniform distribution parameters for malfunction duration - min_number_of_steps_broken = data["malfunction"]["min_duration"] - max_number_of_steps_broken = data["malfunction"]["max_duration"] + min_number_of_steps_broken = data["malfunction"].min_duration + max_number_of_steps_broken = data["malfunction"].max_duration else: # Mean malfunction in number of time steps mean_malfunction_rate = 0. diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index cba49f73d80a6aecd1a3135d47eca2ffcd1d8e5f..9025d415eef78c3b68495b1a5ae09c45652811bb 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -19,7 +19,8 @@ from flatland.core.grid.grid_utils import IntVector2D from flatland.core.transition_map import GridTransitionMap from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus from flatland.envs.distance_map import DistanceMap -from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction +from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionGenerator, \ + MalfunctionProcessData from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_generators import random_rail_generator, RailGenerator from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator @@ -159,7 +160,7 @@ class RailEnv(Environment): """ super().__init__() - self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data + self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data self.rail_generator: RailGenerator = rail_generator self.schedule_generator: ScheduleGenerator = schedule_generator self.rail: Optional[GridTransitionMap] = None @@ -802,8 +803,7 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] - malfunction_data = {"malfunction_process_data": self.malfunction_process_data} - + malfunction_data: MalfunctionProcessData = self.malfunction_process_data msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True) @@ -825,7 +825,7 @@ class RailEnv(Environment): msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True) distance_map_data = self.distance_map.get() - malfunction_data = {"malfunction_process_data": self.malfunction_process_data} + malfunction_data: MalfunctionProcessData = self.malfunction_process_data msgpack.packb(distance_map_data, use_bin_type=True) msg_data = { "grid": grid_data, diff --git a/tests/test_malfunction_generators.py b/tests/test_malfunction_generators.py index 779af28c0824e108285412bd67a601dced6e6fe9..e01c74830959a392a5892ab8a43336613463e937 100644 --- a/tests/test_malfunction_generators.py +++ b/tests/test_malfunction_generators.py @@ -45,18 +45,29 @@ def test_malfanction_to_and_from_file(): rail, rail_map = make_simple_rail2() - env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), number_of_agents=1) - + env = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) + ) env.reset() env.save("./malfunction_saving_loading_tests.pkl") malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl") - env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail), - schedule_generator=random_schedule_generator(seed=10), number_of_agents=1) + env2 = RailEnv(width=25, + height=30, + rail_generator=rail_from_grid_transition_map(rail), + schedule_generator=random_schedule_generator(), + number_of_agents=10, + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data) + ) env2.reset() - assert env2.mean_malfunction_rate == 1000 - assert env2.min_number_of_steps_broken == 2 - assert env2.max_number_of_steps_broken == 5 + assert env2.malfunction_process_data == env.malfunction_process_data + assert env2.malfunction_process_data.malfunction_rate == 1000 + assert env2.malfunction_process_data.min_duration == 2 + assert env2.malfunction_process_data.max_duration == 5 +