From 212ee6a9ab74517c6fd3127ab81f21d584c1ebbd Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sat, 31 Aug 2019 16:40:57 -0400 Subject: [PATCH] updated load schedule from file to respect stored speed profiles --- examples/flatland_2_0_example.py | 4 ++-- flatland/envs/agent_utils.py | 2 +- flatland/envs/rail_env.py | 7 ++----- flatland/envs/schedule_generators.py | 7 ++++--- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py index b560b340..082308e6 100644 --- a/examples/flatland_2_0_example.py +++ b/examples/flatland_2_0_example.py @@ -24,8 +24,8 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor # Different agent types (trains) with different speeds. speed_ration_map = {1.: 0.25, # Fast passenger train - 1. / 2.: 0.25, # Slow commuter train - 1. / 3.: 0.25, # Fast freight train + 1. / 2.: 0.25, # Fast freight train + 1. / 3.: 0.25, # Slow commuter train 1. / 4.: 0.25} # Slow freight train env = RailEnv(width=50, diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 4c407008..66e7e7f5 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -44,7 +44,7 @@ class EnvAgentStatic(object): malfunction_datas = [] for i in range(len(positions)): malfunction_datas.append({'malfunction': 0, - 'malfunction_rate': 0, + 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0}) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 62efbdc5..6adcb305 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -510,11 +510,9 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] - msgpack.packb(grid_data) msgpack.packb(agent_data) msgpack.packb(agent_static_data) - msg_data = { "grid": grid_data, "agents_static": agent_static_data, @@ -532,7 +530,7 @@ class RailEnv(Environment): self.rail.grid = np.array(data[b"grid"]) # agents are always reset as not moving self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -544,7 +542,7 @@ class RailEnv(Environment): self.rail.grid = np.array(data[b"grid"]) # agents are always reset as not moving self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] - self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] + self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]] if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys(): self.obs_builder.distance_map = data[b"distance_maps"] # setup with loaded data @@ -557,7 +555,6 @@ class RailEnv(Environment): grid_data = self.rail.grid.tolist() agent_static_data = [agent.to_list() for agent in self.agents_static] agent_data = [agent.to_list() for agent in self.agents] - msgpack.packb(grid_data) msgpack.packb(agent_data) msgpack.packb(agent_static_data) diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index a0f6825d..918b130a 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -227,12 +227,13 @@ def schedule_from_file(filename) -> ScheduleGenerator: data = msgpack.unpackb(load_data, use_list=False) # agents are always reset as not moving - agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] + agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data[b"agents_static"]] + # setup with loaded data agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - - return agents_position, agents_direction, agents_target, [1.0] * len(agents_position) + agents_speed = [a.speed_data[b'speed'] for a in agents_static] + return agents_position, agents_direction, agents_target, agents_speed return generator -- GitLab