Skip to content
Snippets Groups Projects
Commit 9e99adb7 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

fixed loading and saving of new level gernator objects

parent e3165bb0
No related branches found
No related tags found
No related merge requests found
......@@ -41,14 +41,15 @@ def malfunction_from_file(filename) -> Tuple[MalfunctionGenerator, MalfunctionPr
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
# TODO: make this better by using namedtuple in the pickle file
data['malfunction'] =MalfunctionProcessData._make(data['malfunction'])
if "malfunction" in data:
# Mean malfunction in number of time steps
mean_malfunction_rate = data["malfunction"]["malfunction_rate"]
mean_malfunction_rate = data["malfunction"].malfunction_rate
# Uniform distribution parameters for malfunction duration
min_number_of_steps_broken = data["malfunction"]["min_duration"]
max_number_of_steps_broken = data["malfunction"]["max_duration"]
min_number_of_steps_broken = data["malfunction"].min_duration
max_number_of_steps_broken = data["malfunction"].max_duration
else:
# Mean malfunction in number of time steps
mean_malfunction_rate = 0.
......
......@@ -19,7 +19,8 @@ from flatland.core.grid.grid_utils import IntVector2D
from flatland.core.transition_map import GridTransitionMap
from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent, RailAgentStatus
from flatland.envs.distance_map import DistanceMap
from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction
from flatland.envs.malfunction_generators import no_malfunction_generator, Malfunction, MalfunctionGenerator, \
MalfunctionProcessData
from flatland.envs.observations import GlobalObsForRailEnv
from flatland.envs.rail_generators import random_rail_generator, RailGenerator
from flatland.envs.schedule_generators import random_schedule_generator, ScheduleGenerator
......@@ -159,7 +160,7 @@ class RailEnv(Environment):
"""
super().__init__()
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
self.malfunction_generator, self.malfunction_process_data = malfunction_generator_and_process_data
self.rail_generator: RailGenerator = rail_generator
self.schedule_generator: ScheduleGenerator = schedule_generator
self.rail: Optional[GridTransitionMap] = None
......@@ -802,8 +803,7 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
malfunction_data = {"malfunction_process_data": self.malfunction_process_data}
malfunction_data: MalfunctionProcessData = self.malfunction_process_data
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
......@@ -825,7 +825,7 @@ class RailEnv(Environment):
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
distance_map_data = self.distance_map.get()
malfunction_data = {"malfunction_process_data": self.malfunction_process_data}
malfunction_data: MalfunctionProcessData = self.malfunction_process_data
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
......
......@@ -45,18 +45,29 @@ def test_malfanction_to_and_from_file():
rail, rail_map = make_simple_rail2()
env = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=10), number_of_agents=1)
env = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env.reset()
env.save("./malfunction_saving_loading_tests.pkl")
malfunction_generator, malfunction_process_data = malfunction_from_file("./malfunction_saving_loading_tests.pkl")
env2 = RailEnv(width=25, height=30, rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(seed=10), number_of_agents=1)
env2 = RailEnv(width=25,
height=30,
rail_generator=rail_from_grid_transition_map(rail),
schedule_generator=random_schedule_generator(),
number_of_agents=10,
malfunction_generator_and_process_data=malfunction_from_params(stochastic_data)
)
env2.reset()
assert env2.mean_malfunction_rate == 1000
assert env2.min_number_of_steps_broken == 2
assert env2.max_number_of_steps_broken == 5
assert env2.malfunction_process_data == env.malfunction_process_data
assert env2.malfunction_process_data.malfunction_rate == 1000
assert env2.malfunction_process_data.min_duration == 2
assert env2.malfunction_process_data.max_duration == 5
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment