diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 22886594e47e4f1777aed01f200c2218ede511ab..8265dffa5b5a6998a740202d1d0a7d7431b1c518 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -234,8 +234,8 @@ class RailEnv(Environment): # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition? rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) - if optionals and 'distance_maps' in optionals: - self.distance_map = optionals['distance_maps'] + if optionals and 'distance_map' in optionals: + self.distance_map = optionals['distance_map'] if regen_rail or self.rail is None: self.rail = rail @@ -575,8 +575,8 @@ class RailEnv(Environment): # agents are always reset as not moving self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]] - if "distance_maps" in data.keys(): - self.distance_map = data["distance_maps"] + if "distance_map" in data.keys(): + self.distance_map = data["distance_map"] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -596,7 +596,7 @@ class RailEnv(Environment): "grid": grid_data, "agents_static": agent_static_data, "agents": agent_data, - "distance_maps": distance_map_data} + "distance_map": distance_map_data} return msgpack.packb(msg_data, use_bin_type=True) diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py index a16fb6018a6354665a44c1b44cafd6975bb4e680..e5f0a8e8dcbe81660727e4eb04f5d4a0f636b5d4 100644 --- a/flatland/envs/rail_generators.py +++ b/flatland/envs/rail_generators.py @@ -226,10 +226,10 @@ def rail_from_file(filename) -> RailGenerator: grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid - if b"distance_maps" in data.keys(): - distance_maps = data[b"distance_maps"] - if len(distance_maps) > 0: - return rail, {'distance_maps': distance_maps} + if b"distance_map" in data.keys(): + distance_map = data[b"distance_map"] + if len(distance_map) > 0: + return rail, {'distance_map': distance_map} return [rail, None] return generator