diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index b32e4dfbcf16bf298292bdc5f890ce043c52e14a..82706a4d2e22df7b11f03207d0b7d6aac891a89c 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -38,10 +38,10 @@ min_dist = 5 observation_builder = TreeObsForRailEnv(max_depth=2) # Use a the malfunction generator to break agents from time to time -stochastic_data = {'malfunction_rate': 80, # Rate of malfunction occurence of single agent - 'min_duration': 15, # Minimal duration of malfunction - 'max_duration': 50 # Max duration of malfunction - } +stochastic_data = MalfunctionParameters(malfunction_rate=10000, # Rate of malfunction occurence + min_duration=15, # Minimal duration of malfunction + max_duration=50 # Max duration of malfunction + ) # Custom observation builder TreeObservation = TreeObsForRailEnv(max_depth=2) @@ -64,7 +64,7 @@ env = RailEnv(width=x_dim, number_of_agents=n_agents, malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) -env.reset() +env.reset(True,True) env_renderer = RenderTool(env, gl="PILSVG", ) num_features_per_node = env.obs_builder.observation_dim