diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 8e9ffb99d06176416dfbe2b65bcca09723b1a56c..9d75dfef3b01ad87d7fe8ef302f1cb15fdaba5fc 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -29,18 +29,23 @@ class EnvAgentStatic(object): direction = attrib() target = attrib() moving = attrib() + speed_data = attrib() - def __init__(self, position, direction, target, moving=False): + def __init__(self, position, direction, target, moving=False, speed_data={'position_fraction':0.0, 'speed':1.0, 'transition_action_on_cellexit':2}): self.position = position self.direction = direction self.target = target self.moving = moving + self.speed_data = speed_data @classmethod def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions)))) + speed_datas = [] + for i in range(len(positions)): + speed_datas.append( speed_data={'position_fraction':0.0, 'speed':1.0, 'transition_action_on_cellexit':2} ) + return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas))) def to_list(self): @@ -54,7 +59,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget, int(self.moving)] + return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data] @attrs @@ -78,7 +83,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position, self.moving] + self.old_direction, self.old_position, self.moving, self.speed_data] @classmethod def from_static(cls, oStatic):