diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py index ee80320bc5eb03f8568e38fa876f5cf18818c999..a3a6dc1e4813a8c41284d46b6c03c5898cdcc63e 100644 --- a/flatland/envs/schedule_generators.py +++ b/flatland/envs/schedule_generators.py @@ -228,14 +228,22 @@ def schedule_from_file(filename) -> ScheduleGenerator: data = msgpack.unpackb(load_data, use_list=False, encoding='utf-8') # agents are always reset as not moving - agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]] + if len(data['agents_static'][0]) > 5: + print(len(data['agents_static'][0])) + agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3], d[4], d[5]) for d in data["agents_static"]] + else: + agents_static = [EnvAgentStatic(d[0], d[1], d[2], d[3]) for d in data["agents_static"]] # setup with loaded data agents_position = [a.position for a in agents_static] agents_direction = [a.direction for a in agents_static] agents_target = [a.target for a in agents_static] - agents_speed = [a.speed_data['speed'] for a in agents_static] - agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static] + if len(data['agents_static'][0]) > 5: + agents_speed = [a.speed_data['speed'] for a in agents_static] + agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static] + else: + agents_speed = None + agents_malfunction = None return agents_position, agents_direction, agents_target, agents_speed, agents_malfunction return generator