From 454809751687be83a240a0835a1752714ce36939 Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Wed, 5 Jun 2019 22:26:41 +0200 Subject: [PATCH] fix? --- flatland/envs/agent_utils.py | 6 +++--- flatland/envs/rail_env.py | 2 +- tests/test_env_edit.py | 2 +- tests/test_environments.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 2a1b7228..a66a27bf 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -29,13 +29,13 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() - moving = False + moving = attrib() - def __init__(self, position, direction, target): + def __init__(self, position, direction, target, moving=False): self.position = position self.direction = direction self.target = target - self.moving = False + self.moving = moving @classmethod def from_lists(cls, positions, directions, targets): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 73d0147b..b9373770 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -460,7 +460,7 @@ class RailEnv(Environment): def set_full_state_msg(self, msg_data): data = msgpack.unpackb(msg_data, use_list=False) self.rail.grid = np.array(data[b"grid"]) - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]] + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data[b"agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py index 2d8a4e08..57dad857 100644 --- a/tests/test_env_edit.py +++ b/tests/test_env_edit.py @@ -8,7 +8,7 @@ def test_load_env(): env = RailEnv(10, 10) env.load("env-data/tests/test-10x10.mpk") - agent_static = EnvAgentStatic((0, 0), 2, (5, 5)) + agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) env.add_agent_static(agent_static) assert env.get_num_agents() == 1 diff --git a/tests/test_environments.py b/tests/test_environments.py index 9c7b53b9..ba885d1a 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -185,14 +185,14 @@ def test_dead_end(): # rail_env.agents_target[0] = (0, 0) # rail_env.agents_position[0] = (0, 2) # rail_env.agents_direction[0] = 1 - rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))] + rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)] check_consistency(rail_env) rail_env.reset() # rail_env.agents_target[0] = (0, 4) # rail_env.agents_position[0] = (0, 2) # rail_env.agents_direction[0] = 3 - rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))] + rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)] check_consistency(rail_env) # In the vertical configuration: -- GitLab