Skip to content
Snippets Groups Projects
Commit ac9dae25 authored by Erik Nygren's avatar Erik Nygren
Browse files

fixed saving and loading of agents with speed and malfunction rates

parent 212ee6a9
No related branches found
No related tags found
No related merge requests found
......@@ -30,10 +30,11 @@ class EnvAgentStatic(object):
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0})))
@classmethod
def from_lists(cls, positions, directions, targets, speeds=None):
def from_lists(cls, positions, directions, targets, speeds=None, malfunction_rates=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets
"""
speed_datas = []
for i in range(len(positions)):
speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0,
......@@ -41,10 +42,11 @@ class EnvAgentStatic(object):
# TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set
# some as broken?
malfunction_datas = []
for i in range(len(positions)):
malfunction_datas.append({'malfunction': 0,
'malfunction_rate': 0,
'malfunction_rate': malfunction_rates[i] if malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0})
......
......@@ -238,7 +238,6 @@ class RailEnv(Environment):
agents_hints = optionals['agents_hints']
self.agents_static = EnvAgentStatic.from_lists(
*self.schedule_generator(self.rail, self.get_num_agents(), hints=agents_hints))
self.restart_agents()
for i_agent in range(self.get_num_agents()):
......@@ -248,7 +247,6 @@ class RailEnv(Environment):
if np.random.random() < self.proportion_malfunctioning_trains:
agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
agent.speed_data['position_fraction'] = 0.0
agent.malfunction_data['malfunction'] = 0
self._agent_malfunction(agent)
......@@ -510,9 +508,9 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data)
msgpack.packb(agent_data)
msgpack.packb(agent_static_data)
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
......@@ -526,11 +524,11 @@ class RailEnv(Environment):
return msgpack.packb(msg_data, use_bin_type=True)
def set_full_state_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"])
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]]
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
......@@ -538,13 +536,13 @@ class RailEnv(Environment):
self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False)
def set_full_state_dist_msg(self, msg_data):
data = msgpack.unpackb(msg_data, use_list=False)
self.rail.grid = np.array(data[b"grid"])
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]]
if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys():
self.obs_builder.distance_map = data[b"distance_maps"]
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
if hasattr(self.obs_builder, 'distance_map') and "distance_maps" in data.keys():
self.obs_builder.distance_map = data["distance_maps"]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
......@@ -555,12 +553,12 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data)
msgpack.packb(agent_data)
msgpack.packb(agent_static_data)
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
if hasattr(self.obs_builder, 'distance_map'):
distance_map_data = self.obs_builder.distance_map
msgpack.packb(distance_map_data)
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
......
......@@ -224,16 +224,19 @@ def schedule_from_file(filename) -> ScheduleGenerator:
def generator(rail: GridTransitionMap, num_agents: int, hints: Any = None) -> ScheduleGeneratorProduct:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
# agents are always reset as not moving
agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data[b"agents_static"]]
agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]]
# setup with loaded data
agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
agents_speed = [a.speed_data[b'speed'] for a in agents_static]
return agents_position, agents_direction, agents_target, agents_speed
agents_speed = [a.speed_data['speed'] for a in agents_static]
agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static]
return agents_position, agents_direction, agents_target, agents_speed, agents_malfunction
return generator
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment