Skip to content
Snippets Groups Projects
Commit 672f0f3c authored by spiglerg's avatar spiglerg
Browse files

first push stochastic breaking -- added flags to agents

parent a447381f
No related branches found
No related tags found
No related merge requests found
...@@ -15,6 +15,7 @@ class EnvAgentStatic(object): ...@@ -15,6 +15,7 @@ class EnvAgentStatic(object):
direction = attrib() direction = attrib()
target = attrib() target = attrib()
moving = attrib(default=False) moving = attrib(default=False)
# speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0, # speed_data: speed is added to position_fraction on each moving step, until position_fraction>=1.0,
# after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous # after which 'transition_action_on_cellexit' is executed (equivalent to executing that action in the previous
# cell if speed=1, as default) # cell if speed=1, as default)
...@@ -22,6 +23,11 @@ class EnvAgentStatic(object): ...@@ -22,6 +23,11 @@ class EnvAgentStatic(object):
speed_data = attrib( speed_data = attrib(
default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0}))) default=Factory(lambda: dict({'position_fraction': 0.0, 'speed': 1.0, 'transition_action_on_cellexit': 0})))
# if broken>0, the agent's actions are ignored for 'broken' steps
# number of time the agent had to stop, since the last time it broke down
broken_data = attrib(
default=Factory(lambda: dict({'broken': 0, 'number_of_halts': 0})))
@classmethod @classmethod
def from_lists(cls, positions, directions, targets, speeds=None): def from_lists(cls, positions, directions, targets, speeds=None):
""" Create a list of EnvAgentStatics from lists of positions, directions and targets """ Create a list of EnvAgentStatics from lists of positions, directions and targets
...@@ -31,7 +37,14 @@ class EnvAgentStatic(object): ...@@ -31,7 +37,14 @@ class EnvAgentStatic(object):
speed_datas.append({'position_fraction': 0.0, speed_datas.append({'position_fraction': 0.0,
'speed': speeds[i] if speeds is not None else 1.0, 'speed': speeds[i] if speeds is not None else 1.0,
'transition_action_on_cellexit': 0}) 'transition_action_on_cellexit': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas)))
# TODO: on initialization, all agents are re-set as non-broken. Perhaps it may be desirable to set some as broken?
broken_datas = []
for i in range(len(positions)):
broken_datas.append({'broken': 0,
'number_of_halts': 0})
return list(starmap(EnvAgentStatic, zip(positions, directions, targets, [False] * len(positions), speed_datas, broken_datas)))
def to_list(self): def to_list(self):
...@@ -45,7 +58,7 @@ class EnvAgentStatic(object): ...@@ -45,7 +58,7 @@ class EnvAgentStatic(object):
if type(lTarget) is np.ndarray: if type(lTarget) is np.ndarray:
lTarget = lTarget.tolist() lTarget = lTarget.tolist()
return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data] return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.broken_data]
@attrs @attrs
...@@ -63,7 +76,7 @@ class EnvAgent(EnvAgentStatic): ...@@ -63,7 +76,7 @@ class EnvAgent(EnvAgentStatic):
def to_list(self): def to_list(self):
return [ return [
self.position, self.direction, self.target, self.handle, self.position, self.direction, self.target, self.handle,
self.old_direction, self.old_position, self.moving, self.speed_data] self.old_direction, self.old_position, self.moving, self.speed_data, self.broken_data]
@classmethod @classmethod
def from_static(cls, oStatic): def from_static(cls, oStatic):
......
...@@ -196,6 +196,8 @@ class RailEnv(Environment): ...@@ -196,6 +196,8 @@ class RailEnv(Environment):
for i_agent in range(self.get_num_agents()): for i_agent in range(self.get_num_agents()):
agent = self.agents[i_agent] agent = self.agents[i_agent]
agent.speed_data['position_fraction'] = 0.0 agent.speed_data['position_fraction'] = 0.0
agent.broken_data['broken'] = 0
agent.broken_data['number_of_halts'] = 0
self.num_resets += 1 self.num_resets += 1
self._elapsed_steps = 0 self._elapsed_steps = 0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment