From c93e9d56267411149105622f505407e7640659ff Mon Sep 17 00:00:00 2001 From: MLErik <baerenjesus@gmail.com> Date: Fri, 7 Jun 2019 10:56:46 +0200 Subject: [PATCH] error introduced for christian to test --- torch_training/training_navigation.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 2a4af22..198f0ee 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -65,7 +65,7 @@ action_prob = [0] * 4 agent_obs = [None] * env.get_num_agents() agent_next_obs = [None] * env.get_num_agents() agent = Agent(state_size, action_size, "FC", 0) -agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint10400.pth')) +agent.qnetwork_local.load_state_dict(torch.load('./Nets/avoid_checkpoint15000.pth')) demo = True @@ -119,18 +119,18 @@ def norm_obs_clip(obs, clip_min=-1, clip_max=1): for trials in range(1, n_trials + 1): # Reset environment - obs = env.reset(False,False) - + obs = env.reset(False, False) + print(len(obs[0])) final_obs = obs.copy() final_obs_next = obs.copy() for a in range(env.get_num_agents()): - data, distance = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=5, current_depth=0) - + data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(obs[a]), num_features_per_node=7, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) - obs[a] = np.concatenate((data, distance)) + obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + print(len(data) + len(distance) + len(agent_data), len(obs[a])) for i in range(2): time_obs.append(obs) # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) @@ -156,11 +156,11 @@ for trials in range(1, n_trials + 1): # Environment step next_obs, all_rewards, done, _ = env.step(action_dict) for a in range(env.get_num_agents()): - data, distance = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=5, + data, distance, agent_data = env.obs_builder.split_tree(tree=np.array(next_obs[a]), num_features_per_node=7, current_depth=0) data = norm_obs_clip(data) distance = norm_obs_clip(distance) - next_obs[a] = np.concatenate((data, distance)) + next_obs[a] = np.concatenate((np.concatenate((data, distance)),agent_data)) time_obs.append(next_obs) -- GitLab