Skip to content
Snippets Groups Projects
Commit 79c8d7aa authored by Erik Nygren's avatar Erik Nygren :bullettrain_front:
Browse files

Fixed merge comments by christian baumberger

parent 0dcd1e57
No related branches found
No related tags found
No related merge requests found
...@@ -367,8 +367,6 @@ class RailEnv(Environment): ...@@ -367,8 +367,6 @@ class RailEnv(Environment):
for i_agent in range(self.get_num_agents()): for i_agent in range(self.get_num_agents()):
self.set_agent_active(i_agent) self.set_agent_active(i_agent)
for agent in self.agents: for agent in self.agents:
# Induce malfunctions # Induce malfunctions
self._break_agent(self.mean_malfunction_rate, agent) self._break_agent(self.mean_malfunction_rate, agent)
...@@ -377,7 +375,7 @@ class RailEnv(Environment): ...@@ -377,7 +375,7 @@ class RailEnv(Environment):
agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
# Fix agents that finished their malfunction # Fix agents that finished their malfunction
self._fix_agent(agent) self._fix_agent_after_malfunction(agent)
self.num_resets += 1 self.num_resets += 1
self._elapsed_steps = 0 self._elapsed_steps = 0
...@@ -401,7 +399,7 @@ class RailEnv(Environment): ...@@ -401,7 +399,7 @@ class RailEnv(Environment):
observation_dict: Dict = self._get_observations() observation_dict: Dict = self._get_observations()
return observation_dict, info_dict return observation_dict, info_dict
def _fix_agent(self, agent): def _fix_agent_after_malfunction(self, agent: EnvAgent):
""" """
Updates agent malfunction variables and fixes broken agents Updates agent malfunction variables and fixes broken agents
...@@ -411,7 +409,7 @@ class RailEnv(Environment): ...@@ -411,7 +409,7 @@ class RailEnv(Environment):
""" """
# Ignore agents that are OK # Ignore agents that are OK
if self._is_ok(agent): if self._is_agent_ok(agent):
return return
# Reduce number of malfunction steps left # Reduce number of malfunction steps left
...@@ -425,7 +423,7 @@ class RailEnv(Environment): ...@@ -425,7 +423,7 @@ class RailEnv(Environment):
agent.moving = agent.malfunction_data['moving_before_malfunction'] agent.moving = agent.malfunction_data['moving_before_malfunction']
return return
def _break_agent(self, rate, agent): def _break_agent(self, rate: float, agent) -> bool:
""" """
Malfunction generator that breaks agents at a given rate. Malfunction generator that breaks agents at a given rate.
...@@ -437,13 +435,12 @@ class RailEnv(Environment): ...@@ -437,13 +435,12 @@ class RailEnv(Environment):
if agent.malfunction_data['malfunction'] < 1: if agent.malfunction_data['malfunction'] < 1:
if self.np_random.rand() < self._malfunction_prob(rate): if self.np_random.rand() < self._malfunction_prob(rate):
num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
self.max_number_of_steps_broken + 1) + 1 self.max_number_of_steps_broken + 1) + 1
agent.malfunction_data['malfunction'] = num_broken_steps agent.malfunction_data['malfunction'] = num_broken_steps
agent.malfunction_data['moving_before_malfunction'] = agent.moving agent.malfunction_data['moving_before_malfunction'] = agent.moving
agent.malfunction_data['nr_malfunctions'] += 1 agent.malfunction_data['nr_malfunctions'] += 1
return return
def step(self, action_dict_: Dict[int, RailEnvActions]): def step(self, action_dict_: Dict[int, RailEnvActions]):
""" """
Updates rewards for the agents at a step. Updates rewards for the agents at a step.
...@@ -483,8 +480,6 @@ class RailEnv(Environment): ...@@ -483,8 +480,6 @@ class RailEnv(Environment):
} }
have_all_agents_ended = True # boolean flag to check if all agents are done have_all_agents_ended = True # boolean flag to check if all agents are done
for i_agent, agent in enumerate(self.agents): for i_agent, agent in enumerate(self.agents):
# Reset the step rewards # Reset the step rewards
self.rewards_dict[i_agent] = 0 self.rewards_dict[i_agent] = 0
...@@ -504,8 +499,8 @@ class RailEnv(Environment): ...@@ -504,8 +499,8 @@ class RailEnv(Environment):
info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["speed"][i_agent] = agent.speed_data['speed']
info_dict["status"][i_agent] = agent.status info_dict["status"][i_agent] = agent.status
# Fix agents that finished their malfunction such that they can perfom an action in the next step # Fix agents that finished their malfunction such that they can perform an action in the next step
self._fix_agent(agent) self._fix_agent_after_malfunction(agent)
# Check for end of episode + set global reward to all rewards! # Check for end of episode + set global reward to all rewards!
if have_all_agents_ended: if have_all_agents_ended:
...@@ -957,7 +952,7 @@ class RailEnv(Environment): ...@@ -957,7 +952,7 @@ class RailEnv(Environment):
load_data = read_binary(package, resource) load_data = read_binary(package, resource)
self.set_full_state_msg(load_data) self.set_full_state_msg(load_data)
def _exp_distirbution_synced(self, rate): def _exp_distirbution_synced(self, rate: float) -> float:
""" """
Generates sample from exponential distribution Generates sample from exponential distribution
We need this to guarantee synchronity between different instances with same seed. We need this to guarantee synchronity between different instances with same seed.
...@@ -968,9 +963,9 @@ class RailEnv(Environment): ...@@ -968,9 +963,9 @@ class RailEnv(Environment):
x = - np.log(1 - u) * rate x = - np.log(1 - u) * rate
return x return x
def _malfunction_prob(self, rate): def _malfunction_prob(self, rate: float) -> float:
""" """
Probability that an agent break given the number of agents an the probability of a sinlge agent to break Probability of a single agent to break. According to Poisson process with given rate
:param rate: :param rate:
:return: :return:
""" """
...@@ -979,36 +974,7 @@ class RailEnv(Environment): ...@@ -979,36 +974,7 @@ class RailEnv(Environment):
else: else:
return 1 - np.exp(- (1 / rate)) return 1 - np.exp(- (1 / rate))
def _draw_malfunctioning_agent(self, tries): def _is_agent_ok(self, agent: EnvAgent) -> bool:
"""
Function to determin what agent will be breaking.
It only looks at active and non-broken agents.
After a number of steps it gives up the search after breaking agents and ignores malfunciton
Parameters
----------
tries: How many times we tried to find an agent
Returns
-------
agent that is breaking
"""
# Select only from active agents
breaking_agent_idx = self.np_random.choice(self.active_agents)
breaking_agent = self.agents[breaking_agent_idx]
# We assume that at least half of the agents should still be working
if tries > 0.5 * len(self.active_agents):
return None
# If agent is already broken look for a new one
elif breaking_agent.malfunction_data['malfunction'] > 0:
return self._draw_malfunctioning_agent(tries + 1)
# Return agent to be broken
else:
return breaking_agent
def _is_ok(self, agent):
""" """
Check if an agent is ok, meaning it can move and is not malfuncitoinig Check if an agent is ok, meaning it can move and is not malfuncitoinig
Parameters Parameters
...@@ -1021,4 +987,3 @@ class RailEnv(Environment): ...@@ -1021,4 +987,3 @@ class RailEnv(Environment):
""" """
return agent.malfunction_data['malfunction'] < 1 return agent.malfunction_data['malfunction'] < 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment