From ce1386648b4733a7fc6a5649b1bb53f45551d865 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Sat, 10 Aug 2019 14:54:21 -0400
Subject: [PATCH] updated poisson process for malfunction of agents

---
 examples/debugging_example_DELETE.py |  8 ++-
 examples/training_example.py         |  8 +--
 flatland/envs/agent_utils.py         | 13 ++---
 flatland/envs/rail_env.py            | 74 ++++++++++++++--------------
 4 files changed, 53 insertions(+), 50 deletions(-)

diff --git a/examples/debugging_example_DELETE.py b/examples/debugging_example_DELETE.py
index 8df84833..56148f20 100644
--- a/examples/debugging_example_DELETE.py
+++ b/examples/debugging_example_DELETE.py
@@ -3,11 +3,8 @@ import time
 
 import numpy as np
 
-from flatland.core.env_observation_builder import ObservationBuilder
-from flatland.core.grid.grid_utils import coordinate_to_position
-from flatland.envs.generators import random_rail_generator, complex_rail_generator
+from flatland.envs.generators import complex_rail_generator
 from flatland.envs.observations import TreeObsForRailEnv
-from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv
 from flatland.utils.rendertools import RenderTool
 
@@ -77,7 +74,8 @@ for step in range(100):
         actions[0] = 4 # Halt
 
     obs, all_rewards, done, _ = env.step(actions)
-    print("Agent 0 broken-ness: ", env.agents[0].broken_data['broken'])
+    if env.agents[0].broken_data['broken'] > 0:
+        print("Agent 0 broken-ness: ", env.agents[0].broken_data['broken'])
 
     env_renderer.render_env(show=True, frames=True, show_observations=False)
     time.sleep(0.5)
diff --git a/examples/training_example.py b/examples/training_example.py
index cfed6c92..17484ad7 100644
--- a/examples/training_example.py
+++ b/examples/training_example.py
@@ -16,9 +16,9 @@ TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictor
 LocalGridObs = LocalObsForRailEnv(view_height=10, view_width=2, center=2)
 env = RailEnv(width=50,
               height=50,
-              rail_generator=complex_rail_generator(nr_start_goal=10, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
+              rail_generator=complex_rail_generator(nr_start_goal=20, nr_extra=1, min_dist=8, max_dist=99999, seed=0),
               obs_builder_object=TreeObservation,
-              number_of_agents=5)
+              number_of_agents=20)
 
 env_renderer = RenderTool(env, gl="PILSVG", )
 
@@ -75,6 +75,7 @@ for trials in range(1, n_trials + 1):
 
     score = 0
     # Run episode
+    mean_malfunction_interval = []
     for step in range(100):
         # Chose an action for each agent in the environment
         for a in range(env.get_num_agents()):
@@ -84,7 +85,7 @@ for trials in range(1, n_trials + 1):
         # Environment step which returns the observations for all agents, their corresponding
         # reward and whether their are done
         next_obs, all_rewards, done, _ = env.step(action_dict)
-        env_renderer.render_env(show=True, show_observations=True, show_predictions=True)
+        env_renderer.render_env(show=True, show_observations=False, show_predictions=True)
 
         # Update replay buffer and train agent
         for a in range(env.get_num_agents()):
@@ -94,4 +95,5 @@ for trials in range(1, n_trials + 1):
         obs = next_obs.copy()
         if done['__all__']:
             break
+    print(np.mean(mean_malfunction_interval))
     print('Episode Nr. {}\t Score = {}'.format(trials, score))
diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index 2017d706..27a7a380 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -25,8 +25,8 @@ class EnvAgentStatic(object):
 
     # if broken>0, the agent's actions are ignored for 'broken' steps
     # number of time the agent had to stop, since the last time it broke down
-    broken_data = attrib(
-        default=Factory(lambda: dict({'broken': 0, 'number_of_halts': 0})))
+    malfunction_data = attrib(
+        default=Factory(lambda: dict({'malfunction': 0, 'malfunction_rate': 0, 'next_malfunction': 0})))
 
     @classmethod
     def from_lists(cls, positions, directions, targets, speeds=None):
@@ -42,8 +42,9 @@ class EnvAgentStatic(object):
         # some as broken?
         broken_datas = []
         for i in range(len(positions)):
-            broken_datas.append({'broken': 0,
-                                'number_of_halts': 0})
+            broken_datas.append({'malfunction': 0,
+                                 'malfunction_rate': 0,
+                                 'next_malfunction': 0})
 
         return list(starmap(EnvAgentStatic, zip(positions,
                                                 directions,
@@ -64,7 +65,7 @@ class EnvAgentStatic(object):
         if type(lTarget) is np.ndarray:
             lTarget = lTarget.tolist()
 
-        return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.broken_data]
+        return [lPos, int(self.direction), lTarget, int(self.moving), self.speed_data, self.malfunction_data]
 
 
 @attrs
@@ -82,7 +83,7 @@ class EnvAgent(EnvAgentStatic):
     def to_list(self):
         return [
             self.position, self.direction, self.target, self.handle,
-            self.old_direction, self.old_position, self.moving, self.speed_data, self.broken_data]
+            self.old_direction, self.old_position, self.moving, self.speed_data, self.malfunction_data]
 
     @classmethod
     def from_static(cls, oStatic):
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 0f03870b..cfd8dad4 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -75,16 +75,13 @@ class RailEnv(Environment):
     - stop_penalty = 0  # penalty for stopping a moving agent
     - start_penalty = 0  # penalty for starting a stopped agent
 
-    Stochastic breaking of trains:
-    Trains in RailEnv can break down if they are halted too often (either by their own choice or because an invalid
+    Stochastic malfunctioning of trains:
+    Trains in RailEnv can malfunction if they are halted too often (either by their own choice or because an invalid
     action or cell is selected.
 
-    Every time an agent stops, an agent has a certain probability of breaking. The probability is the product of 2
-    distributions: the first distribution selects the average number of trains that will break during an episode
-    (e.g., max(1, 10% of the trains) ). The second distribution is a Poisson distribution with mean set to the average
-    number of stops at which a train breaks.
-    If a random number in [0,1] is lower than the product of the 2 distributions, the train breaks.
-    A broken train samples a random number of steps it will stay broken for, during which all its actions are ignored. 
+    Every time an agent stops, an agent has a certain probability of malfunctioning. Malfunctions of trains follow a
+    poisson process with a certain rate. Not all trains will be affected by malfunctions during episodes to keep
+    complexity managable.
 
     TODO: currently, the parameters that control the stochasticity of the environment are hard-coded in init().
     For Round 2, they will be passed to the constructor as arguments, to allow for more flexibility.
@@ -160,20 +157,20 @@ class RailEnv(Environment):
         self.action_space = [1]
         self.observation_space = self.obs_builder.observation_space  # updated on resets?
 
+        # Stochastic train malfunctioning parameters
+        self.proportion_malfunctioning_trains = 0.1  # percentage of malfunctioning trains
+        self.mean_malfunction_rate = 5  # Average malfunction in number of stops
+
+        # Uniform distribution parameters for malfunction duration
+        self.min_number_of_steps_broken = 4
+        self.max_number_of_steps_broken = 10
+
+        # Rest environment
         self.reset()
         self.num_resets = 0  # yes, set it to zero again!
 
         self.valid_positions = None
 
-        # Stochastic train breaking parameters
-        self.min_average_broken_trains = 1
-        self.average_proportion_of_broken_trains = 0.1  # ~10% of the trains can be expected to break down in an episode
-        self.mean_number_halts_to_break = 3
-
-        # Uniform distribution
-        self.min_number_of_steps_broken = 4
-        self.max_number_of_steps_broken = 8
-
     # no more agent_handles
     def get_agent_handles(self):
         return range(self.get_num_agents())
@@ -218,9 +215,12 @@ class RailEnv(Environment):
 
         for i_agent in range(self.get_num_agents()):
             agent = self.agents[i_agent]
+
+            # A proportion of agent in the environment will receive a positive malfunction rate
+            if np.random.random() >= self.proportion_malfunctioning_trains:
+                agent.malfunction_data['malfunction_rate'] = self.mean_malfunction_rate
             agent.speed_data['position_fraction'] = 0.0
-            agent.broken_data['broken'] = 0
-            agent.broken_data['number_of_halts'] = 0
+            agent.malfunction_data['malfunction'] = 0
 
         self.num_resets += 1
         self._elapsed_steps = 0
@@ -236,24 +236,26 @@ class RailEnv(Environment):
         return self._get_observations()
 
     def _agent_stopped(self, i_agent):
-        self.agents[i_agent].broken_data['number_of_halts'] += 1
+        # Make sure agent is stopped
+        self.agents[i_agent].moving = False
+
+        # Only agents that have a positive rate for malfunctions are considered
+        if self.agents[i_agent].malfunction_data['malfunction_rate'] > 0:
 
-        def poisson_pdf(x, mean):
-            return np.power(mean, x) * np.exp(-mean) / np.prod(range(2, x))
+            # Decrease counter for next event
+            self.agents[i_agent].malfunction_data['next_malfunction'] -= 1
 
-        p1_prob_train_i_breaks = max(self.min_average_broken_trains / len(self.agents),
-                                     self.average_proportion_of_broken_trains)
-        p2_prob_train_breaks_at_halt_j = poisson_pdf(self.agents[i_agent].broken_data['number_of_halts'],
-                                                     self.mean_number_halts_to_break)
+            # If counter has come to zero, set next malfunction time and duration of current malfunction
 
-        s1 = np.random.random()
-        s2 = np.random.random()
+            if self.agents[i_agent].malfunction_data['next_malfunction'] <= 0:
+                # Next malfunction in number of stops
+                self.agents[i_agent].malfunction_data['next_malfunction'] = int(np.random.exponential(
+                    scale=self.agents[i_agent].malfunction_data['malfunction_rate']))
 
-        if s1 * s2 <= p1_prob_train_i_breaks * p2_prob_train_breaks_at_halt_j:
-            # +1 because the counter is decreased at the beginning of step()
-            num_broken_steps = np.random.randint(self.min_number_of_steps_broken, self.max_number_of_steps_broken+1) + 1
-            self.agents[i_agent].broken_data['broken'] = num_broken_steps
-            self.agents[i_agent].broken_data['number_of_halts'] = 0
+                # Duration of current malfunction
+                num_broken_steps = np.random.randint(self.min_number_of_steps_broken,
+                                                     self.max_number_of_steps_broken + 1) + 1
+                self.agents[i_agent].malfunction_data['malfunction'] = num_broken_steps
 
     def step(self, action_dict_):
         self._elapsed_steps += 1
@@ -284,8 +286,8 @@ class RailEnv(Environment):
             agent.old_direction = agent.direction
             agent.old_position = agent.position
 
-            if agent.broken_data['broken'] > 0:
-                agent.broken_data['broken'] -= 1
+            if agent.malfunction_data['malfunction'] > 0:
+                agent.malfunction_data['malfunction'] -= 1
 
             if self.dones[i_agent]:  # this agent has already completed...
                 continue
@@ -295,7 +297,7 @@ class RailEnv(Environment):
                 action_dict[i_agent] = RailEnvActions.DO_NOTHING
 
             # The train is broken
-            if agent.broken_data['broken'] > 0:
+            if agent.malfunction_data['malfunction'] > 0:
                 action_dict[i_agent] = RailEnvActions.DO_NOTHING
 
             if action_dict[i_agent] < 0 or action_dict[i_agent] > len(RailEnvActions):
-- 
GitLab