Skip to content
Snippets Groups Projects
Commit 5f4773a1 authored by u229589's avatar u229589
Browse files

Refactoring: distance_map is initialized in constructor, so checks like...

Refactoring: distance_map is initialized in constructor, so checks like hasattr(self, 'distance_map') are unnecessary
parent 60b98c1d
No related branches found
No related tags found
No related merge requests found
...@@ -575,7 +575,7 @@ class RailEnv(Environment): ...@@ -575,7 +575,7 @@ class RailEnv(Environment):
# agents are always reset as not moving # agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
if hasattr(self, 'distance_map') and "distance_maps" in data.keys(): if "distance_maps" in data.keys():
self.distance_map = data["distance_maps"] self.distance_map = data["distance_maps"]
# setup with loaded data # setup with loaded data
self.height, self.width = self.rail.grid.shape self.height, self.width = self.rail.grid.shape
...@@ -590,24 +590,18 @@ class RailEnv(Environment): ...@@ -590,24 +590,18 @@ class RailEnv(Environment):
msgpack.packb(grid_data, use_bin_type=True) msgpack.packb(grid_data, use_bin_type=True)
msgpack.packb(agent_data, use_bin_type=True) msgpack.packb(agent_data, use_bin_type=True)
msgpack.packb(agent_static_data, use_bin_type=True) msgpack.packb(agent_static_data, use_bin_type=True)
if hasattr(self, 'distance_map'): distance_map_data = self.distance_map
distance_map_data = self.distance_map msgpack.packb(distance_map_data, use_bin_type=True)
msgpack.packb(distance_map_data, use_bin_type=True) msg_data = {
msg_data = { "grid": grid_data,
"grid": grid_data, "agents_static": agent_static_data,
"agents_static": agent_static_data, "agents": agent_data,
"agents": agent_data, "distance_maps": distance_map_data}
"distance_maps": distance_map_data}
else:
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True) return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename): def save(self, filename):
if hasattr(self, 'distance_map') and self.distance_map is not None: if self.distance_map is not None:
if len(self.distance_map) > 0: if len(self.distance_map) > 0:
with open(filename, "wb") as file_out: with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_dist_msg()) file_out.write(self.get_full_state_dist_msg())
...@@ -619,14 +613,9 @@ class RailEnv(Environment): ...@@ -619,14 +613,9 @@ class RailEnv(Environment):
file_out.write(self.get_full_state_msg()) file_out.write(self.get_full_state_msg())
def load(self, filename): def load(self, filename):
if hasattr(self, 'distance_map'): with open(filename, "rb") as file_in:
with open(filename, "rb") as file_in: load_data = file_in.read()
load_data = file_in.read() self.set_full_state_dist_msg(load_data)
self.set_full_state_dist_msg(load_data)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_msg(load_data)
def load_pkl(self, pkl_data): def load_pkl(self, pkl_data):
self.set_full_state_msg(pkl_data) self.set_full_state_msg(pkl_data)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment