From e2cede22a3dec5241a835c24184cb707d20cb1ba Mon Sep 17 00:00:00 2001 From: Erik Nygren <baerenjesus@gmail.com> Date: Tue, 9 Jun 2020 09:35:27 +0000 Subject: [PATCH] Update multi_agent_training.py --- torch_training/multi_agent_training.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 0b5ccc7..68b0c40 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -44,10 +44,11 @@ def main(argv): # Use a the malfunction generator to break agents from time to time - stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of single agent - 'min_duration': 15, # Minimal duration of malfunction - 'max_duration': 50 # Max duration of malfunction - } + stochastic_data = MalfunctionParameters(malfunction_rate=10000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) + # Custom observation builder TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv(30)) @@ -71,6 +72,8 @@ def main(argv): malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) + # Reset env + env.reset(True,True) # After training we want to render the results so we also load a renderer env_renderer = RenderTool(env, gl="PILSVG", ) # Given the depth of the tree observation and the number of features per node we get the following state_size -- GitLab