From 4d538a4155be094ec868282aa520dec1992f5fb3 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 13 Jul 2019 12:08:18 -0400 Subject: [PATCH] updated load and save function. Now also distance maps are stored. Additional package msgpack-numpy needed for ndarray. This saves tons of time when loading precomputed files. --- flatland/envs/generators.py | 7 ++++-- flatland/envs/observations.py | 10 ++++++-- flatland/envs/rail_env.py | 46 +++++++++++++++++++++++++++-------- requirements_dev.txt | 3 ++- 4 files changed, 51 insertions(+), 15 deletions(-) diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index ec579c1d..9b33a55c 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -227,8 +227,11 @@ def rail_from_file(filename): agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) - + if len(data) > 3: + distance_maps = data[b"distance_maps"] + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps + else: + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3929b9e1..4385d4da 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -41,22 +41,28 @@ class TreeObsForRailEnv(ObservationBuilder): self.agents_previous_reset = None self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] + self.distance_map = None def reset(self): agents = self.env.agents nb_agents = len(agents) - compute_distance_map = True if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): compute_distance_map = False for i in range(nb_agents): if agents[i].target != self.agents_previous_reset[i].target: compute_distance_map = True - self.agents_previous_reset = agents + + # Don't compute the distance map if it was loaded + if self.agents_previous_reset is None and self.distance_map is not None: + self.location_has_target = {tuple(agent.target): 1 for agent in agents} + compute_distance_map = False if compute_distance_map: self._compute_distance_map() + self.agents_previous_reset = agents + def _compute_distance_map(self): agents = self.env.agents nb_agents = len(agents) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 43351099..f082f080 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -6,6 +6,7 @@ Definition of the RailEnv environment. from enum import IntEnum import msgpack +import msgpack_numpy as m import numpy as np from flatland.core.env import Environment @@ -14,6 +15,8 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv +m.patch() + class RailEnvActions(IntEnum): DO_NOTHING = 0 # implies change of direction in a dead-end! @@ -170,6 +173,10 @@ class RailEnv(Environment): """ tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) + # Check if generator provided a distance map TODO: Make this check safer! + if len(tRailAgents) > 5: + self.obs_builder.distance_map = tRailAgents[-1] + if regen_rail or self.rail is None: self.rail = tRailAgents[0] self.height, self.width = self.rail.grid.shape @@ -424,6 +431,8 @@ class RailEnv(Environment): # agents are always reset as not moving self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + if hasattr(self.obs_builder, 'distance_map'): + self.obs_builder.distance_map = data[b"distance_maps"] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -438,22 +447,39 @@ class RailEnv(Environment): msgpack.packb(grid_data) msgpack.packb(agent_data) msgpack.packb(agent_static_data) + if hasattr(self.obs_builder, 'distance_map'): + distance_map_data = self.obs_builder.distance_map + msgpack.packb(distance_map_data) + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data, + "distance_maps": distance_map_data} + else: + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data} - msg_data = { - "grid": grid_data, - "agents_static": agent_static_data, - "agents": agent_data} return msgpack.packb(msg_data, use_bin_type=True) - def save(self, filename): - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_msg()) + if hasattr(self.obs_builder, 'distance_map'): + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_dist_msg()) + else: + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_msg()) def load(self, filename): - with open(filename, "rb") as file_in: - load_data = file_in.read() - self.set_full_state_msg(load_data) + if hasattr(self.obs_builder, 'distance_map'): + with open(filename, "rb") as file_in: + load_data = file_in.read() + self.set_full_state_dist_msg(load_data) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() + self.set_full_state_msg(load_data) def load_pkl(self, pkl_data): self.set_full_state_msg(pkl_data) diff --git a/requirements_dev.txt b/requirements_dev.txt index ea46eb24..edd6ee28 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -3,13 +3,14 @@ tox>=3.5.2 twine>=1.12.1 pytest>=3.8.2 pytest-runner>=4.2 -numpy>=1.16.4 +numpy>=1.16.2 recordtype>=1.3 xarray>=0.11.3 matplotlib>=3.0.2 Pillow>=5.4.1 CairoSVG>=2.3.1 msgpack>=0.6.1 +msgpack-numpy>=0.4.4.0 svgutils>=0.3.1 screeninfo>=0.3.1 pyarrow>=0.13.0 -- GitLab