diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 9b33a55c4f5f57a8f92135d95f69696ee3df0be7..e208a46d1cba9f49ecb7410cd46ddc346566d354 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -218,6 +218,7 @@ def rail_from_file(filename): with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False) + grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid @@ -227,7 +228,7 @@ def rail_from_file(filename): agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - if len(data) > 3: + if b"distance_maps" in data.keys(): distance_maps = data[b"distance_maps"] return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps else: diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 38aadf939c2e2ac0b11b9fb3039817fa38786ffd..3e07c05e8a600bb173d65f438b75d746db34178b 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] self.distance_map = None + self.distance_map_computed = False def reset(self): agents = self.env.agents @@ -64,6 +65,9 @@ class TreeObsForRailEnv(ObservationBuilder): def _compute_distance_map(self): agents = self.env.agents + + # For testing only --> To assert if a distance map need to be recomputed. + self.distance_map_computed = True nb_agents = len(agents) self.distance_map = np.inf * np.ones(shape=(nb_agents, self.env.height, diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 8270685e40e2babc0ad1c001445d3312b656c870..f97b071e6b33c099efa5af36766e159e57716443 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -138,6 +138,9 @@ def tests_rail_from_file(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded + + # Check that distance map was not recomputed + assert env.obs_builder.distance_map_computed is False assert np.shape(env.obs_builder.distance_map) == dist_map_shape assert env.obs_builder.distance_map is not None @@ -207,5 +210,6 @@ def tests_rail_from_file(): assert agents_initial_2 == agents_loaded_4 # Check that distance map was generated with correct shape + assert env4.obs_builder.distance_map_computed is True assert env4.obs_builder.distance_map is not None assert np.shape(env4.obs_builder.distance_map) == dist_map_shape