diff --git a/flatland/envs/rail_env_utils.py b/flatland/envs/rail_env_utils.py index 8b29bbe851478450fae91e002222eef0099c8736..0e305f48e14498382e1ad826b74e09cb469fee24 100644 --- a/flatland/envs/rail_env_utils.py +++ b/flatland/envs/rail_env_utils.py @@ -1,6 +1,7 @@ import math from typing import Tuple, Set, Dict, List, NamedTuple +import matplotlib.pyplot as plt import numpy as np from flatland.core.grid.grid4 import Grid4TransitionsEnum @@ -116,3 +117,13 @@ def get_shortest_paths(distance_map: DistanceMap) -> Dict[int, List[WalkingEleme RailEnvNextAction(RailEnvActions.STOP_MOVING, position, direction))) return shortest_paths + + +def visualize_distance_map(distance_map: DistanceMap, agent_handle: int = 0): + if agent_handle >= distance_map.get().shape[0]: + print("Error: agent_handle cannot be larger than actual number of agents") + return + # take min value of all 4 directions + min_distance_map = np.min(distance_map.get(), axis=3) + plt.imshow(min_distance_map[agent_handle][:][:]) + plt.show()