From ccd2d08bcf51db081c60edd52bb9a3b94f015ed3 Mon Sep 17 00:00:00 2001
From: Erik Nygren <erik.nygren@sbb.ch>
Date: Sat, 5 Oct 2019 11:31:40 -0400
Subject: [PATCH] minor bugfixes

---
 torch_training/training_navigation.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/torch_training/training_navigation.py b/torch_training/training_navigation.py
index 4f82c52..ac80feb 100644
--- a/torch_training/training_navigation.py
+++ b/torch_training/training_navigation.py
@@ -115,9 +115,6 @@ def main(argv):
 
         # Reset environment
         obs, info = env.reset(True, True)
-        register_action_state = np.zeros(env.get_num_agents(), dtype=bool)
-        final_obs = agent_obs.copy()
-        final_obs_next = agent_next_obs.copy()
 
         # Build agent specific observations
         for a in range(env.get_num_agents()):
@@ -155,8 +152,11 @@ def main(argv):
 
             # Update replay buffer and train agent
             for a in range(env.get_num_agents()):
-                if (agent_obs_buffer[a] is not None and register_action_state[a] and env.agents[a].status != 3) or \
+                if (agent_obs_buffer[a] is not None and info['action_required'][a] and env.agents[a].status != 3) or \
                         env.agents[a].status == 2:
+                    if all_rewards[a] < -1.:
+                        print("bad")
+
                     agent_delayed_next = agent_obs[a].copy()
                     agent.step(agent_obs_buffer[a], agent_action_buffer[a], all_rewards[a],
                                agent_delayed_next, done[a])
-- 
GitLab