diff --git a/rollout.py b/rollout.py index cf06647f733994bf2c6ee5d9f7cb9eaac53c1a6d..89899aecdfb24a4fad309114b2dcd007bbb1a43d 100644 --- a/rollout.py +++ b/rollout.py @@ -27,10 +27,9 @@ def run_batched_rollout(batched_env, agent): dones = [False for _ in range(num_envs)] infos = [{} for _ in range(num_envs)] - # We assign each environment a fixed number of episodes at the start - envs_each = NUM_ASSESSMENTS // num_envs - remainders = NUM_ASSESSMENTS % num_envs - episodes = [envs_each + int(i < remainders) for i in range(num_envs)] + # We mark at the start of each episode if we are 'counting it' + active_envs = [i < NUM_ASSESSMENTS for i in range(num_envs)] + num_remaining = NUM_ASSESSMENTS - sum(active_envs) episode_count = 0 pbar = tqdm(total=NUM_ASSESSMENTS) @@ -49,12 +48,17 @@ def run_batched_rollout(batched_env, agent): for done_idx in np.where(dones)[0]: observations[done_idx] = batched_env.single_env_reset(done_idx) - if episodes[done_idx] > 0: + if active_envs[done_idx]: + # We were 'counting' this episode all_returns.append(returns[done_idx]) - returns[done_idx] = 0.0 - episodes[done_idx] -= 1 episode_count += 1 + + active_envs[done_idx] = (num_remaining > 0) + num_remaining -= 1 + pbar.update(1) + + returns[done_idx] = 0.0 return all_returns if __name__ == "__main__":