From a7ddd74b3336b6d934fa4fb5ab76b965558c5268 Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Thu, 18 Jul 2019 12:07:41 -0400 Subject: [PATCH] added new utility function to normalize tree observation --- sequential_agent/run_test.py | 4 ++-- torch_training/multi_agent_inference.py | 31 ++++++++----------------- torch_training/multi_agent_training.py | 17 +++----------- utils/observation_utils.py | 10 ++++++++ 4 files changed, 25 insertions(+), 37 deletions(-) diff --git a/sequential_agent/run_test.py b/sequential_agent/run_test.py index 6e9f7c2..970d6aa 100644 --- a/sequential_agent/run_test.py +++ b/sequential_agent/run_test.py @@ -37,9 +37,9 @@ tree_depth = 1 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) handle = env.get_agent_handles() -n_trials = 10 +n_trials = 1 max_steps = 3 * (env.height + env.width) -record_images = False +record_images = True agent = OrderedAgent() action_dict = dict() diff --git a/torch_training/multi_agent_inference.py b/torch_training/multi_agent_inference.py index 003b18a..e399126 100644 --- a/torch_training/multi_agent_inference.py +++ b/torch_training/multi_agent_inference.py @@ -12,11 +12,11 @@ from importlib_resources import path import torch_training.Nets from torch_training.dueling_double_dqn import Agent -from utils.observation_utils import norm_obs_clip, split_tree +from utils.observation_utils import normalize_observation random.seed(3) np.random.seed(2) - +""" file_name = "./railway/complex_scene.pkl" env = RailEnv(width=10, height=20, @@ -27,9 +27,9 @@ y_dim = env.height """ -x_dim = 50 #np.random.randint(8, 20) -y_dim = 50 #np.random.randint(8, 20) -n_agents = 20 # np.random.randint(3, 8) +x_dim = 10 # np.random.randint(8, 20) +y_dim = 10 # np.random.randint(8, 20) +n_agents = 5 # np.random.randint(3, 8) n_goals = n_agents + np.random.randint(0, 3) min_dist = int(0.75 * min(x_dim, y_dim)) @@ -41,7 +41,7 @@ env = RailEnv(width=x_dim, obs_builder_object=TreeObsForRailEnv(max_depth=3, predictor=ShortestPathPredictorForRailEnv()), number_of_agents=n_agents) env.reset(True, True) -""" + tree_depth = 3 observation_helper = TreeObsForRailEnv(max_depth=tree_depth, predictor=ShortestPathPredictorForRailEnv()) env_renderer = RenderTool(env, gl="PILSVG", ) @@ -53,7 +53,7 @@ for i in range(tree_depth + 1): state_size = num_features_per_node * nr_nodes action_size = 5 -n_trials = 1 +n_trials = 10 observation_radius = 10 max_steps = int(3 * (env.height + env.width)) eps = 1. @@ -73,7 +73,7 @@ agent = Agent(state_size, action_size, "FC", 0) with path(torch_training.Nets, "avoid_checkpoint52800.pth") as file_in: agent.qnetwork_local.load_state_dict(torch.load(file_in)) -record_images = True +record_images = False frame_step = 0 for trials in range(1, n_trials + 1): @@ -84,12 +84,7 @@ for trials in range(1, n_trials + 1): env_renderer.reset() for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node, - current_depth=0) - data = norm_obs_clip(data, fixed_radius=observation_radius) - distance = norm_obs_clip(distance) - agent_data = np.clip(agent_data, -1, 1) - agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + agent_obs[a] = normalize_observation(obs[a], observation_radius=10) # Run episode for step in range(max_steps): @@ -108,13 +103,7 @@ for trials in range(1, n_trials + 1): next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), - num_features_per_node=num_features_per_node, - current_depth=0) - data = norm_obs_clip(data, fixed_radius=observation_radius) - distance = norm_obs_clip(distance) - agent_data = np.clip(agent_data, -1, 1) - agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + agent_obs[a] = agent_obs[a] = normalize_observation(next_obs[a], observation_radius=10) if done['__all__']: break diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 476066a..7659b2d 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -17,7 +17,7 @@ from importlib_resources import path # Import Torch and utility functions to normalize observation import torch_training.Nets from torch_training.dueling_double_dqn import Agent -from utils.observation_utils import norm_obs_clip, split_tree +from utils.observation_utils import normalize_observation def main(argv): @@ -131,13 +131,7 @@ def main(argv): # Build agent specific observations for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(obs[a]), num_features_per_node=num_features_per_node, - current_depth=0) - data = norm_obs_clip(data, fixed_radius=observation_radius) - distance = norm_obs_clip(distance) - agent_data = np.clip(agent_data, -1, 1) - agent_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) - + agent_obs[a] = agent_obs[a] = normalize_observation(obs[a], observation_radius=10) score = 0 env_done = 0 @@ -155,12 +149,7 @@ def main(argv): # Build agent specific observations and normalize for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), - num_features_per_node=num_features_per_node, current_depth=0) - data = norm_obs_clip(data, fixed_radius=observation_radius) - distance = norm_obs_clip(distance) - agent_data = np.clip(agent_data, -1, 1) - agent_next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10) # Update replay buffer and train agent for a in range(env.get_num_agents()): diff --git a/utils/observation_utils.py b/utils/observation_utils.py index c5f0d5d..b3dd5ae 100644 --- a/utils/observation_utils.py +++ b/utils/observation_utils.py @@ -97,3 +97,13 @@ def split_tree(tree, num_features_per_node, current_depth=0): distance_data.extend(tmp_distance_data) agent_data.extend(tmp_agent_data) return tree_data, distance_data, agent_data + + +def normalize_observation(observation, num_features_per_node=9, observation_radius=0): + data, distance, agent_data = split_tree(tree=np.array(observation), num_features_per_node=num_features_per_node, + current_depth=0) + data = norm_obs_clip(data, fixed_radius=observation_radius) + distance = norm_obs_clip(distance) + agent_data = np.clip(agent_data, -1, 1) + normalized_obs = np.concatenate((np.concatenate((data, distance)), agent_data)) + return normalized_obs -- GitLab