diff --git a/torch_training/bla.py b/torch_training/bla.py index f76c4ab267ce14dab7181625437ddc004710ebd8..0f5d7597dcbf41def5c07159d2cd71a6546105ad 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -136,58 +136,58 @@ def main(argv): # env.obs_builder.util_print_obs_subtree(tree=obs[0], num_elements_per_node=5) for a in range(env.get_num_agents()): agent_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - # - # score = 0 - # env_done = 0 - # # Run episode - # for step in range(max_steps): - # if demo: - # env_renderer.renderEnv(show=True, show_observations=False) - # # observation_helper.util_print_obs_subtree(obs_original[0]) - # if record_images: - # env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step)) - # frame_step += 1 - # # print(step) - # # Action - # for a in range(env.get_num_agents()): - # if demo: - # eps = 0 - # # action = agent.act(np.array(obs[a]), eps=eps) - # action = agent.act(agent_obs[a], eps=eps) - # action_prob[action] += 1 - # action_dict.update({a: action}) - # # Environment step - # - # next_obs, all_rewards, done, _ = env.step(action_dict) - # # print(all_rewards,action) - # obs_original = next_obs.copy() - # for a in range(env.get_num_agents()): - # data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), - # current_depth=0) - # data = norm_obs_clip(data) - # distance = norm_obs_clip(distance) - # agent_data = np.clip(agent_data, -1, 1) - # next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) - # time_obs.append(next_obs) - # - # # Update replay buffer and train agent - # for a in range(env.get_num_agents()): - # agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) - # if done[a]: - # final_obs[a] = agent_obs[a].copy() - # final_obs_next[a] = agent_next_obs[a].copy() - # final_action_dict.update({a: action_dict[a]}) - # if not demo and not done[a]: - # agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a]) - # score += all_rewards[a] / env.get_num_agents() - # - # agent_obs = agent_next_obs.copy() - # if done['__all__']: - # env_done = 1 - # for a in range(env.get_num_agents()): - # agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a]) - # break - # # Epsilon decay + + score = 0 + env_done = 0 + # Run episode + for step in range(max_steps): + if demo: + env_renderer.renderEnv(show=True, show_observations=False) + # observation_helper.util_print_obs_subtree(obs_original[0]) + if record_images: + env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step)) + frame_step += 1 + # print(step) + # Action + for a in range(env.get_num_agents()): + if demo: + eps = 0 + # action = agent.act(np.array(obs[a]), eps=eps) + action = agent.act(agent_obs[a], eps=eps) + action_prob[action] += 1 + action_dict.update({a: action}) + # Environment step + + next_obs, all_rewards, done, _ = env.step(action_dict) + # print(all_rewards,action) + obs_original = next_obs.copy() + for a in range(env.get_num_agents()): + data, distance, agent_data = split_tree(tree=np.array(next_obs[a]), + current_depth=0) + data = norm_obs_clip(data) + distance = norm_obs_clip(distance) + agent_data = np.clip(agent_data, -1, 1) + next_obs[a] = np.concatenate((np.concatenate((data, distance)), agent_data)) + time_obs.append(next_obs) + + # Update replay buffer and train agent + for a in range(env.get_num_agents()): + agent_next_obs[a] = np.concatenate((time_obs[0][a], time_obs[1][a])) + if done[a]: + final_obs[a] = agent_obs[a].copy() + final_obs_next[a] = agent_next_obs[a].copy() + final_action_dict.update({a: action_dict[a]}) + if not demo and not done[a]: + agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a]) + score += all_rewards[a] / env.get_num_agents() + + agent_obs = agent_next_obs.copy() + if done['__all__']: + env_done = 1 + for a in range(env.get_num_agents()): + agent.step(final_obs[a], final_action_dict[a], all_rewards[a], final_obs_next[a], done[a]) + break + # Epsilon decay # eps = max(eps_end, eps_decay * eps) # decrease epsilon # # done_window.append(env_done)