diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 18664437ebef0dbf4261cfdb3ba692dd5fab7505..0f9d1e5da7cfb35ca16f4b243f9eb0013f3e824b 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,6 +3,7 @@ from collections import deque import numpy as np import torch +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv @@ -36,12 +37,12 @@ n_agents = 10 observation_builder = TreeObsForRailEnv(max_depth=2) # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction +stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } + # Custom observation builder TreeObservation = TreeObsForRailEnv(max_depth=2, predictor=ShortestPathPredictorForRailEnv()) @@ -61,7 +62,7 @@ env = RailEnv(width=x_dim, max_rails_in_city=2), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) env.reset(True, True) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 2e20c63b293b355afec2be33cbd9acca209039d4..93c7eff6ed7ea82fbe8cb4b16cbb1102b36f0b38 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -2,9 +2,11 @@ import getopt import random import sys from collections import deque - # make sure the root path is in system path from pathlib import Path + +from flatland.envs.malfunction_generators import malfunction_from_params + base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -41,10 +43,9 @@ def main(argv): # Use a the malfunction generator to break agents from time to time - stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction + stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } # Custom observation builder @@ -66,7 +67,7 @@ def main(argv): max_rails_in_city=3), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) # After training we want to render the results so we also load a renderer diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index dd81bead6bb26d40fb73987055ffb20e781f74eb..b32e4dfbcf16bf298292bdc5f890ce043c52e14a 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -3,6 +3,7 @@ from collections import deque import numpy as np import torch +from flatland.envs.malfunction_generators import malfunction_from_params from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.envs.rail_generators import sparse_rail_generator @@ -37,10 +38,9 @@ min_dist = 5 observation_builder = TreeObsForRailEnv(max_depth=2) # Use a the malfunction generator to break agents from time to time -stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction +stochastic_data = {'malfunction_rate': 80, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } # Custom observation builder @@ -59,10 +59,10 @@ env = RailEnv(width=x_dim, seed=1, # Random seed grid_mode=False, max_rails_between_cities=2, - max_rails_in_city=2), + max_rails_in_city=4), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), obs_builder_object=TreeObservation) env.reset() diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index daac62dbd26d190a2341d6dedc5ab0cfef1a2dff..fe4905f91cb84964ca27f8dded888a19a2146e6c 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -2,9 +2,11 @@ import getopt import random import sys from collections import deque - # make sure the root path is in system path from pathlib import Path + +from flatland.envs.malfunction_generators import malfunction_from_params + base_dir = Path(__file__).resolve().parent.parent sys.path.append(str(base_dir)) @@ -40,10 +42,9 @@ def main(argv): # Use a the malfunction generator to break agents from time to time - stochastic_data = {'prop_malfunction': 0.0, # Percentage of defective agents - 'malfunction_rate': 30, # Rate of malfunction occurence - 'min_duration': 3, # Minimal duration of malfunction - 'max_duration': 20 # Max duration of malfunction + stochastic_data = {'malfunction_rate': 8000, # Rate of malfunction occurence of single agent + 'min_duration': 15, # Minimal duration of malfunction + 'max_duration': 50 # Max duration of malfunction } # Custom observation builder @@ -65,7 +66,8 @@ def main(argv): max_rails_in_city=3), schedule_generator=sparse_schedule_generator(speed_ration_map), number_of_agents=n_agents, - stochastic_data=stochastic_data, # Malfunction data generator + malfunction_generator_and_process_data=malfunction_from_params(stochastic_data), + # Malfunction data generator obs_builder_object=TreeObservation) # After training we want to render the results so we also load a renderer