From 23ff4c38f6cb35f898a32f462b4516e3b13fa486 Mon Sep 17 00:00:00 2001 From: Giacomo Spigler <spiglerg@gmail.com> Date: Wed, 19 Jun 2019 12:57:42 +0200 Subject: [PATCH] test1 --- flatland/envs/agent_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index 8e9ffb99..9d75dfef 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -29,18 +29,23 @@ class EnvAgentStatic(object): direction = attrib() target = attrib() moving = attrib() + speed_data = attrib() - def __init__(self, position, direction, target, moving=False): + def __init__(self, position, direction, target, moving=False, speed_data={'position_fraction':0.0, 'speed':1.0, 'transition_action_on_cellexit':2}): self.position = position self.direction = direction self.target = target self.moving = moving + self.speed_data = speed_data @classmethod def from_lists(cls, positions, directions, targets): """ Create a list of EnvAgentStatics from lists of positions, directions and targets """ - return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions)))) + speed_datas = [] + for i in range(len(positions)): + speed_datas.append( speed_data={'position_fraction':0.0, 'speed':1.0, 'transition_action_on_cellexit':2} ) + return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas))) def to_list(self): @@ -54,7 +59,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget, int(self.moving)] + return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data] @attrs @@ -78,7 +83,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position, self.moving] + self.old_direction, self.old_position, self.moving, self.speed_data] @classmethod def from_static(cls, oStatic): -- GitLab