From abfbe464e1ff6e519508dba4d4e0fee236375da5 Mon Sep 17 00:00:00 2001
From: u229589 <christian.baumberger@sbb.ch>
Date: Thu, 7 Nov 2019 11:48:14 +0100
Subject: [PATCH] implement loading of legacy static agents and enable unit
 tests again

---
 flatland/envs/agent_utils.py                        | 10 ++++++++++
 flatland/envs/rail_env.py                           | 10 ++++++++--
 flatland/envs/schedule_generators.py                |  5 ++++-
 tests/test_flatland_envs_rail_env_shortest_paths.py |  4 ----
 4 files changed, 22 insertions(+), 7 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index dd639997..b895c6d2 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -103,3 +103,13 @@ class EnvAgent:
                                           speed_datas,
                                           malfunction_datas,
                                           range(len(schedule.agent_positions)))))
+
+    @classmethod
+    def load_legacy_static_agent(cls, static_agents_data: Tuple):
+        agents = []
+        for i, static_agent in enumerate(static_agents_data):
+            agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
+                             direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
+                             speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
+            agents.append(agent)
+        return agents
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 0606863a..cab96ab0 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -869,7 +869,10 @@ class RailEnv(Environment):
         data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
         self.rail.grid = np.array(data["grid"])
         # agents are always reset as not moving
-        self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
+        if "agents_static" in data:
+            self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
+        else:
+            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
         # setup with loaded data
         self.height, self.width = self.rail.grid.shape
         self.rail.height = self.height
@@ -887,7 +890,10 @@ class RailEnv(Environment):
         data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
         self.rail.grid = np.array(data["grid"])
         # agents are always reset as not moving
-        self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
+        if "agents_static" in data:
+            self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
+        else:
+            self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
         if "distance_map" in data.keys():
             self.distance_map.set(data["distance_map"])
         # setup with loaded data
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index a19501df..ceafa802 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -291,7 +291,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
             with open(filename, "rb") as file_in:
                 load_data = file_in.read()
         data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
-        agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
+        if "agents_static" in data:
+            agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
+        else:
+            agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
 
         # setup with loaded data
         agents_position = [a.initial_position for a in agents]
diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py
index bb3e5e2a..0fffd035 100644
--- a/tests/test_flatland_envs_rail_env_shortest_paths.py
+++ b/tests/test_flatland_envs_rail_env_shortest_paths.py
@@ -1,7 +1,6 @@
 import sys
 
 import numpy as np
-import pytest
 
 from flatland.core.grid.grid4 import Grid4TransitionsEnum
 from flatland.envs.observations import TreeObsForRailEnv
@@ -40,7 +39,6 @@ def test_get_shortest_paths_unreachable():
 
 # todo file test_002.pkl has to be generated automatically
 # see https://gitlab.aicrowd.com/flatland/flatland/issues/279
-@pytest.mark.skip
 def test_get_shortest_paths():
     env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
     env.reset()
@@ -172,7 +170,6 @@ def test_get_shortest_paths():
 
 # todo file test_002.pkl has to be generated automatically
 # see https://gitlab.aicrowd.com/flatland/flatland/issues/279
-@pytest.mark.skip
 def test_get_shortest_paths_max_depth():
     env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
     env.reset()
@@ -204,7 +201,6 @@ def test_get_shortest_paths_max_depth():
 
 # todo file Level_distance_map_shortest_path.pkl has to be generated automatically
 # see https://gitlab.aicrowd.com/flatland/flatland/issues/279
-@pytest.mark.skip
 def test_get_shortest_paths_agent_handle():
     env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests')
     env.reset()
-- 
GitLab