diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 22886594e47e4f1777aed01f200c2218ede511ab..8265dffa5b5a6998a740202d1d0a7d7431b1c518 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -234,8 +234,8 @@ class RailEnv(Environment):
         # TODO can we not put 'self.rail_generator(..)' into 'if regen_rail or self.rail is None' condition?
         rail, optionals = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
 
-        if optionals and 'distance_maps' in optionals:
-            self.distance_map = optionals['distance_maps']
+        if optionals and 'distance_map' in optionals:
+            self.distance_map = optionals['distance_map']
 
         if regen_rail or self.rail is None:
             self.rail = rail
@@ -575,8 +575,8 @@ class RailEnv(Environment):
         # agents are always reset as not moving
         self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data["agents_static"]]
         self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8]) for d in data["agents"]]
-        if "distance_maps" in data.keys():
-            self.distance_map = data["distance_maps"]
+        if "distance_map" in data.keys():
+            self.distance_map = data["distance_map"]
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -596,7 +596,7 @@ class RailEnv(Environment):
             "grid": grid_data,
             "agents_static": agent_static_data,
             "agents": agent_data,
-            "distance_maps": distance_map_data}
+            "distance_map": distance_map_data}
 
         return msgpack.packb(msg_data, use_bin_type=True)
 
diff --git a/flatland/envs/rail_generators.py b/flatland/envs/rail_generators.py
index a16fb6018a6354665a44c1b44cafd6975bb4e680..e5f0a8e8dcbe81660727e4eb04f5d4a0f636b5d4 100644
--- a/flatland/envs/rail_generators.py
+++ b/flatland/envs/rail_generators.py
@@ -226,10 +226,10 @@ def rail_from_file(filename) -> RailGenerator:
         grid = np.array(data[b"grid"])
         rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
         rail.grid = grid
-        if b"distance_maps" in data.keys():
-            distance_maps = data[b"distance_maps"]
-            if len(distance_maps) > 0:
-                return rail, {'distance_maps': distance_maps}
+        if b"distance_map" in data.keys():
+            distance_map = data[b"distance_map"]
+            if len(distance_map) > 0:
+                return rail, {'distance_map': distance_map}
         return [rail, None]
 
     return generator