Skip to content
Snippets Groups Projects
Commit 35e0e64e authored by Eric Hambro's avatar Eric Hambro
Browse files

Add ascension counter

parent 88c416dc
No related branches found
No related tags found
No related merge requests found
......@@ -35,6 +35,7 @@ def run_batched_rollout(num_episodes, batched_env, agent):
episode_count = 0
pbar = tqdm(total=num_episodes)
ascension_count = 0
all_returns = []
returns = [0.0 for _ in range(num_envs)]
# The evaluator will automatically stop after the episodes based on the development/test phase
......@@ -55,11 +56,12 @@ def run_batched_rollout(num_episodes, batched_env, agent):
active_envs[done_idx] = (num_remaining > 0)
num_remaining -= 1
ascension_count += int(infos[done_idx]["is_ascended"])
pbar.update(1)
returns[done_idx] = 0.0
pbar.close()
return all_returns
return ascension_count, all_returns
if __name__ == "__main__":
# AIcrowd will cut the assessment early duing the dev phase
......
......@@ -37,4 +37,4 @@ class TestEvaluationConfig:
# Change this to locally check a different number of rollouts
# The AIcrowd submission evaluator will not use this
# It is only for your local evaluation
NUM_EPISODES = 64
NUM_EPISODES = 512
......@@ -25,8 +25,12 @@ def evaluate():
agent = Agent(num_envs, batched_env.num_actions)
scores = run_batched_rollout(num_episodes, batched_env, agent)
print(f"Median Score: {np.median(scores)}, Mean Score: {np.mean(scores)}")
ascensions, scores = run_batched_rollout(num_episodes, batched_env, agent)
print(
f"Ascensions: {ascensions} "
f"Median Score: {np.median(scores)}, "
f"Mean Score: {np.mean(scores)}"
)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment