From 4efc903c533821c9741b9a9d9ed973f787cf90b9 Mon Sep 17 00:00:00 2001
From: MLErik <baerenjesus@gmail.com>
Date: Tue, 8 Oct 2019 09:34:42 -0400
Subject: [PATCH] added seeding test

---
 tests/test_flatland_malfunction.py |  2 +-
 tests/test_random_seeding.py       | 47 ++++++++++++++++++++++++++++++
 2 files changed, 48 insertions(+), 1 deletion(-)
 create mode 100644 tests/test_random_seeding.py

diff --git a/tests/test_flatland_malfunction.py b/tests/test_flatland_malfunction.py
index c72fc519..a5a46923 100644
--- a/tests/test_flatland_malfunction.py
+++ b/tests/test_flatland_malfunction.py
@@ -151,7 +151,7 @@ def test_malfunction_process_statistically():
                   obs_builder_object=SingleAgentNavigationObs()
                   )
     # reset to initialize agents_static
-    env.reset(False, False, False, random_seed=0)
+    env.reset(True, True, False, random_seed=0)
 
     env.agents[0].target = (0, 0)
     nb_malfunction = 0
diff --git a/tests/test_random_seeding.py b/tests/test_random_seeding.py
new file mode 100644
index 00000000..67c02e86
--- /dev/null
+++ b/tests/test_random_seeding.py
@@ -0,0 +1,47 @@
+import random
+
+import numpy as np
+
+from flatland.core.grid.grid4_utils import get_new_position
+from flatland.envs.rail_env import RailEnv, RailEnvActions
+from flatland.envs.rail_generators import rail_from_grid_transition_map
+from flatland.envs.schedule_generators import random_schedule_generator
+from flatland.utils.simple_rail import make_simple_rail2
+
+
+def test_random_seeding():
+    # Set fixed malfunction duration for this test
+    stochastic_data = {'prop_malfunction': 1.,
+                       'malfunction_rate': 1000,
+                       'min_duration': 3,
+                       'max_duration': 3}
+
+    rail, rail_map = make_simple_rail2()
+
+    env = RailEnv(width=25,
+                  height=30,
+                  rail_generator=rail_from_grid_transition_map(rail),
+                  schedule_generator=random_schedule_generator(),
+                  number_of_agents=1,
+                  stochastic_data=stochastic_data,  # Malfunction data generator
+                  )
+    # reset to initialize agents_static
+    obs, info = env.reset(True, True, False, random_seed=0)
+    env.agents[0].target = (0, 0)
+    assert env.agents[0].initial_position == (3, 3)
+    # Move target to unreachable position in order to not interfere with test
+    for idx in range(2):
+        env.reset(True, True, False, random_seed=0)
+        # Test generation print
+        # print("assert env.agents[0].initial_position == {}".format(env.agents[0].initial_position))
+        env.agents[0].target = (0, 0)
+        assert env.agents[0].initial_position == (3, 3)
+        for step in range(3):
+            actions = {}
+
+            for i in range(len(obs)):
+                actions[i] = np.random.randint(4)
+            env.step(actions)
+        assert env.agents[0].position == (3, 9)
+        # Test generation print
+        # print("assert  env.agents[0].position == {}".format(env.agents[0].position))
-- 
GitLab