From b8eab2fb6b1df23098cb6829ae66e8032d7a2280 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Tue, 27 Aug 2019 11:23:34 +0200 Subject: [PATCH] #141 different agent classes --- scoring/utils/misc_utils.py | 20 +++++++++---------- sequential_agent/run_test.py | 9 ++++++--- torch_training/multi_agent_inference.py | 7 ++++--- torch_training/multi_agent_training.py | 11 ++++++---- .../multi_agent_two_time_step_training.py | 11 ++++++---- torch_training/render_agent_behavior.py | 8 +++++--- torch_training/training_navigation.py | 4 +++- utils/misc_utils.py | 11 ++++++---- 8 files changed, 49 insertions(+), 32 deletions(-) diff --git a/scoring/utils/misc_utils.py b/scoring/utils/misc_utils.py index b15476d..924f216 100644 --- a/scoring/utils/misc_utils.py +++ b/scoring/utils/misc_utils.py @@ -1,19 +1,19 @@ import random import time -from collections import deque import numpy as np + +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator, rail_from_file -from flatland.envs.observations import GlobalObsForRailEnv, TreeObsForRailEnv +from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv -from flatland.utils.rendertools import RenderTool from flatland.envs.rail_env import RailEnv - -from utils.observation_utils import norm_obs_clip, split_tree, max_lt +from flatland.utils.rendertools import RenderTool # Time factor to test the max time allowed for an env. max_time_factor = 1 + def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'): """ Call in a loop to create terminal progress bar @@ -54,7 +54,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No random.seed(parameters[3]) np.random.seed(parameters[3]) - printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) for trial in range(nr_trials_per_test): # Reset the env @@ -73,7 +72,6 @@ def run_test(parameters, agent, observation_builder=None, observation_wrapper=No for a in range(env.get_num_agents()): obs[a] = observation_wrapper(obs[a]) - # Run episode trial_score = 0 max_steps = int(max_time_factor * (env.height + env.width)) @@ -115,6 +113,7 @@ def create_testfiles(parameters, test_nr=0, nr_trials_per_test=100): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=TreeObsForRailEnv(max_depth=2), number_of_agents=parameters[2]) printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) @@ -151,6 +150,7 @@ def render_test(parameters, test_nr=0, nr_examples=5): env_renderer.close_window() return + def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): # Parameter initialization features_per_node = 9 @@ -168,7 +168,6 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): random.seed(parameters[3]) np.random.seed(parameters[3]) - printProgressBar(0, nr_trials_per_test, prefix='Progress:', suffix='Complete', length=20) for trial in range(nr_trials_per_test): # Reset the env @@ -177,7 +176,8 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): env = RailEnv(width=3, height=3, rail_generator=rail_from_file(file_name), - obs_builder_object=TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()), + obs_builder_object=TreeObsForRailEnv(max_depth=tree_depth, + predictor=ShortestPathPredictorForRailEnv()), number_of_agents=1, ) @@ -185,7 +185,7 @@ def run_test_sequential(parameters, agent, test_nr=0, tree_depth=3): done = env.dones # Run episode trial_score = 0 - max_steps = int(max_time_factor* (env.height + env.width)) + max_steps = int(max_time_factor * (env.height + env.width)) for step in range(max_steps): # Action diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py index d0b9ce7..ee4e9ed 100644 --- a/sequential_agent/run_test.py +++ b/sequential_agent/run_test.py @@ -1,10 +1,12 @@ -from sequential_agent.simple_order_agent import OrderedAgent -from flatland.envs.generators import rail_from_file, complex_rail_generator +import numpy as np + +from flatland.envs.agent_generators import complex_rail_generator_agents_placer +from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -import numpy as np +from sequential_agent.simple_order_agent import OrderedAgent np.random.seed(2) """ @@ -29,6 +31,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=TreeObsForRailEnv(max_depth=1, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index b310e95..96f791d 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -3,14 +3,14 @@ from collections import deque import numpy as np import torch -from flatland.envs.generators import rail_from_file -from flatland.envs.rail_env import RailEnv -from flatland.utils.rendertools import RenderTool from importlib_resources import path from observation_builders.observations import TreeObsForRailEnv from predictors.predictions import ShortestPathPredictorForRailEnv import torch_training.Nets +from flatland.envs.generators import rail_from_file +from flatland.envs.rail_env import RailEnv +from flatland.utils.rendertools import RenderTool from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -41,6 +41,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index c3c8c2b..ef0cae7 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -7,15 +7,16 @@ from collections import deque import matplotlib.pyplot as plt import numpy as np import torch +from importlib_resources import path + +# Import Torch and utility functions to normalize observation +import torch_training.Nets # Import Flatland/ Observations and Predictors +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from importlib_resources import path - -# Import Torch and utility functions to normalize observation -import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import normalize_observation @@ -56,6 +57,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -113,6 +115,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=2, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=observation_helper, number_of_agents=n_agents) diff --git a/torch_training/multi_agent_two_time_step_training.py b/torch_training/multi_agent_two_time_step_training.py index 1c3225c..f41495d 100644 --- a/torch_training/multi_agent_two_time_step_training.py +++ b/torch_training/multi_agent_two_time_step_training.py @@ -7,15 +7,16 @@ from collections import deque import matplotlib.pyplot as plt import numpy as np import torch +from importlib_resources import path + +# Import Torch and utility functions to normalize observation +import torch_training.Nets # Import Flatland/ Observations and Predictors +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv -from importlib_resources import path - -# Import Torch and utility functions to normalize observation -import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -52,6 +53,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=observation_helper, number_of_agents=n_agents) env.reset(True, True) @@ -109,6 +111,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) diff --git a/torch_training/render_agent_behavior.py b/torch_training/render_agent_behavior.py index 489501a..38bd009 100644 --- a/torch_training/render_agent_behavior.py +++ b/torch_training/render_agent_behavior.py @@ -3,14 +3,15 @@ from collections import deque import numpy as np import torch +from importlib_resources import path + +import torch_training.Nets +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.predictions import ShortestPathPredictorForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool -from importlib_resources import path - -import torch_training.Nets from torch_training.dueling_double_dqn import Agent from utils.observation_utils import norm_obs_clip, split_tree @@ -37,6 +38,7 @@ env = RailEnv(width=x_dim, rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 5cc7305..40e5a68 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -7,11 +7,12 @@ import matplotlib.pyplot as plt import numpy as np import torch from dueling_double_dqn import Agent + +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import TreeObsForRailEnv from flatland.envs.rail_env import RailEnv from flatland.utils.rendertools import RenderTool - from utils.observation_utils import norm_obs_clip, split_tree @@ -44,6 +45,7 @@ def main(argv): rail_generator=complex_rail_generator(nr_start_goal=n_goals, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=0), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=observation_builder, number_of_agents=n_agents) env.reset(True, True) diff --git a/utils/misc_utils.py b/utils/misc_utils.py index 5b29c6b..6e23af2 100644 --- a/utils/misc_utils.py +++ b/utils/misc_utils.py @@ -3,15 +3,16 @@ import time from collections import deque import numpy as np +from line_profiler import LineProfiler + +from flatland.envs.agent_generators import complex_rail_generator_agents_placer from flatland.envs.generators import complex_rail_generator from flatland.envs.observations import GlobalObsForRailEnv from flatland.envs.rail_env import RailEnv -from line_profiler import LineProfiler - from utils.observation_utils import norm_obs_clip, split_tree -def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '*'): +def printProgressBar(iteration, total, prefix='', suffix='', decimals=1, length=100, fill='*'): """ Call in a loop to create terminal progress bar @params: @@ -31,13 +32,14 @@ def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, if iteration == total: print('') + class RandomAgent: def __init__(self, state_size, action_size): self.state_size = state_size self.action_size = action_size - def act(self, state, eps = 0): + def act(self, state, eps=0): """ :param state: input is the observation of the agent :return: returns an action @@ -87,6 +89,7 @@ def run_test(parameters, agent, test_nr=0, tree_depth=3): rail_generator=complex_rail_generator(nr_start_goal=nr_paths, nr_extra=5, min_dist=min_dist, max_dist=99999, seed=parameters[3]), + agent_generator=complex_rail_generator_agents_placer(), obs_builder_object=GlobalObsForRailEnv(), number_of_agents=parameters[2]) max_steps = int(3 * (env.height + env.width)) -- GitLab