Commit 1369276b authored by Eric Hambro's avatar Eric Hambro
Browse files

Fixup env counting to match AIcrowd logic.

parent 299d5af8
......@@ -27,10 +27,9 @@ def run_batched_rollout(batched_env, agent):
dones = [False for _ in range(num_envs)]
infos = [{} for _ in range(num_envs)]
# We assign each environment a fixed number of episodes at the start
envs_each = NUM_ASSESSMENTS // num_envs
remainders = NUM_ASSESSMENTS % num_envs
episodes = [envs_each + int(i < remainders) for i in range(num_envs)]
# We mark at the start of each episode if we are 'counting it'
active_envs = [i < NUM_ASSESSMENTS for i in range(num_envs)]
num_remaining = NUM_ASSESSMENTS - sum(active_envs)
episode_count = 0
pbar = tqdm(total=NUM_ASSESSMENTS)
......@@ -49,12 +48,17 @@ def run_batched_rollout(batched_env, agent):
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
if episodes[done_idx] > 0:
if active_envs[done_idx]:
# We were 'counting' this episode
all_returns.append(returns[done_idx])
returns[done_idx] = 0.0
episodes[done_idx] -= 1
episode_count += 1
active_envs[done_idx] = (num_remaining > 0)
num_remaining -= 1
pbar.update(1)
returns[done_idx] = 0.0
return all_returns
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment