From abfbe464e1ff6e519508dba4d4e0fee236375da5 Mon Sep 17 00:00:00 2001 From: u229589 <christian.baumberger@sbb.ch> Date: Thu, 7 Nov 2019 11:48:14 +0100 Subject: [PATCH] implement loading of legacy static agents and enable unit tests again --- flatland/envs/agent_utils.py | 10 ++++++++++ flatland/envs/rail_env.py | 10 ++++++++-- flatland/envs/schedule_generators.py | 5 ++++- tests/test_flatland_envs_rail_env_shortest_paths.py | 4 ---- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index dd639997..b895c6d2 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -103,3 +103,13 @@ class EnvAgent: speed_datas, malfunction_datas, range(len(schedule.agent_positions))))) + + @classmethod + def load_legacy_static_agent(cls, static_agents_data: Tuple): + agents = [] + for i, static_agent in enumerate(static_agents_data): + agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1], + direction=static_agent[1], target=static_agent[2], moving=static_agent[3], + speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i) + agents.append(agent) + return agents diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 0606863a..cab96ab0 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -869,7 +869,10 @@ class RailEnv(Environment): data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + if "agents_static" in data: + self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + else: + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape self.rail.height = self.height @@ -887,7 +890,10 @@ class RailEnv(Environment): data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8') self.rail.grid = np.array(data["grid"]) # agents are always reset as not moving - self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + if "agents_static" in data: + self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + else: + self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]] if "distance_map" in data.keys(): self.distance_map.set(data["distance_map"]) # setup with loaded data diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index a19501df..ceafa802 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -291,7 +291,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator: with open(filename, "rb") as file_in: load_data = file_in.read() data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') - agents = [EnvAgent(*d[0:12]) for d in data["agents"]] + if "agents_static" in data: + agents = EnvAgent.load_legacy_static_agent(data["agents_static"]) + else: + agents = [EnvAgent(*d[0:12]) for d in data["agents"]] # setup with loaded data agents_position = [a.initial_position for a in agents] diff --git a/tests/test_flatland_envs_rail_env_shortest_paths.py b/tests/test_flatland_envs_rail_env_shortest_paths.py index bb3e5e2a..0fffd035 100644 --- a/tests/test_flatland_envs_rail_env_shortest_paths.py +++ b/tests/test_flatland_envs_rail_env_shortest_paths.py @@ -1,7 +1,6 @@ import sys import numpy as np -import pytest from flatland.core.grid.grid4 import Grid4TransitionsEnum from flatland.envs.observations import TreeObsForRailEnv @@ -40,7 +39,6 @@ def test_get_shortest_paths_unreachable(): # todo file test_002.pkl has to be generated automatically # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 -@pytest.mark.skip def test_get_shortest_paths(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') env.reset() @@ -172,7 +170,6 @@ def test_get_shortest_paths(): # todo file test_002.pkl has to be generated automatically # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 -@pytest.mark.skip def test_get_shortest_paths_max_depth(): env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests') env.reset() @@ -204,7 +201,6 @@ def test_get_shortest_paths_max_depth(): # todo file Level_distance_map_shortest_path.pkl has to be generated automatically # see https://gitlab.aicrowd.com/flatland/flatland/issues/279 -@pytest.mark.skip def test_get_shortest_paths_agent_handle(): env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests') env.reset() -- GitLab