diff --git a/flatland/envs/generators.py b/flatland/envs/generators.py index ec579c1dbd080dc53504421e6a58673e205f6725..9b33a55c4f5f57a8f92135d95f69696ee3df0be7 100644 --- a/flatland/envs/generators.py +++ b/flatland/envs/generators.py @@ -227,8 +227,11 @@ def rail_from_file(filename): agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) - + if len(data) > 3: + distance_maps = data[b"distance_maps"] + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps + else: + return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return generator diff --git a/flatland/envs/observations.py b/flatland/envs/observations.py index 3929b9e191615ba91fb0e13df6d8ae040b401a5a..2f8ffb80858809a7e601a250265e7336786003b2 100644 --- a/flatland/envs/observations.py +++ b/flatland/envs/observations.py @@ -22,8 +22,6 @@ class TreeObsForRailEnv(ObservationBuilder): For details about the features in the tree observation see the get() function. """ - observation_dim = 9 - def __init__(self, max_depth, predictor=None): super().__init__() self.max_depth = max_depth @@ -34,6 +32,7 @@ class TreeObsForRailEnv(ObservationBuilder): for i in range(self.max_depth + 1): size += pow4 pow4 *= 4 + self.observation_dim = 9 self.observation_space = [size * self.observation_dim] self.location_has_agent = {} self.location_has_agent_direction = {} @@ -41,22 +40,28 @@ class TreeObsForRailEnv(ObservationBuilder): self.agents_previous_reset = None self.tree_explored_actions = [1, 2, 3, 0] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] + self.distance_map = None def reset(self): agents = self.env.agents nb_agents = len(agents) - compute_distance_map = True if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): compute_distance_map = False for i in range(nb_agents): if agents[i].target != self.agents_previous_reset[i].target: compute_distance_map = True - self.agents_previous_reset = agents + + # Don't compute the distance map if it was loaded + if self.agents_previous_reset is None and self.distance_map is not None: + self.location_has_target = {tuple(agent.target): 1 for agent in agents} + compute_distance_map = False if compute_distance_map: self._compute_distance_map() + self.agents_previous_reset = agents + def _compute_distance_map(self): agents = self.env.agents nb_agents = len(agents) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 996301a84f06f185ba9ed605ea1145f404c8b16e..f082f0801ba782bab8d4106856739fc798a3efb8 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -6,6 +6,7 @@ Definition of the RailEnv environment. from enum import IntEnum import msgpack +import msgpack_numpy as m import numpy as np from flatland.core.env import Environment @@ -14,6 +15,8 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent from flatland.envs.generators import random_rail_generator from flatland.envs.observations import TreeObsForRailEnv +m.patch() + class RailEnvActions(IntEnum): DO_NOTHING = 0 # implies change of direction in a dead-end! @@ -170,6 +173,10 @@ class RailEnv(Environment): """ tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) + # Check if generator provided a distance map TODO: Make this check safer! + if len(tRailAgents) > 5: + self.obs_builder.distance_map = tRailAgents[-1] + if regen_rail or self.rail is None: self.rail = tRailAgents[0] self.height, self.width = self.rail.grid.shape @@ -418,14 +425,61 @@ class RailEnv(Environment): self.rail.width = self.width self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + def set_full_state_dist_msg(self, msg_data): + data = msgpack.unpackb(msg_data, use_list=False) + self.rail.grid = np.array(data[b"grid"]) + # agents are always reset as not moving + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + if hasattr(self.obs_builder, 'distance_map'): + self.obs_builder.distance_map = data[b"distance_maps"] + # setup with loaded data + self.height, self.width = self.rail.grid.shape + self.rail.height = self.height + self.rail.width = self.width + self.dones = dict.fromkeys(list(range(self.get_num_agents())) + ["__all__"], False) + + def get_full_state_dist_msg(self): + grid_data = self.rail.grid.tolist() + agent_static_data = [agent.to_list() for agent in self.agents_static] + agent_data = [agent.to_list() for agent in self.agents] + + msgpack.packb(grid_data) + msgpack.packb(agent_data) + msgpack.packb(agent_static_data) + if hasattr(self.obs_builder, 'distance_map'): + distance_map_data = self.obs_builder.distance_map + msgpack.packb(distance_map_data) + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data, + "distance_maps": distance_map_data} + else: + msg_data = { + "grid": grid_data, + "agents_static": agent_static_data, + "agents": agent_data} + + return msgpack.packb(msg_data, use_bin_type=True) + def save(self, filename): - with open(filename, "wb") as file_out: - file_out.write(self.get_full_state_msg()) + if hasattr(self.obs_builder, 'distance_map'): + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_dist_msg()) + else: + with open(filename, "wb") as file_out: + file_out.write(self.get_full_state_msg()) def load(self, filename): - with open(filename, "rb") as file_in: - load_data = file_in.read() - self.set_full_state_msg(load_data) + if hasattr(self.obs_builder, 'distance_map'): + with open(filename, "rb") as file_in: + load_data = file_in.read() + self.set_full_state_dist_msg(load_data) + else: + with open(filename, "rb") as file_in: + load_data = file_in.read() + self.set_full_state_msg(load_data) def load_pkl(self, pkl_data): self.set_full_state_msg(pkl_data) diff --git a/requirements_dev.txt b/requirements_dev.txt index ea46eb245842881f1c51ededa5284300836a36ba..edd6ee2842dfb196db90f5334c6318f801785e0e 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -3,13 +3,14 @@ tox>=3.5.2 twine>=1.12.1 pytest>=3.8.2 pytest-runner>=4.2 -numpy>=1.16.4 +numpy>=1.16.2 recordtype>=1.3 xarray>=0.11.3 matplotlib>=3.0.2 Pillow>=5.4.1 CairoSVG>=2.3.1 msgpack>=0.6.1 +msgpack-numpy>=0.4.4.0 svgutils>=0.3.1 screeninfo>=0.3.1 pyarrow>=0.13.0 diff --git a/tests/tests_generators.py b/tests/tests_generators.py index 449b83294173c9665f54c34a668579b222f0c281..31dff253126bec041241aaba44f33dd9c494f2a1 100644 --- a/tests/tests_generators.py +++ b/tests/tests_generators.py @@ -5,7 +5,7 @@ import numpy as np from flatland.envs.generators import rail_from_grid_transition_map, rail_from_file, complex_rail_generator, \ random_rail_generator, empty_rail_generator -from flatland.envs.observations import TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from tests.simple_rail import make_simple_rail @@ -109,8 +109,12 @@ def test_rail_from_grid_transition_map(): def tests_rail_from_file(): - file_name = "test_pkl.pkl" + file_name = "test_with_distance_map.pkl" + + # Test to save and load file with distance map. + rail, rail_map = make_simple_rail() + env = RailEnv(width=rail_map.shape[1], height=rail_map.shape[0], rail_generator=rail_from_grid_transition_map(rail), @@ -118,6 +122,7 @@ def tests_rail_from_file(): obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()), ) env.save(file_name) + # initialize agents_static rails_initial = env.rail.grid agents_initial = env.agents @@ -133,4 +138,69 @@ def tests_rail_from_file(): assert np.all(np.array_equal(rails_initial, rails_loaded)) assert agents_initial == agents_loaded + assert env.obs_builder.distance_map is not None + + # Test to save and load file without distance map. + + file_name_2 = "test_without_distance_map.pkl" + + env2 = RailEnv(width=rail_map.shape[1], + height=rail_map.shape[0], + rail_generator=rail_from_grid_transition_map(rail), + number_of_agents=3, + obs_builder_object=GlobalObsForRailEnv(), + ) + + env2.save(file_name_2) + + # initialize agents_static + rails_initial_2 = env2.rail.grid + agents_initial_2 = env2.agents + env2 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name_2), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv(), + ) + + rails_loaded_2 = env2.rail.grid + agents_loaded_2 = env2.agents + + assert np.all(np.array_equal(rails_initial_2, rails_loaded_2)) + assert agents_initial_2 == agents_loaded_2 + assert not hasattr(env2.obs_builder, "distance_map") + + # Test to save with distance map and load without + + # initialize agents_static + env3 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name), + number_of_agents=1, + obs_builder_object=GlobalObsForRailEnv(), + ) + + rails_loaded_3 = env3.rail.grid + agents_loaded_3 = env3.agents + + assert np.all(np.array_equal(rails_initial, rails_loaded_3)) + assert agents_initial == agents_loaded_3 + assert not hasattr(env2.obs_builder, "distance_map") + + # Test to save without distance map and load with generating distance map + + # initialize agents_static + env4 = RailEnv(width=1, + height=1, + rail_generator=rail_from_file(file_name_2), + number_of_agents=1, + obs_builder_object=TreeObsForRailEnv(max_depth=2), + ) + + rails_loaded_4 = env4.rail.grid + agents_loaded_4 = env4.agents + + assert np.all(np.array_equal(rails_initial_2, rails_loaded_4)) + assert agents_initial_2 == agents_loaded_4 + assert env.obs_builder.distance_map is not None