Skip to content
Snippets Groups Projects
Commit 1369276b authored by Eric Hambro's avatar Eric Hambro
Browse files

Fixup env counting to match AIcrowd logic.

parent 299d5af8
No related branches found
No related tags found
No related merge requests found
...@@ -27,10 +27,9 @@ def run_batched_rollout(batched_env, agent): ...@@ -27,10 +27,9 @@ def run_batched_rollout(batched_env, agent):
dones = [False for _ in range(num_envs)] dones = [False for _ in range(num_envs)]
infos = [{} for _ in range(num_envs)] infos = [{} for _ in range(num_envs)]
# We assign each environment a fixed number of episodes at the start # We mark at the start of each episode if we are 'counting it'
envs_each = NUM_ASSESSMENTS // num_envs active_envs = [i < NUM_ASSESSMENTS for i in range(num_envs)]
remainders = NUM_ASSESSMENTS % num_envs num_remaining = NUM_ASSESSMENTS - sum(active_envs)
episodes = [envs_each + int(i < remainders) for i in range(num_envs)]
episode_count = 0 episode_count = 0
pbar = tqdm(total=NUM_ASSESSMENTS) pbar = tqdm(total=NUM_ASSESSMENTS)
...@@ -49,12 +48,17 @@ def run_batched_rollout(batched_env, agent): ...@@ -49,12 +48,17 @@ def run_batched_rollout(batched_env, agent):
for done_idx in np.where(dones)[0]: for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx) observations[done_idx] = batched_env.single_env_reset(done_idx)
if episodes[done_idx] > 0: if active_envs[done_idx]:
# We were 'counting' this episode
all_returns.append(returns[done_idx]) all_returns.append(returns[done_idx])
returns[done_idx] = 0.0
episodes[done_idx] -= 1
episode_count += 1 episode_count += 1
active_envs[done_idx] = (num_remaining > 0)
num_remaining -= 1
pbar.update(1) pbar.update(1)
returns[done_idx] = 0.0
return all_returns return all_returns
if __name__ == "__main__": if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment