From 000c2c19b9660035c0e191b2239559eaf42ee859 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Fri, 25 Oct 2019 17:15:02 -0400
Subject: [PATCH] updated behavior of malfunction to reflect the conserns of
 participant

---
 flatland/envs/rail_env.py | 67 +++++++++++++++++----------------------
 1 file changed, 29 insertions(+), 38 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 7a936201..d5f41491 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -339,8 +339,8 @@ class RailEnv(Environment):
             if agents_hints and 'city_orientations' in agents_hints:
                 ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations'])
                 self._max_episode_steps = self.compute_max_episode_steps(
-                    width=self.width, height=self.height,
-                    ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
+                                                    width=self.width, height=self.height,
+                                                    ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
             else:
                 self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
 
@@ -397,49 +397,40 @@ class RailEnv(Environment):
         """
         agent = self.agents[i_agent]
 
-        # Skip agents that cannot break
-        # TODO: Make a better malfunction model such that not always the same agents break.
+        # Ignore agents that dont have positive malfunction rate
         if agent.malfunction_data['malfunction_rate'] < 1:
             return False
 
-        # If agent is currently working and next malfunction time is reached we set it to malfunctioning
+        # Update malfunctioning agents
+        if agent.malfunction_data['malfunction'] > 0:
+            agent.malfunction_data['malfunction'] -= 1
+            return True
+
+        # Restart fixed agents
+        if agent.malfunction_data['malfunction'] < 1 and agent.malfunction_data['next_malfunction'] > 0:
+            if 'moving_before_malfunction' in agent.malfunction_data:
+                self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
+            return False
+
+        # Break agents that have next_malfunction
         if 1 > agent.malfunction_data['malfunction'] and agent.malfunction_data['next_malfunction'] < 1:
             # Increase number of malfunctions
             agent.malfunction_data['nr_malfunctions'] += 1
 
-            # Next malfunction in number of steps
+            # Next malfunction in number of stops
             next_breakdown = int(
                 self._exp_distirbution_synced(rate=agent.malfunction_data['malfunction_rate']))
             agent.malfunction_data['next_malfunction'] = max(next_breakdown, 1)
-
             # Duration of current malfunction
             num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
-                                                      self.max_number_of_steps_broken + 1)
+                                                      self.max_number_of_steps_broken + 1) + 1
             agent.malfunction_data['malfunction'] = num_broken_steps
-            # Remember current moving state of the agent
             agent.malfunction_data['moving_before_malfunction'] = agent.moving
 
             return True
-        else:
-            # The train was broken before...
-            if agent.malfunction_data['malfunction'] > 0:
-
-                # Last step of malfunction --> Agent starts moving again after getting fixed
-                if agent.malfunction_data['malfunction'] < 2:
-                    agent.malfunction_data['malfunction'] -= 1
-
-                    # restore moving state before malfunction without further penalty
-                    self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
-
-                else:
-                    agent.malfunction_data['malfunction'] -= 1
-
-                    # Nothing left to do with broken agent
-                    return True
 
         # Decrease counter for next event only if agent is currently not broken and agent has a malfunction rate
-        if agent.malfunction_data['malfunction_rate'] >= 1 and agent.malfunction_data['next_malfunction'] > 0 and \
-            agent.malfunction_data['malfunction'] < 1:
+        if agent.malfunction_data['next_malfunction'] > 0 and agent.malfunction_data['malfunction'] < 1:
             agent.malfunction_data['next_malfunction'] -= 1
 
         return False
@@ -459,10 +450,10 @@ class RailEnv(Environment):
         if self.dones["__all__"]:
             self.rewards_dict = {}
             info_dict = {
-                "action_required": {},
-                "malfunction": {},
-                "speed": {},
-                "status": {},
+                "action_required" : {},
+                "malfunction" : {},
+                "speed" : {},
+                "status" : {},
             }
             for i_agent, agent in enumerate(self.agents):
                 self.rewards_dict[i_agent] = self.global_reward
@@ -476,12 +467,12 @@ class RailEnv(Environment):
         # Reset the step rewards
         self.rewards_dict = dict()
         info_dict = {
-            "action_required": {},
-            "malfunction": {},
-            "speed": {},
-            "status": {},
+            "action_required" : {},
+            "malfunction" : {},
+            "speed" : {},
+            "status" : {},
         }
-        have_all_agents_ended = True  # boolean flag to check if all agents are done
+        have_all_agents_ended = True # boolean flag to check if all agents are done
         for i_agent, agent in enumerate(self.agents):
             # Reset the step rewards
             self.rewards_dict[i_agent] = 0
@@ -495,8 +486,8 @@ class RailEnv(Environment):
             # Build info dict
             info_dict["action_required"][i_agent] = \
                 (agent.status == RailAgentStatus.READY_TO_DEPART or (
-                    agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
-                                                                          rtol=1e-03)))
+                agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
+                                                                        rtol=1e-03)))
             info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
             info_dict["speed"][i_agent] = agent.speed_data['speed']
             info_dict["status"][i_agent] = agent.status
-- 
GitLab