From c9dc6e10d7d7880a13662dc1e9bcf3020f6f8004 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 31 Oct 2019 14:20:36 -0400 Subject: [PATCH] refactored how we call malfuncitons. Now we draw random numbers at each time step for each agent. will this be to expensive? If yes this can easily be refactored --- flatland/envs/rail_env.py | 76 +++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 34 deletions(-) diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 58faab24..82e49239 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -367,15 +367,17 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): self.set_agent_active(i_agent) - # Induce malfunctions - self._malfunction(self.mean_malfunction_rate) + for agent in self.agents: + # Induce malfunctions + self._break_agent(self.mean_malfunction_rate, agent) + if agent.malfunction_data["malfunction"] > 0: agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING - # Fix agents that finished their malfunciton - self._fix_agents() + # Fix agents that finished their malfunction + self._fix_agent(agent) self.num_resets += 1 self._elapsed_steps = 0 @@ -399,42 +401,46 @@ class RailEnv(Environment): observation_dict: Dict = self._get_observations() return observation_dict, info_dict - def _fix_agents(self): + def _fix_agent(self, agent): """ Updates agent malfunction variables and fixes broken agents - """ - for agent in self.agents: - # Ignore agents that OK - if self._agent_is_ok(agent): - continue + Parameters + ---------- + agent + """ - # Reduce number of malfunction steps left - if agent.malfunction_data['malfunction'] > 1: - agent.malfunction_data['malfunction'] -= 1 - continue + # Ignore agents that are OK + if self._is_ok(agent): + return - # Restart agents at the end of their malfunction + # Reduce number of malfunction steps left + if agent.malfunction_data['malfunction'] > 1: agent.malfunction_data['malfunction'] -= 1 - if 'moving_before_malfunction' in agent.malfunction_data: - agent.moving = agent.malfunction_data['moving_before_malfunction'] - continue + return + # Restart agents at the end of their malfunction + agent.malfunction_data['malfunction'] -= 1 + if 'moving_before_malfunction' in agent.malfunction_data: + agent.moving = agent.malfunction_data['moving_before_malfunction'] + return - def _malfunction(self, rate): + def _break_agent(self, rate, agent): """ - Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run + Malfunction generator that breaks agents at a given rate. - """ - if self.np_random.rand() < self._malfunction_prob(rate, len(self.active_agents)): + Parameters + ---------- + agent - breaking_agent = self._draw_malfunctioning_agent(0) - if breaking_agent: + """ + if agent.malfunction_data['malfunction'] < 1: + if self.np_random.rand() < self._malfunction_prob(rate): num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken + 1) + 1 - breaking_agent.malfunction_data['malfunction'] = num_broken_steps - breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving - breaking_agent.malfunction_data['nr_malfunctions'] += 1 + agent.malfunction_data['malfunction'] = num_broken_steps + agent.malfunction_data['moving_before_malfunction'] = agent.moving + agent.malfunction_data['nr_malfunctions'] += 1 return @@ -477,13 +483,15 @@ class RailEnv(Environment): } have_all_agents_ended = True # boolean flag to check if all agents are done - # Induce malfunctions - self._malfunction(self.mean_malfunction_rate) + for i_agent, agent in enumerate(self.agents): # Reset the step rewards self.rewards_dict[i_agent] = 0 + # Induce malfunction before we do a step, thus a broken agent can't move in this step + self._break_agent(self.mean_malfunction_rate, agent) + # Perform step on the agent self._step_agent(i_agent, action_dict_.get(i_agent)) @@ -496,8 +504,8 @@ class RailEnv(Environment): info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["status"][i_agent] = agent.status - # Fix agents that finished their malfunction - self._fix_agents() + # Fix agents that finished their malfunction such that they can perfom an action in the next step + self._fix_agent(agent) # Check for end of episode + set global reward to all rewards! if have_all_agents_ended: @@ -960,7 +968,7 @@ class RailEnv(Environment): x = - np.log(1 - u) * rate return x - def _malfunction_prob(self, rate, n_agents): + def _malfunction_prob(self, rate): """ Probability that an agent break given the number of agents an the probability of a sinlge agent to break :param rate: @@ -969,7 +977,7 @@ class RailEnv(Environment): if rate <= 0: return 0. else: - return 1 - np.exp(- (1 / rate) * (n_agents)) + return 1 - np.exp(- (1 / rate)) def _draw_malfunctioning_agent(self, tries): """ @@ -1000,7 +1008,7 @@ class RailEnv(Environment): else: return breaking_agent - def _agent_is_ok(self, agent): + def _is_ok(self, agent): """ Check if an agent is ok, meaning it can move and is not malfuncitoinig Parameters -- GitLab