diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index 9b33a55c4f5f57a8f92135d95f69696ee3df0be7..355f5502992a34f8d58d4dbd80028eb4dd71cc48 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -145,7 +145,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= agents_target = [sg[1] for sg in start_goal[:num_agents]] agents_direction = start_dir[:num_agents] - return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator @@ -193,7 +193,7 @@ def rail_from_manual_specifications_generator(rail_spec): rail, num_agents) - return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator @@ -218,6 +218,7 @@ def rail_from_file(filename): with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False) + grid = np.array(data[b"grid"]) rail = GridTransitionMap(width=np.shape(grid)[1], height=np.shape(grid)[0], transitions=rail_env_transitions) rail.grid = grid @@ -227,11 +228,16 @@ def rail_from_file(filename): agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - if len(data) > 3: + if b"distance_maps" in data.keys(): distance_maps = data[b"distance_maps"] - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps + if len(distance_maps) > 0: + return rail, agents_position, agents_direction, agents_target, [1.0] * len( + agents_position), distance_maps + else: + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) else: return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return generator @@ -256,7 +262,7 @@ def rail_from_grid_transition_map(rail_map): rail_map, num_agents) - return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator @@ -529,6 +535,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return_rail, num_agents) - return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 38aadf939c2e2ac0b11b9fb3039817fa38786ffd..80dc73adf417b43d2740b80e9b86fe5ba5ce257f 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -41,6 +41,7 @@ class TreeObsForRailEnv(ObservationBuilder): self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] self.distance_map = None + self.distance_map_computed = False def reset(self): agents = self.env.agents @@ -51,7 +52,6 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(nb_agents): if agents[i].target != self.agents_previous_reset[i].target: compute_distance_map = True - # Don't compute the distance map if it was loaded if self.agents_previous_reset is None and self.distance_map is not None: self.location_has_target = {tuple(agent.target): 1 for agent in agents} @@ -64,6 +64,8 @@ class TreeObsForRailEnv(ObservationBuilder): def _compute_distance_map(self): agents = self.env.agents + # For testing only --> To assert if a distance map need to be recomputed. + self.distance_map_computed = True nb_agents = len(agents) self.distance_map = np.inf * np.ones(shape=(nb_agents, self.env.height, diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a0aa2e64933c2f9e8d53a5bcd08bb9f8845484e6..e098160d7a14f9b7cc07b7f365e6af573c3d7d56 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -480,8 +480,12 @@ class RailEnv(Environment): def save(self, filename): if hasattr(self.obs_builder, 'distance_map'): - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_dist_msg()) + if len(self.obs_builder.distance_map) > 0: + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_dist_msg()) + else: + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_msg()) else: with open(filename, "wb") as file_out: file_out.write(self.get_full_state_msg()) diff --git a/notebooks/Scene_Editor.ipynb b/notebooks/Scene_Editor.ipynb index 877f82001dfeef3732f824e06899f62f68ba85d4..852fddb422fd2d8e6a6c51df993a183e45511982 100644 --- a/notebooks/Scene_Editor.ipynb +++ b/notebooks/Scene_Editor.ipynb @@ -11,7 +11,20 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<style>.container { width:95% !important; }</style>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from IPython.core.display import display, HTML\n", "display(HTML(\"<style>.container { width:95% !important; }</style>\"))" @@ -57,7 +70,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bd784815032c4d89803fa93399def9cf", + "model_id": "5a05a5f0569846dfa58fc6cfc222619a", "version_major": 2, "version_minor": 0 }, diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 31dff253126bec041241aaba44f33dd9c494f2a1..f97b071e6b33c099efa5af36766e159e57716443 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -8,7 +8,7 @@ from flatland.envs.generators import rail_from_grid_transition_map, rail_from_fi from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from tests.simple_rail import make_simple_rail +from flatland.utils.simple_rail import make_simple_rail def test_empty_rail_generator(): @@ -122,7 +122,7 @@ def tests_rail_from_file(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.save(file_name) - + dist_map_shape = np.shape(env.obs_builder.distance_map) # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents @@ -138,6 +138,10 @@ def tests_rail_from_file(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded + + # Check that distance map was not recomputed + assert env.obs_builder.distance_map_computed is False + assert np.shape(env.obs_builder.distance_map) == dist_map_shape assert env.obs_builder.distance_map is not None # Test to save and load file without distance map. @@ -173,7 +177,6 @@ def tests_rail_from_file(): # Test to save with distance map and load without - # initialize agents_static env3 = RailEnv(width=1, height=1, rail_generator=rail_from_file(file_name), @@ -201,6 +204,12 @@ def tests_rail_from_file(): rails_loaded_4 = env4.rail.grid agents_loaded_4 = env4.agents + # Check that no distance map was saved + assert not hasattr(env2.obs_builder, "distance_map") assert np.all(np.array_equal(rails_initial_2, rails_loaded_4)) assert agents_initial_2 == agents_loaded_4 - assert env.obs_builder.distance_map is not None + + # Check that distance map was generated with correct shape + assert env4.obs_builder.distance_map_computed is True + assert env4.obs_builder.distance_map is not None + assert np.shape(env4.obs_builder.distance_map) == dist_map_shape