Skip to content
Snippets Groups Projects
Commit 30e36369 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated how generator check for distance map data.

Updated test for generators to check that distance map is not recomputed when loaded correctly
parent 2a12614c
No related branches found
No related tags found
No related merge requests found
......@@ -218,6 +218,7 @@ def rail_from_file(filename):
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False)
grid = np.array(data[b"grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid
......@@ -227,7 +228,7 @@ def rail_from_file(filename):
agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static]
if len(data) > 3:
if b"distance_maps" in data.keys():
distance_maps = data[b"distance_maps"]
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
else:
......
......@@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
self.distance_map = None
self.distance_map_computed = False
def reset(self):
agents = self.env.agents
......@@ -64,6 +65,9 @@ class TreeObsForRailEnv(ObservationBuilder):
def _compute_distance_map(self):
agents = self.env.agents
# For testing only --> To assert if a distance map need to be recomputed.
self.distance_map_computed = True
nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.env.height,
......
......@@ -138,6 +138,9 @@ def tests_rail_from_file():
assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert env.obs_builder.distance_map_computed is False
assert np.shape(env.obs_builder.distance_map) == dist_map_shape
assert env.obs_builder.distance_map is not None
......@@ -207,5 +210,6 @@ def tests_rail_from_file():
assert agents_initial_2 == agents_loaded_4
# Check that distance map was generated with correct shape
assert env4.obs_builder.distance_map_computed is True
assert env4.obs_builder.distance_map is not None
assert np.shape(env4.obs_builder.distance_map) == dist_map_shape
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment