Skip to content
Snippets Groups Projects
Commit ada01747 authored by adrian_egli's avatar adrian_egli
Browse files

Merge branch '291-converter-for-old-pkl-files' into 'master'

implement loading of legacy static agents and enable unit tests again

Closes #291

See merge request flatland/flatland!267
parents 64362a51 abfbe464
No related branches found
No related tags found
No related merge requests found
......@@ -103,3 +103,13 @@ class EnvAgent:
speed_datas,
malfunction_datas,
range(len(schedule.agent_positions)))))
@classmethod
def load_legacy_static_agent(cls, static_agents_data: Tuple):
agents = []
for i, static_agent in enumerate(static_agents_data):
agent = EnvAgent(initial_position=static_agent[0], initial_direction=static_agent[1],
direction=static_agent[1], target=static_agent[2], moving=static_agent[3],
speed_data=static_agent[4], malfunction_data=static_agent[5], handle=i)
agents.append(agent)
return agents
......@@ -869,7 +869,10 @@ class RailEnv(Environment):
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
self.height, self.width = self.rail.grid.shape
self.rail.height = self.height
......@@ -887,7 +890,10 @@ class RailEnv(Environment):
data = msgpack.unpackb(msg_data, use_list=False, encoding='utf-8')
self.rail.grid = np.array(data["grid"])
# agents are always reset as not moving
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "agents_static" in data:
self.agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
self.agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "distance_map" in data.keys():
self.distance_map.set(data["distance_map"])
# setup with loaded data
......
......@@ -291,7 +291,10 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
with open(filename, "rb") as file_in:
load_data = file_in.read()
data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8')
agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
if "agents_static" in data:
agents = EnvAgent.load_legacy_static_agent(data["agents_static"])
else:
agents = [EnvAgent(*d[0:12]) for d in data["agents"]]
# setup with loaded data
agents_position = [a.initial_position for a in agents]
......
import sys
import numpy as np
import pytest
from flatland.core.grid.grid4 import Grid4TransitionsEnum
from flatland.envs.observations import TreeObsForRailEnv
......@@ -40,7 +39,6 @@ def test_get_shortest_paths_unreachable():
# todo file test_002.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
@pytest.mark.skip
def test_get_shortest_paths():
env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
env.reset()
......@@ -172,7 +170,6 @@ def test_get_shortest_paths():
# todo file test_002.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
@pytest.mark.skip
def test_get_shortest_paths_max_depth():
env = load_flatland_environment_from_file('test_002.pkl', 'env_data.tests')
env.reset()
......@@ -204,7 +201,6 @@ def test_get_shortest_paths_max_depth():
# todo file Level_distance_map_shortest_path.pkl has to be generated automatically
# see https://gitlab.aicrowd.com/flatland/flatland/issues/279
@pytest.mark.skip
def test_get_shortest_paths_agent_handle():
env = load_flatland_environment_from_file('Level_distance_map_shortest_path.pkl', 'env_data.tests')
env.reset()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment