Skip to content
Snippets Groups Projects
Commit c4456f93 authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

remodeled the way malfunctions work: Now they are agent independent

parent 94de85fa
No related branches found
No related tags found
No related merge requests found
...@@ -59,7 +59,7 @@ schedule_generator = sparse_schedule_generator(speed_ration_map) ...@@ -59,7 +59,7 @@ schedule_generator = sparse_schedule_generator(speed_ration_map)
# during an episode. # during an episode.
stochastic_data = {'prop_malfunction': 0.3, # Percentage of defective agents stochastic_data = {'prop_malfunction': 0.3, # Percentage of defective agents
'malfunction_rate': 30, # Rate of malfunction occurence 'malfunction_rate': 50, # Rate of malfunction occurence
'min_duration': 3, # Minimal duration of malfunction 'min_duration': 3, # Minimal duration of malfunction
'max_duration': 20 # Max duration of malfunction 'max_duration': 20 # Max duration of malfunction
} }
...@@ -204,9 +204,8 @@ print("========================================") ...@@ -204,9 +204,8 @@ print("========================================")
for agent_idx, agent in enumerate(env.agents): for agent_idx, agent in enumerate(env.agents):
print( print(
"Agent {} will malfunction = {} at a rate of {}, the next malfunction will occur in {} step. Agent OK = {}".format( "Agent {} is OK = {}".format(
agent_idx, agent.malfunction_data['malfunction_rate'] > 0, agent.malfunction_data['malfunction_rate'], agent_idx, agent.malfunction_data['malfunction'] < 1))
agent.malfunction_data['next_malfunction'], agent.malfunction_data['malfunction'] < 1))
# Now that you have seen these novel concepts that were introduced you will realize that agents don't need to take # Now that you have seen these novel concepts that were introduced you will realize that agents don't need to take
# an action at every time step as it will only change the outcome when actions are chosen at cell entry. # an action at every time step as it will only change the outcome when actions are chosen at cell entry.
......
...@@ -39,8 +39,8 @@ class EnvAgentStatic(object): ...@@ -39,8 +39,8 @@ class EnvAgentStatic(object):
# number of time the agent had to stop, since the last time it broke down # number of time the agent had to stop, since the last time it broke down
malfunction_data = attrib( malfunction_data = attrib(
default=Factory( default=Factory(
lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0, 'nr_malfunctions': 0, lambda: dict({'malfunction': 0, 'nr_malfunctions': 0,
'moving_before_malfunction': False, 'fixed': False}))) 'moving_before_malfunction': False, 'fixed': True})))
status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus) status = attrib(default=RailAgentStatus.READY_TO_DEPART, type=RailAgentStatus)
position = attrib(default=None, type=Optional[Tuple[int, int]]) position = attrib(default=None, type=Optional[Tuple[int, int]])
...@@ -62,11 +62,8 @@ class EnvAgentStatic(object): ...@@ -62,11 +62,8 @@ class EnvAgentStatic(object):
malfunction_datas = [] malfunction_datas = []
for i in range(len(schedule.agent_positions)): for i in range(len(schedule.agent_positions)):
malfunction_datas.append({'malfunction': 0, malfunction_datas.append({'malfunction': 0,
'malfunction_rate': schedule.agent_malfunction_rates[
i] if schedule.agent_malfunction_rates is not None else 0.,
'next_malfunction': 0,
'nr_malfunctions': 0, 'nr_malfunctions': 0,
'fixed': False}) 'fixed': True})
return list(starmap(EnvAgentStatic, zip(schedule.agent_positions, return list(starmap(EnvAgentStatic, zip(schedule.agent_positions,
schedule.agent_directions, schedule.agent_directions,
......
...@@ -8,6 +8,7 @@ from typing import List, NamedTuple, Optional, Dict ...@@ -8,6 +8,7 @@ from typing import List, NamedTuple, Optional, Dict
import msgpack import msgpack
import msgpack_numpy as m import msgpack_numpy as m
import numpy as np import numpy as np
import random
from gym.utils import seeding from gym.utils import seeding
from flatland.core.env import Environment from flatland.core.env import Environment
...@@ -194,20 +195,15 @@ class RailEnv(Environment): ...@@ -194,20 +195,15 @@ class RailEnv(Environment):
# Stochastic train malfunctioning parameters # Stochastic train malfunctioning parameters
if stochastic_data is not None: if stochastic_data is not None:
prop_malfunction = stochastic_data['prop_malfunction']
mean_malfunction_rate = stochastic_data['malfunction_rate'] mean_malfunction_rate = stochastic_data['malfunction_rate']
malfunction_min_duration = stochastic_data['min_duration'] malfunction_min_duration = stochastic_data['min_duration']
malfunction_max_duration = stochastic_data['max_duration'] malfunction_max_duration = stochastic_data['max_duration']
else: else:
prop_malfunction = 0.
mean_malfunction_rate = 0. mean_malfunction_rate = 0.
malfunction_min_duration = 0. malfunction_min_duration = 0.
malfunction_max_duration = 0. malfunction_max_duration = 0.
# percentage of malfunctioning trains # Mean malfunction in number of time steps
self.proportion_malfunctioning_trains = prop_malfunction
# Mean malfunction in number of stops
self.mean_malfunction_rate = mean_malfunction_rate self.mean_malfunction_rate = mean_malfunction_rate
# Uniform distribution parameters for malfunction duration # Uniform distribution parameters for malfunction duration
...@@ -219,6 +215,7 @@ class RailEnv(Environment): ...@@ -219,6 +215,7 @@ class RailEnv(Environment):
def _seed(self, seed=None): def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed) self.np_random, seed = seeding.np_random(seed)
random.seed(seed)
return [seed] return [seed]
# no more agent_handles # no more agent_handles
...@@ -344,16 +341,8 @@ class RailEnv(Environment): ...@@ -344,16 +341,8 @@ class RailEnv(Environment):
if activate_agents: if activate_agents:
for i_agent in range(self.get_num_agents()): for i_agent in range(self.get_num_agents()):
self.set_agent_active(i_agent) self.set_agent_active(i_agent)
self._malfunction(self.mean_malfunction_rate)
for i_agent, agent in enumerate(self.agents): for i_agent, agent in enumerate(self.agents):
# A proportion of agent in the environment will receive a positive malfunction rate
if self.np_random.rand() < self.proportion_malfunctioning_trains:
agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
next_breakdown = int(
self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
agent.malfunction_data['next_malfunction'] = next_breakdown
agent.malfunction_data['malfunction'] = 0
initial_malfunction = self._agent_malfunction(i_agent) initial_malfunction = self._agent_malfunction(i_agent)
if initial_malfunction: if initial_malfunction:
...@@ -390,45 +379,39 @@ class RailEnv(Environment): ...@@ -390,45 +379,39 @@ class RailEnv(Environment):
""" """
agent = self.agents[i_agent] agent = self.agents[i_agent]
# Ignore agents that dont have positive malfunction rate # Reduce number of malfunction steps left
if agent.malfunction_data['malfunction_rate'] < 1:
return False
# Update malfunctioning agents
if agent.malfunction_data['malfunction'] > 0: if agent.malfunction_data['malfunction'] > 0:
agent.malfunction_data['malfunction'] -= 1 agent.malfunction_data['malfunction'] -= 1
return True return True
if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] > 0: # Ignore agents that OK
# Restart fixed agents if agent.malfunction_data['fixed']:
if not agent.malfunction_data['fixed']: return False
agent.malfunction_data['next_malfunction'] -= 1
agent.malfunction_data['fixed'] = True # Restart agents at the end of their malfunction
if 'moving_before_malfunction' in agent.malfunction_data: agent.malfunction_data['fixed'] = True
self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction'] if 'moving_before_malfunction' in agent.malfunction_data:
return False self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
else: return False
# Agent has been running smoothly
agent.malfunction_data['next_malfunction'] -= 1
return False
# Break agents that have next_malfunction def _malfunction(self, rate) -> bool:
if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] < 1: """
# Increase number of malfunctions Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run
agent.malfunction_data['nr_malfunctions'] += 1
agent.malfunction_data['fixed'] = False """
if np.random.random() < self._malfunction_prob(rate):
# Next malfunction in number of stops breaking_agent = random.choice(self.agents)
next_breakdown = int( while breaking_agent.status == RailAgentStatus.DONE_REMOVED:
self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate'])) breaking_agent = random.choice(self.agents)
agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1)
# Duration of current malfunction
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) self.max_number_of_steps_broken + 1)
agent.malfunction_data['malfunction'] = num_broken_steps breaking_agent.malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['moving_before_malfunction'] = agent.moving breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
breaking_agent.malfunction_data['fixed'] = False
return True
...@@ -463,6 +446,9 @@ class RailEnv(Environment): ...@@ -463,6 +446,9 @@ class RailEnv(Environment):
"status" : {}, "status" : {},
} }
have_all_agents_ended = True # boolean flag to check if all agents are done have_all_agents_ended = True # boolean flag to check if all agents are done
# Evoke the malfunction generator
self._malfunction(self.mean_malfunction_rate)
for i_agent, agent in enumerate(self.agents): for i_agent, agent in enumerate(self.agents):
# Reset the step rewards # Reset the step rewards
self.rewards_dict[i_agent] = 0 self.rewards_dict[i_agent] = 0
...@@ -824,3 +810,14 @@ class RailEnv(Environment): ...@@ -824,3 +810,14 @@ class RailEnv(Environment):
u = self.np_random.rand() u = self.np_random.rand()
x = - np.log(1 - u) * rate x = - np.log(1 - u) * rate
return x return x
def _malfunction_prob(self, rate):
"""
Gives the cummulative exponential distribution at point x, with exp decay rate
:param rate:
:return:
"""
if rate <= 0:
return 0.
else:
return 1 - np.exp(-(1 / rate))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment