diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py index 82e4923931e78848e9afe07df00c9dab8d79abb5..c85aaafcc3d2bfafbe407cd616044ea01e973541 100644 --- a/flatland/envs/rail_env.py +++ b/flatland/envs/rail_env.py @@ -367,8 +367,6 @@ class RailEnv(Environment): for i_agent in range(self.get_num_agents()): self.set_agent_active(i_agent) - - for agent in self.agents: # Induce malfunctions self._break_agent(self.mean_malfunction_rate, agent) @@ -377,7 +375,7 @@ class RailEnv(Environment): agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING # Fix agents that finished their malfunction - self._fix_agent(agent) + self._fix_agent_after_malfunction(agent) self.num_resets += 1 self._elapsed_steps = 0 @@ -401,7 +399,7 @@ class RailEnv(Environment): observation_dict: Dict = self._get_observations() return observation_dict, info_dict - def _fix_agent(self, agent): + def _fix_agent_after_malfunction(self, agent: EnvAgent): """ Updates agent malfunction variables and fixes broken agents @@ -411,7 +409,7 @@ class RailEnv(Environment): """ # Ignore agents that are OK - if self._is_ok(agent): + if self._is_agent_ok(agent): return # Reduce number of malfunction steps left @@ -425,7 +423,7 @@ class RailEnv(Environment): agent.moving = agent.malfunction_data['moving_before_malfunction'] return - def _break_agent(self, rate, agent): + def _break_agent(self, rate: float, agent) -> bool: """ Malfunction generator that breaks agents at a given rate. @@ -437,13 +435,12 @@ class RailEnv(Environment): if agent.malfunction_data['malfunction'] < 1: if self.np_random.rand() < self._malfunction_prob(rate): num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken, - self.max_number_of_steps_broken + 1) + 1 + self.max_number_of_steps_broken + 1) + 1 agent.malfunction_data['malfunction'] = num_broken_steps agent.malfunction_data['moving_before_malfunction'] = agent.moving agent.malfunction_data['nr_malfunctions'] += 1 return - def step(self, action_dict_: Dict[int, RailEnvActions]): """ Updates rewards for the agents at a step. @@ -483,8 +480,6 @@ class RailEnv(Environment): } have_all_agents_ended = True # boolean flag to check if all agents are done - - for i_agent, agent in enumerate(self.agents): # Reset the step rewards self.rewards_dict[i_agent] = 0 @@ -504,8 +499,8 @@ class RailEnv(Environment): info_dict["speed"][i_agent] = agent.speed_data['speed'] info_dict["status"][i_agent] = agent.status - # Fix agents that finished their malfunction such that they can perfom an action in the next step - self._fix_agent(agent) + # Fix agents that finished their malfunction such that they can perform an action in the next step + self._fix_agent_after_malfunction(agent) # Check for end of episode + set global reward to all rewards! if have_all_agents_ended: @@ -957,7 +952,7 @@ class RailEnv(Environment): load_data = read_binary(package, resource) self.set_full_state_msg(load_data) - def _exp_distirbution_synced(self, rate): + def _exp_distirbution_synced(self, rate: float) -> float: """ Generates sample from exponential distribution We need this to guarantee synchronity between different instances with same seed. @@ -968,9 +963,9 @@ class RailEnv(Environment): x = - np.log(1 - u) * rate return x - def _malfunction_prob(self, rate): + def _malfunction_prob(self, rate: float) -> float: """ - Probability that an agent break given the number of agents an the probability of a sinlge agent to break + Probability of a single agent to break. According to Poisson process with given rate :param rate: :return: """ @@ -979,36 +974,7 @@ class RailEnv(Environment): else: return 1 - np.exp(- (1 / rate)) - def _draw_malfunctioning_agent(self, tries): - """ - Function to determin what agent will be breaking. - It only looks at active and non-broken agents. - After a number of steps it gives up the search after breaking agents and ignores malfunciton - - Parameters - ---------- - tries: How many times we tried to find an agent - - Returns - ------- - agent that is breaking - """ - # Select only from active agents - breaking_agent_idx = self.np_random.choice(self.active_agents) - breaking_agent = self.agents[breaking_agent_idx] - # We assume that at least half of the agents should still be working - if tries > 0.5 * len(self.active_agents): - return None - - # If agent is already broken look for a new one - elif breaking_agent.malfunction_data['malfunction'] > 0: - return self._draw_malfunctioning_agent(tries + 1) - - # Return agent to be broken - else: - return breaking_agent - - def _is_ok(self, agent): + def _is_agent_ok(self, agent: EnvAgent) -> bool: """ Check if an agent is ok, meaning it can move and is not malfuncitoinig Parameters @@ -1021,4 +987,3 @@ class RailEnv(Environment): """ return agent.malfunction_data['malfunction'] < 1 -