From 212ee6a9ab74517c6fd3127ab81f21d584c1ebbd Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 31 Aug 2019 16:40:57 -0400
Subject: [PATCH] updated load schedule from file to respect stored speed
 profiles

---
 examples/flatland_2_0_example.py     | 4 ++--
 flatland/envs/agent_utils.py         | 2 +-
 flatland/envs/rail_env.py            | 7 ++-----
 flatland/envs/schedule_generators.py | 7 ++++---
 4 files changed, 9 insertions(+), 11 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index b560b340..082308e6 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -24,8 +24,8 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
 
 # Different agent types (trains) with different speeds.
 speed_ration_map = {1.: 0.25,  # Fast passenger train
-                    1. / 2.: 0.25,  # Slow commuter train
-                    1. / 3.: 0.25,  # Fast freight train
+                    1. / 2.: 0.25,  # Fast freight train
+                    1. / 3.: 0.25,  # Slow commuter train
                     1. / 4.: 0.25}  # Slow freight train
 
 env = RailEnv(width=50,
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 4c407008..66e7e7f5 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -44,7 +44,7 @@ class EnvAgentStatic(object):
         malfunction_datas = []
         for i in range(len(positions)):
             malfunction_datas.append({'malfunction': 0,
-                                 'malfunction_rate': 0,
+                                      'malfunction_rate': 0,
                                       'next_malfunction': 0,
                                       'nr_malfunctions': 0})
 
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 62efbdc5..6adcb305 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -510,11 +510,9 @@ class RailEnv(Environment):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
         agent_data = [agent.to_list() for agent in self.agents]
-
         msgpack.packb(grid_data)
         msgpack.packb(agent_data)
         msgpack.packb(agent_static_data)
-
         msg_data = {
             "grid": grid_data,
             "agents_static": agent_static_data,
@@ -532,7 +530,7 @@ class RailEnv(Environment):
         self.rail.grid = np.array(data[b"grid"])
         # agents are always reset as not moving
         self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
-        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
+        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]]
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -544,7 +542,7 @@ class RailEnv(Environment):
         self.rail.grid = np.array(data[b"grid"])
         # agents are always reset as not moving
         self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
-        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]]
+        self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9]) for d in data[b"agents"]]
         if hasattr(self.obs_builder, 'distance_map') and b"distance_maps" in data.keys():
             self.obs_builder.distance_map = data[b"distance_maps"]
         # setup with loaded data
@@ -557,7 +555,6 @@ class RailEnv(Environment):
         grid_data = self.rail.grid.tolist()
         agent_static_data = [agent.to_list() for agent in self.agents_static]
         agent_data = [agent.to_list() for agent in self.agents]
-
         msgpack.packb(grid_data)
         msgpack.packb(agent_data)
         msgpack.packb(agent_static_data)
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index a0f6825d..918b130a 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -227,12 +227,13 @@ def schedule_from_file(filename) -> ScheduleGenerator:
         data = msgpack.unpackb(load_data, use_list=False)
 
         # agents are always reset as not moving
-        agents_static = [EnvAgentStatic(d[0], d[1], d[2], moving=False) for d in data[b"agents_static"]]
+        agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data[b"agents_static"]]
+
         # setup with loaded data
         agents_position = [a.position for a in agents_static]
         agents_direction = [a.direction for a in agents_static]
         agents_target = [a.target for a in agents_static]
-
-        return agents_position, agents_direction, agents_target, [1.0] * len(agents_position)
+        agents_speed = [a.speed_data[b'speed'] for a in agents_static]
+        return agents_position, agents_direction, agents_target, agents_speed
 
     return generator
-- 
GitLab