From ed1f9cc4699c9fc1a66f8d6e33b8e4693dbee07e Mon Sep 17 00:00:00 2001 From: Erik Nygren <erik.nygren@sbb.ch> Date: Sun, 1 Sep 2019 18:58:04 -0400 Subject: [PATCH] fixing learning issues --- torch_training/training_navigation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py index 4f99b48..3ed4840 100644 --- a/torch_training/training_navigation.py +++ b/torch_training/training_navigation.py @@ -125,6 +125,7 @@ def main(argv): # Build agent specific observations for a in range(env.get_num_agents()): agent_obs[a] = normalize_observation(obs[a], observation_radius=10) + agent_obs_buffer[a] = agent_obs[a].copy() # Reset score and done score = 0 @@ -136,11 +137,13 @@ def main(argv): for a in range(env.get_num_agents()): if env.agents[a].speed_data['position_fraction'] < 0.001: register_action_state[a] = True + action = agent.act(agent_obs[a], eps=eps) + action_prob[action] += 1 + if step == 0: + agent_action_buffer[a] = action else: register_action_state[a] = False - - action = agent.act(agent_obs[a], eps=eps) - action_prob[action] += 1 + action = 0 action_dict.update({a: action}) # Environment step @@ -150,6 +153,7 @@ def main(argv): for a in range(env.get_num_agents()): agent_next_obs[a] = normalize_observation(next_obs[a], observation_radius=10) cummulated_reward[a] += all_rewards[a] + # Update replay buffer and train agent for a in range(env.get_num_agents()): if done[a]: -- GitLab