Commit 49fafc4f authored by Eric Hambro's avatar Eric Hambro
Browse files

Add collection of returns.

parent 26267fce
......@@ -35,19 +35,27 @@ def run_batched_rollout(batched_env, agent):
episode_count = 0
pbar = tqdm(total=NUM_ASSESSMENTS)
all_returns = []
returns = [0.0 for _ in range(num_envs)]
# The evaluator will automatically stop after the episodes based on the development/test phase
while episode_count < NUM_ASSESSMENTS:
actions = agent.batched_step(observations, rewards, dones, infos)
observations, rewards, dones, infos = batched_env.batch_step(actions)
for i, r in enumerate(rewards):
returns[i] += r
for done_idx in np.where(dones)[0]:
observations[done_idx] = batched_env.single_env_reset(done_idx)
if episodes[done_idx] > 0:
all_returns.append(returns[done_idx])
returns[done_idx] = 0.0
episodes[done_idx] -= 1
episode_count += 1
pbar.update(1)
return all_returns
if __name__ == "__main__":
submission_env_make_fn = SubmissionConfig.submission_env_make_fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment