diff --git a/rollout.py b/rollout.py
index cf06647f733994bf2c6ee5d9f7cb9eaac53c1a6d..89899aecdfb24a4fad309114b2dcd007bbb1a43d 100644
--- a/rollout.py
+++ b/rollout.py
@@ -27,10 +27,9 @@ def run_batched_rollout(batched_env, agent):
     dones = [False for _ in range(num_envs)]
     infos = [{} for _ in range(num_envs)]
 
-    # We assign each environment a fixed number of episodes at the start
-    envs_each = NUM_ASSESSMENTS // num_envs
-    remainders = NUM_ASSESSMENTS % num_envs
-    episodes = [envs_each + int(i < remainders) for i in range(num_envs)]
+    # We mark at the start of each episode if we are 'counting it'
+    active_envs = [i < NUM_ASSESSMENTS for i in range(num_envs)]
+    num_remaining = NUM_ASSESSMENTS - sum(active_envs)
     
     episode_count = 0
     pbar = tqdm(total=NUM_ASSESSMENTS)
@@ -49,12 +48,17 @@ def run_batched_rollout(batched_env, agent):
         for done_idx in np.where(dones)[0]:
             observations[done_idx] = batched_env.single_env_reset(done_idx)
  
-            if episodes[done_idx] > 0:
+            if active_envs[done_idx]:
+                # We were 'counting' this episode
                 all_returns.append(returns[done_idx])
-                returns[done_idx] = 0.0
-                episodes[done_idx] -= 1
                 episode_count += 1
+                
+                active_envs[done_idx] = (num_remaining > 0)
+                num_remaining -= 1
+                
                 pbar.update(1)
+            
+            returns[done_idx] = 0.0
     return all_returns
 
 if __name__ == "__main__":