Skip to content
Snippets Groups Projects
Commit 5f4773a1 authored by u229589's avatar u229589
Browse files

Refactoring: distance_map is initialized in constructor, so checks like...

Refactoring: distance_map is initialized in constructor, so checks like hasattr(self, 'distance_map') are unnecessary
parent 60b98c1d
No related branches found
No related tags found
No related merge requests found
......@@ -575,7 +575,7 @@ class RailEnv(Environment):
# agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
if hasattr(self, 'distance_map') and "distance_maps" in data.keys():
if "distance_maps" in data.keys():
self.distance_map = data["distance_maps"]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
......@@ -590,24 +590,18 @@ class RailEnv(Environment):
msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True)
if hasattr(self, 'distance_map'):
distance_map_data = self.distance_map
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_maps": distance_map_data}
else:
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
distance_map_data = self.distance_map
msgpack.packb(distance_map_data, use_bin_type=True)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_maps": distance_map_data}
return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename):
if hasattr(self, 'distance_map') and self.distance_map is not None:
if self.distance_map is not None:
if len(self.distance_map) > 0:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_dist_msg())
......@@ -619,14 +613,9 @@ class RailEnv(Environment):
file_out.write(self.get_full_state_msg())
def load(self, filename):
if hasattr(self, 'distance_map'):
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_dist_msg(load_data)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_msg(load_data)
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_dist_msg(load_data)
def load_pkl(self, pkl_data):
self.set_full_state_msg(pkl_data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment