From 446e6987cb0c83db6fac2e6570be0d0e80525cf4 Mon Sep 17 00:00:00 2001 From: u214892 <u214892@sbb.ch> Date: Thu, 11 Jul 2019 09:55:50 +0200 Subject: [PATCH] #42 run baselines in ci --- torch_training/bla.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torch_training/bla.py b/torch_training/bla.py index f0e7f59..f4f7131 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -157,18 +157,18 @@ def main(argv): action_prob[action] += 1 action_dict.update({a: action}) # Environment step - # - # next_obs, all_rewards, done, _ = env.step(action_dict) - # # print(all_rewards,action) - # obs_original = next_obs.copy() - # for a in range(env.get_num_agents()): - # data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), - # current_depth=0) - # data = norm_obs_clip(data) - # distance = norm_obs_clip(distance) - # agent_data = np.clip(agent_data, -1, 1) - # next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) - # time_obs.append(next_obs) + + next_obs, all_rewards, done, _ = env.step(action_dict) + # print(all_rewards,action) + obs_original = next_obs.copy() + for a in range(env.get_num_agents()): + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + current_depth=0) + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + agent_data = np.clip(agent_data, -1, 1) + next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + time_obs.append(next_obs) # # # Update replay buffer and train agent # for a in range(env.get_num_agents()): -- GitLab