diff --git a/torch_training/multi_agent_training.py b/torch_training/multi_agent_training.py index 48227047e0e398a80cee2c41b0aa223ae3aaa9b1..ba30d4620588846e77eecdb794fa72d3e1088e83 100644 --- a/torch_training/multi_agent_training.py +++ b/torch_training/multi_agent_training.py @@ -175,6 +175,10 @@ def main(argv): # Action for a in range(env.get_num_agents()): + if env.agents[a].speed_data['position_fraction'] == 0.: + register_action_state[a] = True + else: + register_action_state[a] = False action = agent.act(agent_obs[a], eps=eps) action_prob[action] += 1 action_dict.update({a: action}) @@ -192,7 +196,7 @@ def main(argv): final_obs[a] = agent_obs[a].copy() final_obs_next[a] = agent_next_obs[a].copy() final_action_dict.update({a: action_dict[a]}) - if not done[a]: + if not done[a] and register_action_state[a]: agent.step(agent_obs[a], action_dict[a], all_rewards[a], agent_next_obs[a], done[a]) score += all_rewards[a] / env.get_num_agents()