From e2cede22a3dec5241a835c24184cb707d20cb1ba Mon Sep 17 00:00:00 2001
From: Erik Nygren <baerenjesus@gmail.com>
Date: Tue, 9 Jun 2020 09:35:27 +0000
Subject: [PATCH] Update multi_agent_training.py

---
 torch_training/multi_agent_training.py | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py
index 0b5ccc7..68b0c40 100644
--- a/torch_training/multi_agent_training.py
+++ b/torch_training/multi_agent_training.py
@@ -44,10 +44,11 @@ def main(argv):
 
 
     # Use a the malfunction generator to break agents from time to time
-    stochastic_data = {'malfunction_rate': 8000,  # Rate of malfunction occurence of single agent
-                       'min_duration': 15,  # Minimal duration of malfunction
-                       'max_duration': 50  # Max duration of malfunction
-                       }
+    stochastic_data = MalfunctionParameters(malfunction_rate=10000,  # Rate of malfunction occurence
+                                            min_duration=15,  # Minimal duration of malfunction
+                                            max_duration=50  # Max duration of malfunction
+                                            )
+
 
     # Custom observation builder
     TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30))
@@ -71,6 +72,8 @@ def main(argv):
                   malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                   obs_builder_object=TreeObservation)
 
+    # Reset env
+    env.reset(True,True)
     # After training we want to render the results so we also load a renderer
     env_renderer = RenderTool(env, gl="PILSVG", )
     # Given the depth of the tree observation and the number of features per node we get the following state_size
-- 
GitLab