diff --git a/torch_training/bla.py b/torch_training/bla.py index 4225d669ccdaa3ce24f7164852cee0436deb3781..f0e7f592a3946736586b503e89d7daf7cb026085 100644 --- a/torch_training/bla.py +++ b/torch_training/bla.py @@ -148,15 +148,15 @@ def main(argv): env_renderer.gl.saveImage("./Images/flatland_frame_{:04d}.bmp".format(frame_step)) frame_step += 1 # print(step) - # # Action - # for a in range(env.get_num_agents()): - # if demo: - # eps = 0 - # # action = agent.act(np.array(obs[a]), eps=eps) - # action = agent.act(agent_obs[a], eps=eps) - # action_prob[action] += 1 - # action_dict.update({a: action}) - # # Environment step + # Action + for a in range(env.get_num_agents()): + if demo: + eps = 0 + # action = agent.act(np.array(obs[a]), eps=eps) + action = agent.act(agent_obs[a], eps=eps) + action_prob[action] += 1 + action_dict.update({a: action}) + # Environment step # # next_obs, all_rewards, done, _ = env.step(action_dict) # # print(all_rewards,action)