From e1ce40085b77bf0c85658b94892b504cc3562932 Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Thu, 7 Nov 2019 15:38:29 -0500 Subject: [PATCH] updated files to new malfunction behavior --- torch_training/multi_agent_inference.py | 11 ++++++----- torch_training/multi_agent_training.py | 13 +++++++------ torch_training/render_agent_behavior.py | 12 ++++++------ torch_training/training_navigation.py | 14 ++++++++------ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 1866443..0f9d1e5 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,6 +3,7 @@ from collections import deque import numpy as np import torch +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -36,12 +37,12 @@ n_agents = 10 observation_builder = TreeObsForRailEnv(max_depth=2) # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction +stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } + # Custom observation builder TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) @@ -61,7 +62,7 @@ env = RailEnv(width=x_dim, max_rails_in_city=2), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) env.reset(True, True) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 2e20c63..93c7eff 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -2,9 +2,11 @@ import getopt import random import sys from collections import deque - # make sure the root path is in system path from pathlib import Path + +from flatland.envs.malfunction_generators import malfunction_from_params + base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -41,10 +43,9 @@ def main(argv): # Use a the malfunction generator to break agents from time to time - stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction + stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } # Custom observation builder @@ -66,7 +67,7 @@ def main(argv): max_rails_in_city=3), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) # After training we want to render the results so we also load a renderer diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index dd81bea..b32e4df 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -3,6 +3,7 @@ from collections import deque import numpy as np import torch +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator @@ -37,10 +38,9 @@ min_dist = 5 observation_builder = TreeObsForRailEnv(max_depth=2) # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction +stochastic_data = {'malfunction_rate': 80, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } # Custom observation builder @@ -59,10 +59,10 @@ env = RailEnv(width=x_dim, seed=1, # Random seed grid_mode=False, max_rails_between_cities=2, - max_rails_in_city=2), + max_rails_in_city=4), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) env.reset() diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index daac62d..fe4905f 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -2,9 +2,11 @@ import getopt import random import sys from collections import deque - # make sure the root path is in system path from pathlib import Path + +from flatland.envs.malfunction_generators import malfunction_from_params + base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -40,10 +42,9 @@ def main(argv): # Use a the malfunction generator to break agents from time to time - stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction + stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } # Custom observation builder @@ -65,7 +66,8 @@ def main(argv): max_rails_in_city=3), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + # Malfunction data generator obs_builder_object=TreeObservation) # After training we want to render the results so we also load a renderer -- GitLab