diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index dc1cff12c0c8b1860859208a13d6403734a2d2ad..69cfce764fe124d9e3eb05019e1c734f2285bdb5 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -1,3 +1,7 @@ +import numpy as np +import matplotlib.pyplot as plt + +from flatland.envs.distance_map import DistanceMap from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -17,3 +21,13 @@ def load_flatland_environment_from_file(file_name, load_from_package=None, obs_b schedule_generator=schedule_from_file(file_name, load_from_package), obs_builder_object=obs_builder_object) return environment + + +def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): + if agent_handle >= distance_map.get().shape[0]: + print("Error: agent_handle cannot be larger than actual number of agents") + return + # take min value of all 4 directions + min_distance_map = np.min(distance_map.get(), axis=3) + plt.imshow(min_distance_map[agent_handle][:][:]) + plt.show()