Skip to content
Snippets Groups Projects
Commit 212ee6a9 authored by Erik Nygren's avatar Erik Nygren
Browse files

updated load schedule from file to respect stored speed profiles

parent 8dd6a225
No related branches found
No related tags found
No related merge requests found
...@@ -24,8 +24,8 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor ...@@ -24,8 +24,8 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
# Different agent types (trains) with different speeds. # Different agent types (trains) with different speeds.
speed_ration_map = {1.: 0.25, # Fast passenger train speed_ration_map = {1.: 0.25, # Fast passenger train
1. / 2.: 0.25, # Slow commuter train 1. / 2.: 0.25, # Fast freight train
1. / 3.: 0.25, # Fast freight train 1. / 3.: 0.25, # Slow commuter train
1. / 4.: 0.25} # Slow freight train 1. / 4.: 0.25} # Slow freight train
env = RailEnv(width=50, env = RailEnv(width=50,
......
...@@ -44,7 +44,7 @@ class EnvAgentStatic(object): ...@@ -44,7 +44,7 @@ class EnvAgentStatic(object):
malfunction_datas = [] malfunction_datas = []
for i in range(len(positions)): for i in range(len(positions)):
malfunction_datas.append({'malfunction': 0, malfunction_datas.append({'malfunction': 0,
'malfunction_rate': 0, 'malfunction_rate': 0,
'next_malfunction': 0, 'next_malfunction': 0,
'nr_malfunctions': 0}) 'nr_malfunctions': 0})
......
...@@ -510,11 +510,9 @@ class RailEnv(Environment): ...@@ -510,11 +510,9 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist() grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static] agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents] agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data) msgpack.packb(grid_data)
msgpack.packb(agent_data) msgpack.packb(agent_data)
msgpack.packb(agent_static_data) msgpack.packb(agent_static_data)
msg_data = { msg_data = {
"grid": grid_data, "grid": grid_data,
"agents_static": agent_static_data, "agents_static": agent_static_data,
...@@ -532,7 +530,7 @@ class RailEnv(Environment): ...@@ -532,7 +530,7 @@ class RailEnv(Environment):
self.rail.grid = np.array(data[b"grid"]) self.rail.grid = np.array(data[b"grid"])
# agents are always reset as not moving # agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]]
# setup with loaded data # setup with loaded data
self.height, self.width = self.rail.grid.shape self.height, self.width = self.rail.grid.shape
self.rail.height = self.height self.rail.height = self.height
...@@ -544,7 +542,7 @@ class RailEnv(Environment): ...@@ -544,7 +542,7 @@ class RailEnv(Environment):
self.rail.grid = np.array(data[b"grid"]) self.rail.grid = np.array(data[b"grid"])
# agents are always reset as not moving # agents are always reset as not moving
self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]]
if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys(): if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys():
self.obs_builder.distance_map = data[b"distance_maps"] self.obs_builder.distance_map = data[b"distance_maps"]
# setup with loaded data # setup with loaded data
...@@ -557,7 +555,6 @@ class RailEnv(Environment): ...@@ -557,7 +555,6 @@ class RailEnv(Environment):
grid_data = self.rail.grid.tolist() grid_data = self.rail.grid.tolist()
agent_static_data = [agent.to_list() for agent in self.agents_static] agent_static_data = [agent.to_list() for agent in self.agents_static]
agent_data = [agent.to_list() for agent in self.agents] agent_data = [agent.to_list() for agent in self.agents]
msgpack.packb(grid_data) msgpack.packb(grid_data)
msgpack.packb(agent_data) msgpack.packb(agent_data)
msgpack.packb(agent_static_data) msgpack.packb(agent_static_data)
......
...@@ -227,12 +227,13 @@ def schedule_from_file(filename) -> ScheduleGenerator: ...@@ -227,12 +227,13 @@ def schedule_from_file(filename) -> ScheduleGenerator:
data = msgpack.unpackb(load_data, use_list=False) data = msgpack.unpackb(load_data, use_list=False)
# agents are always reset as not moving # agents are always reset as not moving
agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]] agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data[b"agents_static"]]
# setup with loaded data # setup with loaded data
agents_position = [a.position for a in agents_static] agents_position = [a.position for a in agents_static]
agents_direction = [a.direction for a in agents_static] agents_direction = [a.direction for a in agents_static]
agents_target = [a.target for a in agents_static] agents_target = [a.target for a in agents_static]
agents_speed = [a.speed_data[b'speed'] for a in agents_static]
return agents_position, agents_direction, agents_target, [1.0] * len(agents_position) return agents_position, agents_direction, agents_target, agents_speed
return generator return generator
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment