diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 4f82c523474f6880a8595fe8c9dc3e4565e4902b..ac80feb8c84519cf006009ac134915c40dda9ca1 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -115,9 +115,6 @@ def main(argv): # Reset environment obs, info = env.reset(True, True) - register_action_state = np.zeros(env.get_num_agents(), dtype=bool) - final_obs = agent_obs.copy() - final_obs_next = agent_next_obs.copy() # Build agent specific observations for a in range(env.get_num_agents()): @@ -155,8 +152,11 @@ def main(argv): # Update replay buffer and train agent for a in range(env.get_num_agents()): - if (agent_obs_buffer[a] is not None and register_action_state[a] and env.agents[a].status != 3) or \ + if (agent_obs_buffer[a] is not None and info['action_required'][a] and env.agents[a].status != 3) or \ env.agents[a].status == 2: + if all_rewards[a] < -1.: + print("bad") + agent_delayed_next = agent_obs[a].copy() agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a], agent_delayed_next, done[a])