diff --git a/torch_training/bla.py b/torch_training/bla.py index f0e7f592a3946736586b503e89d7daf7cb026085..f4f7131ba73463192cf5a369d52e667bfc45fb2d 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -157,18 +157,18 @@ def main(argv): action_prob[action] += 1 action_dict.update({a: action}) # Environment step - # - # next_obs, all_rewards, done, _ = env.step(action_dict) - # # print(all_rewards,action) - # obs_original = next_obs.copy() - # for a in range(env.get_num_agents()): - # data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), - # current_depth=0) - # data = norm_obs_clip(data) - # distance = norm_obs_clip(distance) - # agent_data = np.clip(agent_data, -1, 1) - # next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) - # time_obs.append(next_obs) + + next_obs, all_rewards, done, _ = env.step(action_dict) + # print(all_rewards,action) + obs_original = next_obs.copy() + for a in range(env.get_num_agents()): + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + current_depth=0) + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + agent_data = np.clip(agent_data, -1, 1) + next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + time_obs.append(next_obs) # # # Update replay buffer and train agent # for a in range(env.get_num_agents()):