From d5b16a5210992b38360021ba9d5194535d9b75a3 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 29 Oct 2019 18:37:24 -0400
Subject: [PATCH] updating tests to new malfunction generation

---
 flatland/envs/agent_utils.py         |  1 +
 flatland/envs/rail_env.py            | 49 ++++++++++++----------------
 flatland/envs/schedule_generators.py | 15 ++++-----
 flatland/envs/schedule_utils.py      |  3 +-
 tests/test_flatland_malfunction.py   | 19 ++---------
 5 files changed, 31 insertions(+), 56 deletions(-)

diff --git a/flatland/envs/agent_utils.py b/flatland/envs/agent_utils.py
index ef2d4855..01ce2908 100644
--- a/flatland/envs/agent_utils.py
+++ b/flatland/envs/agent_utils.py
@@ -63,6 +63,7 @@ class EnvAgentStatic(object):
         for i in range(len(schedule.agent_positions)):
             malfunction_datas.append({'malfunction': 0,
                                       'nr_malfunctions': 0,
+                                      'moving_before_malfunction': False,
                                       'fixed': True})
 
         return list(starmap(EnvAgentStatic, zip(schedule.agent_positions,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 3e17a4a4..53784a67 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -1,6 +1,7 @@
 """
 Definition of the RailEnv environment.
 """
+import random
 # TODO:  _ this is a global method --> utils or remove later
 from enum import IntEnum
 from typing import List, NamedTuple, Optional, Dict
@@ -8,7 +9,6 @@ from typing import List, NamedTuple, Optional, Dict
 import msgpack
 import msgpack_numpy as m
 import numpy as np
-import random
 from gym.utils import seeding
 
 from flatland.core.env import Environment
@@ -209,7 +209,6 @@ class RailEnv(Environment):
         # Uniform distribution parameters for malfunction duration
         self.min_number_of_steps_broken = malfunction_min_duration
         self.max_number_of_steps_broken = malfunction_max_duration
-        # Reset environment
 
         self.valid_positions = None
 
@@ -331,8 +330,8 @@ class RailEnv(Environment):
             if agents_hints and 'city_orientations' in agents_hints:
                 ratio_nr_agents_to_nr_cities = self.get_num_agents() / len(agents_hints['city_orientations'])
                 self._max_episode_steps = self.compute_max_episode_steps(
-                                                    width=self.width, height=self.height,
-                                                    ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
+                    width=self.width, height=self.height,
+                    ratio_nr_agents_to_nr_cities=ratio_nr_agents_to_nr_cities)
             else:
                 self._max_episode_steps = self.compute_max_episode_steps(width=self.width, height=self.height)
 
@@ -394,9 +393,6 @@ class RailEnv(Environment):
             self.agents[i_agent].moving = agent.malfunction_data['moving_before_malfunction']
         return False
 
-
-
-
     def _malfunction(self, rate) -> bool:
         """
         Malfunction generator that breaks agents at a given rate. It does randomly chose agent to break during the run
@@ -404,16 +400,13 @@ class RailEnv(Environment):
         """
         if np.random.random() < self._malfunction_prob(rate):
             breaking_agent = random.choice(self.agents)
-            while breaking_agent.status == RailAgentStatus.DONE_REMOVED:
-                breaking_agent = random.choice(self.agents)
-
-            num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
-                                                      self.max_number_of_steps_broken + 1)
-            breaking_agent.malfunction_data['malfunction'] = num_broken_steps
-            breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
-            breaking_agent.malfunction_data['fixed'] = False
-
-
+            if breaking_agent.malfunction_data['malfunction'] < 1:
+                num_broken_steps = self.np_random.randint(self.min_number_of_steps_broken,
+                                                          self.max_number_of_steps_broken + 1)
+                breaking_agent.malfunction_data['malfunction'] = num_broken_steps
+                breaking_agent.malfunction_data['moving_before_malfunction'] = breaking_agent.moving
+                breaking_agent.malfunction_data['fixed'] = False
+                breaking_agent.malfunction_data['nr_malfunctions'] += 1
 
     def step(self, action_dict_: Dict[int, RailEnvActions]):
 
@@ -423,10 +416,10 @@ class RailEnv(Environment):
         if self.dones["__all__"]:
             self.rewards_dict = {}
             info_dict = {
-                "action_required" : {},
-                "malfunction" : {},
-                "speed" : {},
-                "status" : {},
+                "action_required": {},
+                "malfunction": {},
+                "speed": {},
+                "status": {},
             }
             for i_agent, agent in enumerate(self.agents):
                 self.rewards_dict[i_agent] = self.global_reward
@@ -440,12 +433,12 @@ class RailEnv(Environment):
         # Reset the step rewards
         self.rewards_dict = dict()
         info_dict = {
-            "action_required" : {},
-            "malfunction" : {},
-            "speed" : {},
-            "status" : {},
+            "action_required": {},
+            "malfunction": {},
+            "speed": {},
+            "status": {},
         }
-        have_all_agents_ended = True # boolean flag to check if all agents are done
+        have_all_agents_ended = True  # boolean flag to check if all agents are done
 
         # Evoke the malfunction generator
         self._malfunction(self.mean_malfunction_rate)
@@ -462,8 +455,8 @@ class RailEnv(Environment):
             # Build info dict
             info_dict["action_required"][i_agent] = \
                 (agent.status == RailAgentStatus.READY_TO_DEPART or (
-                agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
-                                                                        rtol=1e-03)))
+                    agent.status == RailAgentStatus.ACTIVE and np.isclose(agent.speed_data['position_fraction'], 0.0,
+                                                                          rtol=1e-03)))
             info_dict["malfunction"][i_agent] = agent.malfunction_data['malfunction']
             info_dict["speed"][i_agent] = agent.speed_data['speed']
             info_dict["status"][i_agent] = agent.status
diff --git a/flatland/envs/schedule_generators.py b/flatland/envs/schedule_generators.py
index cb8b1537..58a7be34 100644
--- a/flatland/envs/schedule_generators.py
+++ b/flatland/envs/schedule_generators.py
@@ -79,7 +79,7 @@ def complex_schedule_generator(speed_ratio_map: Mapping[float, float] = None, se
             speeds = [1.0] * len(agents_position)
 
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
-                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
+                        agent_targets=agents_target, agent_speeds=speeds)
 
     return generator
 
@@ -165,7 +165,7 @@ def sparse_schedule_generator(speed_ratio_map: Mapping[float, float] = None, see
             speeds = [1.0] * len(agents_position)
 
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
-                        agent_targets=agents_target, agent_speeds=speeds, agent_malfunction_rates=None)
+                        agent_targets=agents_target, agent_speeds=speeds)
 
     return generator
 
@@ -199,12 +199,12 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
                     valid_positions.append((r, c))
         if len(valid_positions) == 0:
             return Schedule(agent_positions=[], agent_directions=[],
-                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
+                            agent_targets=[], agent_speeds=[])
 
         if len(valid_positions) < num_agents:
             warnings.warn("schedule_generators: len(valid_positions) < num_agents")
             return Schedule(agent_positions=[], agent_directions=[],
-                            agent_targets=[], agent_speeds=[], agent_malfunction_rates=None)
+                            agent_targets=[], agent_speeds=[])
 
         agents_position_idx = [i for i in np.random.choice(len(valid_positions), num_agents, replace=False)]
         agents_position = [valid_positions[agents_position_idx[i]] for i in range(num_agents)]
@@ -263,7 +263,7 @@ def random_schedule_generator(speed_ratio_map: Optional[Mapping[float, float]] =
 
         agents_speed = speed_initialization_helper(num_agents, speed_ratio_map, seed=_runtime_seed)
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
-                        agent_targets=agents_target, agent_speeds=agents_speed, agent_malfunction_rates=None)
+                        agent_targets=agents_target, agent_speeds=agents_speed)
 
     return generator
 
@@ -304,12 +304,9 @@ def schedule_from_file(filename, load_from_package=None) -> ScheduleGenerator:
         agents_target = [a.target for a in agents_static]
         if len(data['agents_static'][0]) > 5:
             agents_speed = [a.speed_data['speed'] for a in agents_static]
-            agents_malfunction = [a.malfunction_data['malfunction_rate'] for a in agents_static]
         else:
             agents_speed = None
-            agents_malfunction = None
         return Schedule(agent_positions=agents_position, agent_directions=agents_direction,
-                        agent_targets=agents_target, agent_speeds=agents_speed,
-                        agent_malfunction_rates=agents_malfunction)
+                        agent_targets=agents_target, agent_speeds=agents_speed)
 
     return generator
diff --git a/flatland/envs/schedule_utils.py b/flatland/envs/schedule_utils.py
index e89f170d..c61d2f6b 100644
--- a/flatland/envs/schedule_utils.py
+++ b/flatland/envs/schedule_utils.py
@@ -6,5 +6,4 @@ from flatland.core.grid.grid_utils import IntVector2DArray
 Schedule = NamedTuple('Schedule', [('agent_positions', IntVector2DArray),
                                    ('agent_directions', List[Grid4TransitionsEnum]),
                                    ('agent_targets', IntVector2DArray),
-                                   ('agent_speeds', List[float]),
-                                   ('agent_malfunction_rates', List[int])])
+                                   ('agent_speeds', List[float])])
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index a46467ed..7eac117f 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -66,8 +66,7 @@ class SingleAgentNavigationObs(ObservationBuilder):
 
 def test_malfunction_process():
     # Set fixed malfunction duration for this test
-    stochastic_data = {'prop_malfunction': 1.,
-                       'malfunction_rate': 1000,
+    stochastic_data = {'malfunction_rate': 1,
                        'min_duration': 3,
                        'max_duration': 3}
 
@@ -84,11 +83,6 @@ def test_malfunction_process():
     # reset to initialize agents_static
     obs, info = env.reset(False, False, True, random_seed=10)
 
-    # Check that a initial duration for malfunction was assigned
-    assert env.agents[0].malfunction_data['next_malfunction'] > 0
-    for agent in env.agents:
-        agent.status = RailAgentStatus.ACTIVE
-
     agent_halts = 0
     total_down_time = 0
     agent_old_position = env.agents[0].position
@@ -101,12 +95,6 @@ def test_malfunction_process():
         for i in range(len(obs)):
             actions[i] = np.argmax(obs[i]) + 1
 
-        if step % 5 == 0:
-            # Stop the agent and set it to be malfunctioning
-            env.agents[0].malfunction_data['malfunction'] = -1
-            env.agents[0].malfunction_data['next_malfunction'] = 0
-            agent_halts += 1
-
         obs, all_rewards, done, _ = env.step(actions)
 
         if env.agents[0].malfunction_data['malfunction'] > 0:
@@ -122,12 +110,9 @@ def test_malfunction_process():
         total_down_time += env.agents[0].malfunction_data['malfunction']
 
     # Check that the appropriate number of malfunctions is achieved
-    assert env.agents[0].malfunction_data['nr_malfunctions'] == 20, "Actual {}".format(
+    assert env.agents[0].malfunction_data['nr_malfunctions'] == 30, "Actual {}".format(
         env.agents[0].malfunction_data['nr_malfunctions'])
 
-    # Check that 20 stops where performed
-    assert agent_halts == 20
-
     # Check that malfunctioning data was standing around
     assert total_down_time > 0
 
-- 
GitLab