diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 2a1b72281f8b70bb326a3454a305d081d266948d..a66a27bf0e6dce55dfa878687ac1328aee63a6ea 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -29,13 +29,13 @@ class EnvAgentStatic(object): position = attrib() direction = attrib() target = attrib() - moving = False + moving = attrib() - def __init__(self, position, direction, target): + def __init__(self, position, direction, target, moving=False): self.position = position self.direction = direction self.target = target - self.moving = False + self.moving = moving @classmethod def from_lists(cls, positions, directions, targets): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 73d0147bc784243176faf7a19bad36005ad9662e..b93737709abf97e1364c389665431c375b2c16ac 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -460,7 +460,7 @@ class RailEnv(Environment): def set_full_state_msg(self, msg_data): data = msgpack.unpackb(msg_data, use_list=False) self.rail.grid = np.array(data[b"grid"]) - self.agents_static = [EnvAgentStatic(d[0], d[1], d[2]) for d in data[b"agents_static"]] + self.agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data[b"agents_static"]] self.agents = [EnvAgent(d[0], d[1], d[2], d[3], d[4]) for d in data[b"agents"]] # setup with loaded data self.height, self.width = self.rail.grid.shape diff --git a/tests/test_env_edit.py b/tests/test_env_edit.py index 2d8a4e087a0ab1d4425c05cf3153abcded8a8ceb..57dad857546d21009b881f4bf26e085873eaa655 100644 --- a/tests/test_env_edit.py +++ b/tests/test_env_edit.py @@ -8,7 +8,7 @@ def test_load_env(): env = RailEnv(10, 10) env.load("env-data/tests/test-10x10.mpk") - agent_static = EnvAgentStatic((0, 0), 2, (5, 5)) + agent_static = EnvAgentStatic((0, 0), 2, (5, 5), False) env.add_agent_static(agent_static) assert env.get_num_agents() == 1 diff --git a/tests/test_environments.py b/tests/test_environments.py index 9c7b53b9b5876a99d7deea20da10816d81f02b65..ba885d1ae931582adbf543e0014b6c1c4c533d1c 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -185,14 +185,14 @@ def test_dead_end(): # rail_env.agents_target[0] = (0, 0) # rail_env.agents_position[0] = (0, 2) # rail_env.agents_direction[0] = 1 - rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0))] + rail_env.agents = [EnvAgent(position=(0, 2), direction=1, target=(0, 0), moving=False)] check_consistency(rail_env) rail_env.reset() # rail_env.agents_target[0] = (0, 4) # rail_env.agents_position[0] = (0, 2) # rail_env.agents_direction[0] = 3 - rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4))] + rail_env.agents = [EnvAgent(position=(0, 2), direction=3, target=(0, 4), moving=False)] check_consistency(rail_env) # In the vertical configuration: