diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py index e353af29ddbee16c208e2059767c18fa7880cb64..6b881f9cda5d566b2989b938bb234d0560cdf7cb 100644 --- a/flatland/envs/agent_utils.py +++ b/flatland/envs/agent_utils.py @@ -15,6 +15,7 @@ class EnvAgentStatic(object): direction = attrib() target = attrib() moving = attrib(default=False) + # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous # cell if speed=1, as default) @@ -22,6 +23,11 @@ class EnvAgentStatic(object): speed_data = attrib( default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))) + # if broken>0, the agent's actions are ignored for 'broken' steps + # number of time the agent had to stop, since the last time it broke down + broken_data = attrib( + default=Factory(lambda: dict({'broken': 0, 'number_of_halts': 0}))) + @classmethod def from_lists(cls, positions, directions, targets, speeds=None): """ Create a list of EnvAgentStatics from lists of positions, directions and targets @@ -31,7 +37,14 @@ class EnvAgentStatic(object): speed_datas.append({'position_fraction': 0.0, 'speed': speeds[i] if speeds is not None else 1.0, 'transition_action_on_cellexit': 0}) - return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas))) + + # TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set some as broken? + broken_datas = [] + for i in range(len(positions)): + broken_datas.append({'broken': 0, + 'number_of_halts': 0}) + + return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas, broken_datas))) def to_list(self): @@ -45,7 +58,7 @@ class EnvAgentStatic(object): if type(lTarget) is np.ndarray: lTarget = lTarget.tolist() - return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data] + return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.broken_data] @attrs @@ -63,7 +76,7 @@ class EnvAgent(EnvAgentStatic): def to_list(self): return [ self.position, self.direction, self.target, self.handle, - self.old_direction, self.old_position, self.moving, self.speed_data] + self.old_direction, self.old_position, self.moving, self.speed_data, self.broken_data] @classmethod def from_static(cls, oStatic): diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index abe623ae173a593e265cff7d4d88eb323e16b08e..ca003a80460a2ccdd4bfb01402a9cddc06261e1e 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -196,6 +196,8 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): agent = self.agents[i_agent] agent.speed_data['position_fraction'] = 0.0 + agent.broken_data['broken'] = 0 + agent.broken_data['number_of_halts'] = 0 self.num_resets += 1 self._elapsed_steps = 0