From a165ac00c1cbc94219574c893df5e942cba6ef2e Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 8 Oct 2019 09:04:27 -0400
Subject: [PATCH] fixed first tests in malfunction test

---
 examples/flatland_2_0_example.py   |  2 +-
 flatland/envs/distance_map.py      |  1 +
 flatland/envs/rail_env.py          |  5 ++---
 tests/test_flatland_malfunction.py | 17 ++++++-----------
 4 files changed, 10 insertions(+), 15 deletions(-)

diff --git a/examples/flatland_2_0_example.py b/examples/flatland_2_0_example.py
index b9ace9f5..5ece03e9 100644
--- a/examples/flatland_2_0_example.py
+++ b/examples/flatland_2_0_example.py
@@ -37,7 +37,7 @@ env = RailEnv(width=100,
                                                    seed=14,  # Random seed
                                                    grid_mode=False,
                                                    max_rails_between_cities=2,
-                                                   max_rails_in_city=6,
+                                                   max_rails_in_city=8,
                                                    ),
               schedule_generator=sparse_schedule_generator(speed_ration_map),
               number_of_agents=100,
diff --git a/flatland/envs/distance_map.py b/flatland/envs/distance_map.py
index 2bc1a511..c6e73b0b 100644
--- a/flatland/envs/distance_map.py
+++ b/flatland/envs/distance_map.py
@@ -55,6 +55,7 @@ class DistanceMap:
         self.env_width = rail.width
 
     def _compute(self, agents: List[EnvAgent], rail: GridTransitionMap):
+        print("computing distance map")
         self.agents_previous_computation = self.agents
         self.distance_map = np.inf * np.ones(shape=(len(agents),
                                                     self.env_height,
diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 464d34bd..df0b8848 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -9,7 +9,6 @@ from typing import List, NamedTuple, Optional, Dict
 import msgpack
 import msgpack_numpy as m
 import numpy as np
-
 from gym.utils import seeding
 
 from flatland.core.env import Environment
@@ -187,7 +186,7 @@ class RailEnv(Environment):
         self.distance_map = DistanceMap(self.agents, self.height, self.width)
 
         self.action_space = [1]
-        
+
         # Stochastic train malfunctioning parameters
         if stochastic_data is not None:
             prop_malfunction = stochastic_data['prop_malfunction']
@@ -466,7 +465,7 @@ class RailEnv(Environment):
             return
 
         # Is the agent at the beginning of the cell? Then, it can take an action.
-        # As long as the agent is malfunctioning or stopped at the beginning of the cell, 
+        # As long as the agent is malfunctioning or stopped at the beginning of the cell,
         # different actions may be taken!
         if agent.speed_data['position_fraction'] == 0.0:
             # No action has been supplied for this agent -> set DO_NOTHING as default
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index e87bd93d..c72fc519 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -70,14 +70,6 @@ def test_malfunction_process():
                        'malfunction_rate': 1000,
                        'min_duration': 3,
                        'max_duration': 3}
-    # random.seed(0)
-    # np.random.seed(0)
-
-    stochastic_data = {'prop_malfunction': 1.,  # Percentage of defective agents
-                       'malfunction_rate': 70,  # Rate of malfunction occurence
-                       'min_duration': 2,  # Minimal duration of malfunction
-                       'max_duration': 5  # Max duration of malfunction
-                       }
 
     rail, rail_map = make_simple_rail2()
 
@@ -90,8 +82,8 @@ def test_malfunction_process():
                   obs_builder_object=SingleAgentNavigationObs()
                   )
     # reset to initialize agents_static
-    obs = env.reset(False, False, True, random_seed=0)
-
+    obs, info = env.reset(False, False, True, random_seed=0)
+    print(env.agents[0].malfunction_data)
     # Check that a initial duration for malfunction was assigned
     assert env.agents[0].malfunction_data['next_malfunction'] > 0
     for agent in env.agents:
@@ -100,6 +92,9 @@ def test_malfunction_process():
     agent_halts = 0
     total_down_time = 0
     agent_old_position = env.agents[0].position
+
+    # Move target to unreachable position in order to not interfere with test
+    env.agents[0].target = (0, 0)
     for step in range(100):
         actions = {}
 
@@ -157,6 +152,7 @@ def test_malfunction_process_statistically():
                   )
     # reset to initialize agents_static
     env.reset(False, False, False, random_seed=0)
+
     env.agents[0].target = (0, 0)
     nb_malfunction = 0
     for step in range(20):
@@ -166,7 +162,6 @@ def test_malfunction_process_statistically():
             action_dict[agent.handle] = RailEnvActions(np.random.randint(4))
 
         env.step(action_dict)
-
     # check that generation of malfunctions works as expected
     assert env.agents[0].malfunction_data["nr_malfunctions"] == 4
 
-- 
GitLab