Commit 74046a63 authored by pfrl_rainbow's avatar pfrl_rainbow

enable thread

parent f27deeb8
...@@ -60,7 +60,7 @@ MINERL_MAX_EVALUATION_EPISODES = int(os.getenv('MINERL_MAX_EVALUATION_EPISODES', ...@@ -60,7 +60,7 @@ MINERL_MAX_EVALUATION_EPISODES = int(os.getenv('MINERL_MAX_EVALUATION_EPISODES',
# Parallel testing/inference, **you can override** below value based on compute # Parallel testing/inference, **you can override** below value based on compute
# requirements, etc to save OOM in this phase. # requirements, etc to save OOM in this phase.
EVALUATION_THREAD_COUNT = int(os.getenv('EPISODES_EVALUATION_THREAD_COUNT', 1)) EVALUATION_THREAD_COUNT = int(os.getenv('EPISODES_EVALUATION_THREAD_COUNT', 2))
class EpisodeDone(Exception): class EpisodeDone(Exception):
pass pass
...@@ -174,59 +174,46 @@ AGENT_TO_TEST = MineRLRainbowBaselineAgent # MineRLMatrixAgent, MineRLRandomAgen ...@@ -174,59 +174,46 @@ AGENT_TO_TEST = MineRLRainbowBaselineAgent # MineRLMatrixAgent, MineRLRandomAgen
# EVALUATION CODE # # EVALUATION CODE #
#################### ####################
def main(): def main():
# agent = AGENT_TO_TEST()
# assert isinstance(agent, MineRLAgentBase)
# agent.load_agent()
#
# assert MINERL_MAX_EVALUATION_EPISODES > 0
# assert EVALUATION_THREAD_COUNT > 0
#
# # Create the parallel envs (sequentially to prevent issues!)
# envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
# episodes_per_thread = [MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT for _ in range(EVALUATION_THREAD_COUNT)]
# episodes_per_thread[-1] += MINERL_MAX_EVALUATION_EPISODES - EVALUATION_THREAD_COUNT *(MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT)
# # A simple funciton to evaluate on episodes!
# def evaluate(i, env):
# print("[{}] Starting evaluator.".format(i))
# for i in range(episodes_per_thread[i]):
# try:
# agent.run_agent_on_episode(Episode(env))
# except EpisodeDone:
# print("[{}] Episode complete".format(i))
# pass
#
# evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)]
# for thread in evaluator_threads:
# thread.start()
#
# # wait fo the evaluation to finish
# for thread in evaluator_threads:
# thread.join()
assert MINERL_MAX_EVALUATION_EPISODES > 0 assert MINERL_MAX_EVALUATION_EPISODES > 0
assert EVALUATION_THREAD_COUNT == 1 assert EVALUATION_THREAD_COUNT > 0
# Create the parallel envs (sequentially to prevent issues!)
kmeans = joblib.load(os.path.abspath(os.path.join(__file__, os.pardir, 'train', 'kmeans.joblib'))) kmeans = joblib.load(os.path.abspath(os.path.join(__file__, os.pardir, 'train', 'kmeans.joblib')))
core_env = gym.make(MINERL_GYM_ENV) def wrapper(env):
env = wrap_env( return wrap_env(
env=core_env, test=True, monitor=False, outdir=None, env=env, test=True, monitor=False, outdir=None,
frame_skip=FRAME_SKIP, gray_scale=GRAY_SCALE, frame_stack=FRAME_STACK, frame_skip=FRAME_SKIP, gray_scale=GRAY_SCALE, frame_stack=FRAME_STACK,
randomize_action=RANDOMIZE_ACTION, eval_epsilon=EVAL_EPSILON, randomize_action=RANDOMIZE_ACTION, eval_epsilon=EVAL_EPSILON,
action_choices=kmeans.cluster_centers_, action_choices=kmeans.cluster_centers_,
) )
agent = AGENT_TO_TEST(env) envs = [wrapper(gym.make(MINERL_GYM_ENV)) for _ in range(EVALUATION_THREAD_COUNT)]
# envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
agent = AGENT_TO_TEST(envs[0])
# agent = AGENT_TO_TEST()
assert isinstance(agent, MineRLAgentBase) assert isinstance(agent, MineRLAgentBase)
agent.load_agent() agent.load_agent()
for i in range(MINERL_MAX_EVALUATION_EPISODES): episodes_per_thread = [MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT for _ in range(EVALUATION_THREAD_COUNT)]
episodes_per_thread[-1] += MINERL_MAX_EVALUATION_EPISODES - EVALUATION_THREAD_COUNT *(MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT)
# A simple funciton to evaluate on episodes!
def evaluate(i, env):
print("[{}] Starting evaluator.".format(i)) print("[{}] Starting evaluator.".format(i))
try: for i in range(episodes_per_thread[i]):
agent.run_agent_on_episode(Episode(env)) try:
except EpisodeDone: agent.run_agent_on_episode(Episode(env))
print("[{}] Episode complete".format(i)) except EpisodeDone:
pass print("[{}] Episode complete".format(i))
pass
evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)]
for thread in evaluator_threads:
thread.start()
# wait fo the evaluation to finish
for thread in evaluator_threads:
thread.join()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment