From c9dc6e10d7d7880a13662dc1e9bcf3020f6f8004 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Thu, 31 Oct 2019 14:20:36 -0400
Subject: [PATCH] refactored how we call malfuncitons. Now we draw random
 numbers at each time step for each agent. will this be to expensive? If yes
 this can easily be refactored

---
 flatland/envs/rail_env.py | 76 +++++++++++++++++++++------------------
 1 file changed, 42 insertions(+), 34 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 58faab24..82e49239 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -367,15 +367,17 @@ class RailEnv(Environment):
             for i_agent in range(self.get_num_agents()):
                 self.set_agent_active(i_agent)
 
-        # Induce malfunctions
-        self._malfunction(self.mean_malfunction_rate)
+
 
         for agent in self.agents:
+            # Induce malfunctions
+            self._break_agent(self.mean_malfunction_rate, agent)
+
             if agent.malfunction_data["malfunction"] > 0:
                 agent.speed_data['transition_action_on_cellexit'] = RailEnvActions.DO_NOTHING
 
-        # Fix agents that finished their malfunciton
-        self._fix_agents()
+            # Fix agents that finished their malfunction
+            self._fix_agent(agent)
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -399,42 +401,46 @@ class RailEnv(Environment):
         observation_dict: Dict = self._get_observations()
         return observation_dict, info_dict
 
-    def _fix_agents(self):
+    def _fix_agent(self, agent):
         """
         Updates agent malfunction variables and fixes broken agents
-        """
-        for agent in self.agents:
 
-            # Ignore agents that OK
-            if self._agent_is_ok(agent):
-                continue
+        Parameters
+        ----------
+        agent
+        """
 
-            # Reduce number of malfunction steps left
-            if agent.malfunction_data['malfunction'] > 1:
-                agent.malfunction_data['malfunction'] -= 1
-                continue
+        # Ignore agents that are OK
+        if self._is_ok(agent):
+            return
 
-            # Restart agents at the end of their malfunction
+        # Reduce number of malfunction steps left
+        if agent.malfunction_data['malfunction'] > 1:
             agent.malfunction_data['malfunction'] -= 1
-            if 'moving_before_malfunction' in agent.malfunction_data:
-                agent.moving = agent.malfunction_data['moving_before_malfunction']
-                continue
+            return
 
+        # Restart agents at the end of their malfunction
+        agent.malfunction_data['malfunction'] -= 1
+        if 'moving_before_malfunction' in agent.malfunction_data:
+            agent.moving = agent.malfunction_data['moving_before_malfunction']
+            return
 
-    def _malfunction(self, rate):
+    def _break_agent(self, rate, agent):
         """
-        Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run
+        Malfunction generator that breaks agents at a given rate.
 
-        """
-        if self.np_random.rand() < self._malfunction_prob(rate, len(self.active_agents)):
+        Parameters
+        ----------
+        agent
 
-            breaking_agent = self._draw_malfunctioning_agent(0)
-            if breaking_agent:
+        """
+        if agent.malfunction_data['malfunction'] < 1:
+            if self.np_random.rand() < self._malfunction_prob(rate):
                 num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
                                                               self.max_number_of_steps_broken + 1) + 1
-                breaking_agent.malfunction_data['malfunction'] = num_broken_steps
-                breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
-                breaking_agent.malfunction_data['nr_malfunctions'] += 1
+                agent.malfunction_data['malfunction'] = num_broken_steps
+                agent.malfunction_data['moving_before_malfunction'] = agent.moving
+                agent.malfunction_data['nr_malfunctions'] += 1
         return
 
 
@@ -477,13 +483,15 @@ class RailEnv(Environment):
         }
         have_all_agents_ended = True  # boolean flag to check if all agents are done
 
-        # Induce malfunctions
-        self._malfunction(self.mean_malfunction_rate)
+
 
         for i_agent, agent in enumerate(self.agents):
             # Reset the step rewards
             self.rewards_dict[i_agent] = 0
 
+            # Induce malfunction before we do a step, thus a broken agent can't move in this step
+            self._break_agent(self.mean_malfunction_rate, agent)
+
             # Perform step on the agent
             self._step_agent(i_agent, action_dict_.get(i_agent))
 
@@ -496,8 +504,8 @@ class RailEnv(Environment):
             info_dict["speed"][i_agent] = agent.speed_data['speed']
             info_dict["status"][i_agent] = agent.status
 
-        # Fix agents that finished their malfunction
-        self._fix_agents()
+            # Fix agents that finished their malfunction such that they can perfom an action in the next step
+            self._fix_agent(agent)
 
         # Check for end of episode + set global reward to all rewards!
         if have_all_agents_ended:
@@ -960,7 +968,7 @@ class RailEnv(Environment):
         x = - np.log(1 - u) * rate
         return x
 
-    def _malfunction_prob(self, rate, n_agents):
+    def _malfunction_prob(self, rate):
         """
         Probability that an agent break given the number of agents an the probability of a sinlge agent to break
         :param rate:
@@ -969,7 +977,7 @@ class RailEnv(Environment):
         if rate <= 0:
             return 0.
         else:
-            return 1 - np.exp(- (1 / rate) * (n_agents))
+            return 1 - np.exp(- (1 / rate))
 
     def _draw_malfunctioning_agent(self, tries):
         """
@@ -1000,7 +1008,7 @@ class RailEnv(Environment):
         else:
             return breaking_agent
 
-    def _agent_is_ok(self, agent):
+    def _is_ok(self, agent):
         """
         Check if an agent is ok, meaning it can move and is not malfuncitoinig
         Parameters
-- 
GitLab