Skip to content
Snippets Groups Projects
Commit 30e36369 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated how generator check for distance map data.

Updated test for generators to check that distance map is not recomputed when loaded correctly
parent 2a12614c
No related branches found
No related tags found
No related merge requests found
...@@ -218,6 +218,7 @@ def rail_from_file(filename): ...@@ -218,6 +218,7 @@ def rail_from_file(filename):
with open(filename, "rb") as file_in: with open(filename, "rb") as file_in:
load_data = file_in.read() load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False) data = msgpack.unpackb(load_data, use_list=False)
grid = np.array(data[b"grid"]) grid = np.array(data[b"grid"])
rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions)
rail.grid = grid rail.grid = grid
...@@ -227,7 +228,7 @@ def rail_from_file(filename): ...@@ -227,7 +228,7 @@ def rail_from_file(filename):
agents_position = [a.position for a in agents_static] agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static] agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static] agents_target = [a.target for a in agents_static]
if len(data) > 3: if b"distance_maps" in data.keys():
distance_maps = data[b"distance_maps"] distance_maps = data[b"distance_maps"]
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
else: else:
......
...@@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder):
self.tree_explored_actions = [1, 2, 3, 0] self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
self.distance_map = None self.distance_map = None
self.distance_map_computed = False
def reset(self): def reset(self):
agents = self.env.agents agents = self.env.agents
...@@ -64,6 +65,9 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -64,6 +65,9 @@ class TreeObsForRailEnv(ObservationBuilder):
def _compute_distance_map(self): def _compute_distance_map(self):
agents = self.env.agents agents = self.env.agents
# For testing only --> To assert if a distance map need to be recomputed.
self.distance_map_computed = True
nb_agents = len(agents) nb_agents = len(agents)
self.distance_map = np.inf * np.ones(shape=(nb_agents, self.distance_map = np.inf * np.ones(shape=(nb_agents,
self.env.height, self.env.height,
......
...@@ -138,6 +138,9 @@ def tests_rail_from_file(): ...@@ -138,6 +138,9 @@ def tests_rail_from_file():
assert np.all(np.array_equal(rails_initial, rails_loaded)) assert np.all(np.array_equal(rails_initial, rails_loaded))
assert agents_initial == agents_loaded assert agents_initial == agents_loaded
# Check that distance map was not recomputed
assert env.obs_builder.distance_map_computed is False
assert np.shape(env.obs_builder.distance_map) == dist_map_shape assert np.shape(env.obs_builder.distance_map) == dist_map_shape
assert env.obs_builder.distance_map is not None assert env.obs_builder.distance_map is not None
...@@ -207,5 +210,6 @@ def tests_rail_from_file(): ...@@ -207,5 +210,6 @@ def tests_rail_from_file():
assert agents_initial_2 == agents_loaded_4 assert agents_initial_2 == agents_loaded_4
# Check that distance map was generated with correct shape # Check that distance map was generated with correct shape
assert env4.obs_builder.distance_map_computed is True
assert env4.obs_builder.distance_map is not None assert env4.obs_builder.distance_map is not None
assert np.shape(env4.obs_builder.distance_map) == dist_map_shape assert np.shape(env4.obs_builder.distance_map) == dist_map_shape
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment