From 35e0e64e20f7913a8003b598048c338f9708e2a9 Mon Sep 17 00:00:00 2001 From: Eric Hambro <eric.hambro@gmail.com> Date: Sun, 6 Jun 2021 08:51:18 -0700 Subject: [PATCH] Add ascension counter --- rollout.py | 4 +++- submission_config.py | 2 +- test_submission.py | 8 ++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/rollout.py b/rollout.py index aac7720..b0b06bb 100644 --- a/rollout.py +++ b/rollout.py @@ -35,6 +35,7 @@ def run_batched_rollout(num_episodes, batched_env, agent): episode_count = 0 pbar = tqdm(total=num_episodes) + ascension_count = 0 all_returns = [] returns = [0.0 for _ in range(num_envs)] # The evaluator will automatically stop after the episodes based on the development/test phase @@ -55,11 +56,12 @@ def run_batched_rollout(num_episodes, batched_env, agent): active_envs[done_idx] = (num_remaining > 0) num_remaining -= 1 + ascension_count += int(infos[done_idx]["is_ascended"]) pbar.update(1) returns[done_idx] = 0.0 pbar.close() - return all_returns + return ascension_count, all_returns if __name__ == "__main__": # AIcrowd will cut the assessment early duing the dev phase diff --git a/submission_config.py b/submission_config.py index 9968e99..dc2e609 100644 --- a/submission_config.py +++ b/submission_config.py @@ -37,4 +37,4 @@ class TestEvaluationConfig: # Change this to locally check a different number of rollouts # The AIcrowd submission evaluator will not use this # It is only for your local evaluation - NUM_EPISODES = 64 + NUM_EPISODES = 512 diff --git a/test_submission.py b/test_submission.py index 7ab3495..4ca3cd3 100644 --- a/test_submission.py +++ b/test_submission.py @@ -25,8 +25,12 @@ def evaluate(): agent = Agent(num_envs, batched_env.num_actions) - scores = run_batched_rollout(num_episodes, batched_env, agent) - print(f"Median Score: {np.median(scores)}, Mean Score: {np.mean(scores)}") + ascensions, scores = run_batched_rollout(num_episodes, batched_env, agent) + print( + f"Ascensions: {ascensions} " + f"Median Score: {np.median(scores)}, " + f"Mean Score: {np.mean(scores)}" + ) if __name__ == "__main__": -- GitLab