From cd33c5c7e0908c42b992b504496f58991287d6e6 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Wed, 25 Sep 2019 09:37:56 +0200
Subject: [PATCH] add visualization for distance map

---
 flatland/envs/rail_env_utils.py | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py
index dc1cff12..c12a3882 100644
--- a/flatland/envs/rail_env_utils.py
+++ b/flatland/envs/rail_env_utils.py
@@ -1,3 +1,7 @@
+import numpy as np
+import matplotlib.pyplot as plt
+
+from flatland.envs.distance_map import DistanceMap
 from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
@@ -17,3 +21,11 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b
                           schedule_generator=schedule_from_file(file_name, load_from_package),
                           obs_builder_object=obs_builder_object)
     return environment
+
+
+def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0):
+    assert agent_handle < distance_map.get().shape[0]
+    # take min value of all 4 directions
+    min_distance_map = np.min(distance_map.get(), axis=3)
+    plt.imshow(min_distance_map[agent_handle][:][:])
+    plt.show()
-- 
GitLab