Skip to content
Snippets Groups Projects
Commit 4d538a41 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated load and save function. Now also distance maps are stored.

Additional package msgpack-numpy needed for ndarray.
This saves tons of time when loading precomputed files.
parent a0791a3f
No related branches found
No related tags found
No related merge requests found
...@@ -227,8 +227,11 @@ def rail_from_file(filename): ...@@ -227,8 +227,11 @@ def rail_from_file(filename):
agents_position = [a.position for a in agents_static] agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static] agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static] agents_target = [a.target for a in agents_static]
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position) if len(data) > 3:
distance_maps = data[b"distance_maps"]
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position), distance_maps
else:
return rail, agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
return generator return generator
......
...@@ -41,22 +41,28 @@ class TreeObsForRailEnv(ObservationBuilder): ...@@ -41,22 +41,28 @@ class TreeObsForRailEnv(ObservationBuilder):
self.agents_previous_reset = None self.agents_previous_reset = None
self.tree_explored_actions = [1, 2, 3, 0] self.tree_explored_actions = [1, 2, 3, 0]
self.tree_explorted_actions_char = ['L', 'F', 'R', 'B'] self.tree_explorted_actions_char = ['L', 'F', 'R', 'B']
self.distance_map = None
def reset(self): def reset(self):
agents = self.env.agents agents = self.env.agents
nb_agents = len(agents) nb_agents = len(agents)
compute_distance_map = True compute_distance_map = True
if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset): if self.agents_previous_reset is not None and nb_agents == len(self.agents_previous_reset):
compute_distance_map = False compute_distance_map = False
for i in range(nb_agents): for i in range(nb_agents):
if agents[i].target != self.agents_previous_reset[i].target: if agents[i].target != self.agents_previous_reset[i].target:
compute_distance_map = True compute_distance_map = True
self.agents_previous_reset = agents
# Don't compute the distance map if it was loaded
if self.agents_previous_reset is None and self.distance_map is not None:
self.location_has_target = {tuple(agent.target): 1 for agent in agents}
compute_distance_map = False
if compute_distance_map: if compute_distance_map:
self._compute_distance_map() self._compute_distance_map()
self.agents_previous_reset = agents
def _compute_distance_map(self): def _compute_distance_map(self):
agents = self.env.agents agents = self.env.agents
nb_agents = len(agents) nb_agents = len(agents)
......
...@@ -6,6 +6,7 @@ Definition of the RailEnv environment. ...@@ -6,6 +6,7 @@ Definition of the RailEnv environment.
from enum import IntEnum from enum import IntEnum
import msgpack import msgpack
import msgpack_numpy as m
import numpy as np import numpy as np
from flatland.core.env import Environment from flatland.core.env import Environment
...@@ -14,6 +15,8 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent ...@@ -14,6 +15,8 @@ from flatland.envs.agent_utils import EnvAgentStatic, EnvAgent
from flatland.envs.generators import random_rail_generator from flatland.envs.generators import random_rail_generator
from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.observations import TreeObsForRailEnv
m.patch()
class RailEnvActions(IntEnum): class RailEnvActions(IntEnum):
DO_NOTHING = 0 # implies change of direction in a dead-end! DO_NOTHING = 0 # implies change of direction in a dead-end!
...@@ -170,6 +173,10 @@ class RailEnv(Environment): ...@@ -170,6 +173,10 @@ class RailEnv(Environment):
""" """
tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets) tRailAgents = self.rail_generator(self.width, self.height, self.get_num_agents(), self.num_resets)
# Check if generator provided a distance map TODO: Make this check safer!
if len(tRailAgents) > 5:
self.obs_builder.distance_map = tRailAgents[-1]
if regen_rail or self.rail is None: if regen_rail or self.rail is None:
self.rail = tRailAgents[0] self.rail = tRailAgents[0]
self.height, self.width = self.rail.grid.shape self.height, self.width = self.rail.grid.shape
...@@ -424,6 +431,8 @@ class RailEnv(Environment): ...@@ -424,6 +431,8 @@ class RailEnv(Environment):
# agents are always reset as not moving # agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
if hasattr(self.obs_builder, 'distance_map'):
self.obs_builder.distance_map = data[b"distance_maps"]
# setup with loaded data # setup with loaded data
self.height, self.width = self.rail.grid.shape self.height, self.width = self.rail.grid.shape
self.rail.height = self.height self.rail.height = self.height
...@@ -438,22 +447,39 @@ class RailEnv(Environment): ...@@ -438,22 +447,39 @@ class RailEnv(Environment):
msgpack.packb(grid_data) msgpack.packb(grid_data)
msgpack.packb(agent_data) msgpack.packb(agent_data)
msgpack.packb(agent_static_data) msgpack.packb(agent_static_data)
if hasattr(self.obs_builder, 'distance_map'):
distance_map_data = self.obs_builder.distance_map
msgpack.packb(distance_map_data)
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data,
"distance_maps": distance_map_data}
else:
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
msg_data = {
"grid": grid_data,
"agents_static": agent_static_data,
"agents": agent_data}
return msgpack.packb(msg_data, use_bin_type=True) return msgpack.packb(msg_data, use_bin_type=True)
def save(self, filename): def save(self, filename):
with open(filename, "wb") as file_out: if hasattr(self.obs_builder, 'distance_map'):
file_out.write(self.get_full_state_msg()) with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_dist_msg())
else:
with open(filename, "wb") as file_out:
file_out.write(self.get_full_state_msg())
def load(self, filename): def load(self, filename):
with open(filename, "rb") as file_in: if hasattr(self.obs_builder, 'distance_map'):
load_data = file_in.read() with open(filename, "rb") as file_in:
self.set_full_state_msg(load_data) load_data = file_in.read()
self.set_full_state_dist_msg(load_data)
else:
with open(filename, "rb") as file_in:
load_data = file_in.read()
self.set_full_state_msg(load_data)
def load_pkl(self, pkl_data): def load_pkl(self, pkl_data):
self.set_full_state_msg(pkl_data) self.set_full_state_msg(pkl_data)
......
...@@ -3,13 +3,14 @@ tox>=3.5.2 ...@@ -3,13 +3,14 @@ tox>=3.5.2
twine>=1.12.1 twine>=1.12.1
pytest>=3.8.2 pytest>=3.8.2
pytest-runner>=4.2 pytest-runner>=4.2
numpy>=1.16.4 numpy>=1.16.2
recordtype>=1.3 recordtype>=1.3
xarray>=0.11.3 xarray>=0.11.3
matplotlib>=3.0.2 matplotlib>=3.0.2
Pillow>=5.4.1 Pillow>=5.4.1
CairoSVG>=2.3.1 CairoSVG>=2.3.1
msgpack>=0.6.1 msgpack>=0.6.1
msgpack-numpy>=0.4.4.0
svgutils>=0.3.1 svgutils>=0.3.1
screeninfo>=0.3.1 screeninfo>=0.3.1
pyarrow>=0.13.0 pyarrow>=0.13.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment