From 53dd3c55c3ae490f83d89aa7df0d423a4f92d348 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Tue, 17 Sep 2019 10:37:05 +0200 Subject: [PATCH] Refactoring: rename distance_maps to distance_map --- flatland/envs/rail_env.py | 10 +++++----- flatland/envs/rail_generators.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 22886594..8265dffa 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -234,8 +234,8 @@ class RailEnv(Environment): # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) - if optionals and 'distance_maps' in optionals: - self.distance_map = optionals['distance_maps'] + if optionals and 'distance_map' in optionals: + self.distance_map = optionals['distance_map'] if regen_rail or self.rail is None: self.rail = rail @@ -575,8 +575,8 @@ class RailEnv(Environment): # agents are always reset as not moving self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] - if "distance_maps" in data.keys(): - self.distance_map = data["distance_maps"] + if "distance_map" in data.keys(): + self.distance_map = data["distance_map"] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -596,7 +596,7 @@ class RailEnv(Environment): "grid": grid_data, "agents_static": agent_static_data, "agents": agent_data, - "distance_maps": distance_map_data} + "distance_map": distance_map_data} return msgpack.packb(msg_data, use_bin_type=True) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a16fb601..e5f0a8e8 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -226,10 +226,10 @@ def rail_from_file(filename) -> RailGenerator: grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid - if b"distance_maps" in data.keys(): - distance_maps = data[b"distance_maps"] - if len(distance_maps) > 0: - return rail, {'distance_maps': distance_maps} + if b"distance_map" in data.keys(): + distance_map = data[b"distance_map"] + if len(distance_map) > 0: + return rail, {'distance_map': distance_map} return [rail, None] return generator -- GitLab