From 02e9d8d0a98ca4a986c7c9c8effbdcc604e85c6b Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Tue, 16 Jul 2019 11:49:59 -0400 Subject: [PATCH] updated how generator check for distance map data. Updated test for generators to check that distance map is not recomputed when loaded correctly --- flatland/envs/generators.py | 16 +++++++++++----- flatland/envs/observations.py | 3 +-- flatland/envs/rail_env.py | 8 ++++++-- notebooks/Scene_Editor.ipynb | 17 +++++++++++++++-- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index e208a46d..ae2968f8 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -145,7 +145,7 @@ def complex_rail_generator(nr_start_goal=1, nr_extra=100, min_dist=20, max_dist= agents_target = [sg[1] for sg in start_goal[:num_agents]] agents_direction = start_dir[:num_agents] - return grid_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return grid_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator @@ -193,7 +193,7 @@ def rail_from_manual_specifications_generator(rail_spec): rail, num_agents) - return rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator @@ -230,9 +230,15 @@ def rail_from_file(filename): agents_target = [a.target for a in agents_static] if b"distance_maps" in data.keys(): distance_maps = data[b"distance_maps"] - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps + if len(distance_maps) > 0: + print("Loading distance map") + return rail, agents_position, agents_direction, agents_target, [1.0] * len( + agents_position), distance_maps + else: + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) else: return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + return generator @@ -257,7 +263,7 @@ def rail_from_grid_transition_map(rail_map): rail_map, num_agents) - return rail_map, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return rail_map, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator @@ -530,6 +536,6 @@ def random_rail_generator(cell_type_relative_proportion=[1.0] * 11): return_rail, num_agents) - return return_rail, agents_position, agents_direction, agents_target, [1.0]*len(agents_position) + return return_rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3e07c05e..d0d678b8 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -52,7 +52,6 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(nb_agents): if agents[i].target != self.agents_previous_reset[i].target: compute_distance_map = True - # Don't compute the distance map if it was loaded if self.agents_previous_reset is None and self.distance_map is not None: self.location_has_target = {tuple(agent.target): 1 for agent in agents} @@ -65,7 +64,7 @@ class TreeObsForRailEnv(ObservationBuilder): def _compute_distance_map(self): agents = self.env.agents - + print("Computing distance map") # For testing only --> To assert if a distance map need to be recomputed. self.distance_map_computed = True nb_agents = len(agents) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index a0aa2e64..e098160d 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -480,8 +480,12 @@ class RailEnv(Environment): def save(self, filename): if hasattr(self.obs_builder, 'distance_map'): - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_dist_msg()) + if len(self.obs_builder.distance_map) > 0: + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_dist_msg()) + else: + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_msg()) else: with open(filename, "wb") as file_out: file_out.write(self.get_full_state_msg()) diff --git a/notebooks/Scene_Editor.ipynb b/notebooks/Scene_Editor.ipynb index 877f8200..852fddb4 100644 --- a/notebooks/Scene_Editor.ipynb +++ b/notebooks/Scene_Editor.ipynb @@ -11,7 +11,20 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<style>.container { width:95% !important; }</style>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "from IPython.core.display import display, HTML\n", "display(HTML(\"<style>.container { width:95% !important; }</style>\"))" @@ -57,7 +70,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bd784815032c4d89803fa93399def9cf", + "model_id": "5a05a5f0569846dfa58fc6cfc222619a", "version_major": 2, "version_minor": 0 }, -- GitLab