From 722b0928d9d1f9d29cf36e2e004e74e3b6ccacb5 Mon Sep 17 00:00:00 2001
From: u214892 <u214892@sbb.ch>
Date: Thu, 19 Sep 2019 16:43:29 +0200
Subject: [PATCH] #168 #163 multispeed and penalty testing

---
 flatland/envs/rail_env.py          |  2 --
 tests/test_flatland_malfunction.py |  8 ++++----
 tests/test_multi_speed.py          | 19 ++++---------------
 tests/test_utils.py                |  8 ++++++++
 4 files changed, 16 insertions(+), 21 deletions(-)

diff --git a/flatland/envs/rail_env.py b/flatland/envs/rail_env.py
index 0467fcd6..435571c8 100644
--- a/flatland/envs/rail_env.py
+++ b/flatland/envs/rail_env.py
@@ -360,8 +360,6 @@ class RailEnv(Environment):
 
         # Perform step on all agents
         for i_agent in range(self.get_num_agents()):
-            if self._elapsed_steps - 1 == 3 and i_agent == 0:
-                a = 5
             self._step_agent(i_agent, action_dict_.get(i_agent))
 
         # Check for end of episode + set global reward to all rewards!
diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index 473629e2..801c5062 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -9,7 +9,7 @@ from flatland.envs.observations import TreeObsForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, sparse_rail_generator
 from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
-from test_utils import Replay, ReplayConfig, run_replay_config
+from test_utils import Replay, ReplayConfig, run_replay_config, set_penalties_for_replay
 
 
 class SingleAgentNavigationObs(TreeObsForRailEnv):
@@ -187,7 +187,7 @@ def test_initial_malfunction():
                   number_of_agents=1,
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
-
+    set_penalties_for_replay(env)
     replay_config = ReplayConfig(
         replay=[
             Replay(
@@ -270,7 +270,7 @@ def test_initial_malfunction_stop_moving():
                   number_of_agents=1,
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
-
+    set_penalties_for_replay(env)
     replay_config = ReplayConfig(
         replay=[
             Replay(
@@ -363,7 +363,7 @@ def test_initial_malfunction_do_nothing():
                   number_of_agents=1,
                   stochastic_data=stochastic_data,  # Malfunction data generator
                   )
-
+    set_penalties_for_replay(env)
     replay_config = ReplayConfig(
         replay=[Replay(
             position=(28, 5),
diff --git a/tests/test_multi_speed.py b/tests/test_multi_speed.py
index f2fd3613..57c0430d 100644
--- a/tests/test_multi_speed.py
+++ b/tests/test_multi_speed.py
@@ -6,9 +6,8 @@ from flatland.envs.predictions import ShortestPathPredictorForRailEnv
 from flatland.envs.rail_env import RailEnv, RailEnvActions
 from flatland.envs.rail_generators import complex_rail_generator, rail_from_grid_transition_map
 from flatland.envs.schedule_generators import complex_schedule_generator, random_schedule_generator
-from flatland.utils.rendertools import RenderTool
 from flatland.utils.simple_rail import make_simple_rail
-from test_utils import ReplayConfig, Replay, run_replay_config
+from test_utils import ReplayConfig, Replay, run_replay_config, set_penalties_for_replay
 
 np.random.seed(1)
 
@@ -106,9 +105,7 @@ def test_multispeed_actions_no_malfunction_no_blocking(rendering=True):
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
 
-    # env.start_penalty = 13
-    # env.stop_penalty = 19
-
+    set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
             Replay(
@@ -208,7 +205,7 @@ def test_multispeed_actions_no_malfunction_blocking(rendering=True):
                   number_of_agents=2,
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
-
+    set_penalties_for_replay(env)
     test_configs = [
         ReplayConfig(
             replay=[
@@ -391,15 +388,7 @@ def test_multispeed_actions_malfunction_no_blocking(rendering=True):
                   obs_builder_object=TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()),
                   )
 
-    # initialize agents_static
-    env.reset()
-
-    # reset to set agents from agents_static
-    env.reset(False, False)
-
-    if rendering:
-        renderer = RenderTool(env, gl="PILSVG")
-
+    set_penalties_for_replay(env)
     test_config = ReplayConfig(
         replay=[
             Replay(
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 88d669fa..5821cd05 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -26,6 +26,13 @@ class ReplayConfig(object):
     speed = attrib(type=float)
 
 
+# ensure that env is working correctly with start/stop/invalidaction penalty different from 0
+def set_penalties_for_replay(env: RailEnv):
+    env.step_penalty = 13
+    env.stop_penalty = 19
+    env.invalid_action_penalty = 29
+
+
 def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering: bool = False):
     """
     Runs the replay configs and checks assertions.
@@ -54,6 +61,7 @@ def run_replay_config(env: RailEnv, test_configs: List[ReplayConfig], rendering:
     info_dict = {
         'action_required': [True for _ in test_configs]
     }
+
     for step in range(len(test_configs[0].replay)):
         if step == 0:
             for a, test_config in enumerate(test_configs):
-- 
GitLab