Skip to content
Snippets Groups Projects
Commit 1369276b authored by Eric Hambro's avatar Eric Hambro
Browse files

Fixup env counting to match AIcrowd logic.

parent 299d5af8
No related branches found
No related tags found
No related merge requests found
......@@ -27,10 +27,9 @@ def run_batched_rollout(batched_env, agent):
dones = [False for _ in range(num_envs)]
infos = [{} for _ in range(num_envs)]
# We assign each environment a fixed number of episodes at the start
envs_each = NUM_ASSESSMENTS // num_envs
remainders = NUM_ASSESSMENTS % num_envs
episodes = [envs_each + int(i < remainders) for i in range(num_envs)]
# We mark at the start of each episode if we are 'counting it'
active_envs = [i < NUM_ASSESSMENTS for i in range(num_envs)]
num_remaining = NUM_ASSESSMENTS - sum(active_envs)
episode_count = 0
pbar = tqdm(total=NUM_ASSESSMENTS)
......@@ -49,12 +48,17 @@ def run_batched_rollout(batched_env, agent):
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
if episodes[done_idx] > 0:
if active_envs[done_idx]:
# We were 'counting' this episode
all_returns.append(returns[done_idx])
returns[done_idx] = 0.0
episodes[done_idx] -= 1
episode_count += 1
active_envs[done_idx] = (num_remaining > 0)
num_remaining -= 1
pbar.update(1)
returns[done_idx] = 0.0
return all_returns
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment