Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
pfrl_rainbow
minerl2020_submission
Commits
74046a63
Commit
74046a63
authored
Aug 07, 2020
by
pfrl_rainbow
Browse files
enable thread
parent
f27deeb8
Changes
1
Hide whitespace changes
Inline
Side-by-side
test.py
View file @
74046a63
...
...
@@ -60,7 +60,7 @@ MINERL_MAX_EVALUATION_EPISODES = int(os.getenv('MINERL_MAX_EVALUATION_EPISODES',
# Parallel testing/inference, **you can override** below value based on compute
# requirements, etc to save OOM in this phase.
EVALUATION_THREAD_COUNT
=
int
(
os
.
getenv
(
'EPISODES_EVALUATION_THREAD_COUNT'
,
1
))
EVALUATION_THREAD_COUNT
=
int
(
os
.
getenv
(
'EPISODES_EVALUATION_THREAD_COUNT'
,
2
))
class
EpisodeDone
(
Exception
):
pass
...
...
@@ -174,59 +174,46 @@ AGENT_TO_TEST = MineRLRainbowBaselineAgent # MineRLMatrixAgent, MineRLRandomAgen
# EVALUATION CODE #
####################
def
main
():
# agent = AGENT_TO_TEST()
# assert isinstance(agent, MineRLAgentBase)
# agent.load_agent()
#
# assert MINERL_MAX_EVALUATION_EPISODES > 0
# assert EVALUATION_THREAD_COUNT > 0
#
# # Create the parallel envs (sequentially to prevent issues!)
# envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
# episodes_per_thread = [MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT for _ in range(EVALUATION_THREAD_COUNT)]
# episodes_per_thread[-1] += MINERL_MAX_EVALUATION_EPISODES - EVALUATION_THREAD_COUNT *(MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT)
# # A simple funciton to evaluate on episodes!
# def evaluate(i, env):
# print("[{}] Starting evaluator.".format(i))
# for i in range(episodes_per_thread[i]):
# try:
# agent.run_agent_on_episode(Episode(env))
# except EpisodeDone:
# print("[{}] Episode complete".format(i))
# pass
#
# evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)]
# for thread in evaluator_threads:
# thread.start()
#
# # wait fo the evaluation to finish
# for thread in evaluator_threads:
# thread.join()
assert
MINERL_MAX_EVALUATION_EPISODES
>
0
assert
EVALUATION_THREAD_COUNT
==
1
assert
EVALUATION_THREAD_COUNT
>
0
# Create the parallel envs (sequentially to prevent issues!)
kmeans
=
joblib
.
load
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__file__
,
os
.
pardir
,
'train'
,
'kmeans.joblib'
)))
core_env
=
gym
.
make
(
MINERL_GYM_ENV
)
env
=
wrap_env
(
env
=
core_
env
,
test
=
True
,
monitor
=
False
,
outdir
=
None
,
frame_skip
=
FRAME_SKIP
,
gray_scale
=
GRAY_SCALE
,
frame_stack
=
FRAME_STACK
,
randomize_action
=
RANDOMIZE_ACTION
,
eval_epsilon
=
EVAL_EPSILON
,
action_choices
=
kmeans
.
cluster_centers_
,
)
def
wrapper
(
env
):
return
wrap_env
(
env
=
env
,
test
=
True
,
monitor
=
False
,
outdir
=
None
,
frame_skip
=
FRAME_SKIP
,
gray_scale
=
GRAY_SCALE
,
frame_stack
=
FRAME_STACK
,
randomize_action
=
RANDOMIZE_ACTION
,
eval_epsilon
=
EVAL_EPSILON
,
action_choices
=
kmeans
.
cluster_centers_
,
)
agent
=
AGENT_TO_TEST
(
env
)
envs
=
[
wrapper
(
gym
.
make
(
MINERL_GYM_ENV
))
for
_
in
range
(
EVALUATION_THREAD_COUNT
)]
# envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
agent
=
AGENT_TO_TEST
(
envs
[
0
])
# agent = AGENT_TO_TEST()
assert
isinstance
(
agent
,
MineRLAgentBase
)
agent
.
load_agent
()
for
i
in
range
(
MINERL_MAX_EVALUATION_EPISODES
):
episodes_per_thread
=
[
MINERL_MAX_EVALUATION_EPISODES
//
EVALUATION_THREAD_COUNT
for
_
in
range
(
EVALUATION_THREAD_COUNT
)]
episodes_per_thread
[
-
1
]
+=
MINERL_MAX_EVALUATION_EPISODES
-
EVALUATION_THREAD_COUNT
*
(
MINERL_MAX_EVALUATION_EPISODES
//
EVALUATION_THREAD_COUNT
)
# A simple funciton to evaluate on episodes!
def
evaluate
(
i
,
env
):
print
(
"[{}] Starting evaluator."
.
format
(
i
))
try
:
agent
.
run_agent_on_episode
(
Episode
(
env
))
except
EpisodeDone
:
print
(
"[{}] Episode complete"
.
format
(
i
))
pass
for
i
in
range
(
episodes_per_thread
[
i
]):
try
:
agent
.
run_agent_on_episode
(
Episode
(
env
))
except
EpisodeDone
:
print
(
"[{}] Episode complete"
.
format
(
i
))
pass
evaluator_threads
=
[
threading
.
Thread
(
target
=
evaluate
,
args
=
(
i
,
envs
[
i
]))
for
i
in
range
(
EVALUATION_THREAD_COUNT
)]
for
thread
in
evaluator_threads
:
thread
.
start
()
# wait fo the evaluation to finish
for
thread
in
evaluator_threads
:
thread
.
join
()
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment