diff --git a/torch_training/bla.py b/torch_training/bla.py index f4f7131ba73463192cf5a369d52e667bfc45fb2d..2dfec096fd854b3d14a4698b67e5be667313c5d1 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -162,12 +162,13 @@ def main(argv): # print(all_rewards,action) obs_original = next_obs.copy() for a in range(env.get_num_agents()): - data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), - current_depth=0) - data = norm_obs_clip(data) - distance = norm_obs_clip(distance) - agent_data = np.clip(agent_data, -1, 1) - next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + a = 5 + # data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + # current_depth=0) + # data = norm_obs_clip(data) + # distance = norm_obs_clip(distance) + # agent_data = np.clip(agent_data, -1, 1) + # next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) time_obs.append(next_obs) # # # Update replay buffer and train agent