Commit 1bc3ec17 authored by pfrl_rainbow's avatar pfrl_rainbow

refine main mod

parent ef3f99dd
......@@ -115,42 +115,106 @@ def main():
utils.log_versions()
try:
_main(args)
dqn_family(
# meta setttings
env_id=args.env,
outdir=args.outdir,
seed=args.seed,
gpu=args.gpu,
demo=args.demo,
load=args.load,
eval_n_runs=args.eval_n_runs,
monitor=args.monitor,
# hyper params
agent_type=args.agent,
arch=args.arch,
update_interval=args.update_interval,
frame_skip=args.frame_skip,
gamma=args.gamma,
clip_delta=args.clip_delta,
num_step_return=args.num_step_return,
lr=args.lr,
adam_eps=args.adam_eps,
batch_accumulator=args.batch_accumulator,
gray_scale=args.gray_scale,
frame_stack=args.frame_stack,
final_exploration_frames=args.final_exploration_frames,
final_epsilon=args.final_epsilon,
eval_epsilon=args.eval_epsilon,
noisy_net_sigma=args.noisy_net_sigma,
replay_capacity=args.replay_capacity,
replay_start_size=args.replay_start_size,
prioritized=args.prioritized,
target_update_interval=args.target_update_interval,
kmeans_n_clusters=args.kmeans_n_clusters,
)
except: # noqa
logger.exception('execution failed.')
raise
def _main(args):
os.environ['MALMO_MINECRAFT_OUTPUT_LOGDIR'] = args.outdir
def dqn_family(
# meta settings
env_id,
outdir,
seed,
gpu,
demo,
load,
eval_n_runs,
monitor,
# hyper params
agent_type,
arch,
update_interval,
frame_skip,
gamma,
clip_delta,
num_step_return,
lr,
adam_eps,
batch_accumulator,
gray_scale,
frame_stack,
final_exploration_frames,
final_epsilon,
eval_epsilon,
noisy_net_sigma,
replay_capacity,
replay_start_size,
prioritized,
target_update_interval,
kmeans_n_clusters,
):
os.environ['MALMO_MINECRAFT_OUTPUT_LOGDIR'] = outdir
# Set a random seed used in ChainerRL.
pfrl.utils.set_random_seed(args.seed)
# Set a random seed used in PFRL.
pfrl.utils.set_random_seed(seed)
# Set different random seeds for train and test envs.
train_seed = args.seed # noqa: never used in this script
test_seed = 2 ** 31 - 1 - args.seed
train_seed = seed # noqa: never used in this script
test_seed = 2 ** 31 - 1 - seed
# K-Means
kmeans = cached_kmeans(
cache_dir=os.environ.get('KMEANS_CACHE'),
env_id=args.env,
n_clusters=args.kmeans_n_clusters,
random_state=args.seed)
env_id=env_id,
n_clusters=kmeans_n_clusters,
random_state=seed)
# create & wrap env
def wrap_env_partial(env, test):
randomize_action = test and args.noisy_net_sigma is None
randomize_action = test and noisy_net_sigma is None
wrapped_env = wrap_env(
env=env, test=test,
monitor=args.monitor, outdir=args.outdir,
frame_skip=args.frame_skip,
gray_scale=args.gray_scale, frame_stack=args.frame_stack,
randomize_action=randomize_action, eval_epsilon=args.eval_epsilon,
monitor=monitor, outdir=outdir,
frame_skip=frame_skip,
gray_scale=gray_scale, frame_stack=frame_stack,
randomize_action=randomize_action, eval_epsilon=eval_epsilon,
action_choices=kmeans.cluster_centers_)
return wrapped_env
logger.info('The first `gym.make(MineRL*)` may take several minutes. Be patient!')
core_env = gym.make(args.env)
core_env = gym.make(env_id)
# training env
env = wrap_env_partial(env=core_env, test=False)
# env.seed(int(train_seed)) # TODO: not supported yet
......@@ -162,38 +226,38 @@ def _main(args):
# 8,000,000 frames = 1333 episodes if we count an episode as 6000 frames,
# 8,000,000 frames = 1000 episodes if we count an episode as 8000 frames.
maximum_frames = 8000000
if args.frame_skip is None:
if frame_skip is None:
steps = maximum_frames
eval_interval = 6000 * 100 # (approx.) every 100 episode (counts "1 episode = 6000 steps")
else:
steps = maximum_frames // args.frame_skip
eval_interval = 6000 * 100 // args.frame_skip # (approx.) every 100 episode (counts "1 episode = 6000 steps")
steps = maximum_frames // frame_skip
eval_interval = 6000 * 100 // frame_skip # (approx.) every 100 episode (counts "1 episode = 6000 steps")
agent = get_agent(
n_actions=env.action_space.n, arch=args.arch, n_input_channels=env.observation_space.shape[0],
noisy_net_sigma=args.noisy_net_sigma, final_epsilon=args.final_epsilon,
final_exploration_frames=args.final_exploration_frames, explorer_sample_func=env.action_space.sample,
lr=args.lr, adam_eps=args.adam_eps,
prioritized=args.prioritized, steps=steps, update_interval=args.update_interval,
replay_capacity=args.replay_capacity, num_step_return=args.num_step_return,
agent_type=args.agent, gpu=args.gpu, gamma=args.gamma, replay_start_size=args.replay_start_size,
target_update_interval=args.target_update_interval, clip_delta=args.clip_delta,
batch_accumulator=args.batch_accumulator
n_actions=env.action_space.n, arch=arch, n_input_channels=env.observation_space.shape[0],
noisy_net_sigma=noisy_net_sigma, final_epsilon=final_epsilon,
final_exploration_frames=final_exploration_frames, explorer_sample_func=env.action_space.sample,
lr=lr, adam_eps=adam_eps,
prioritized=prioritized, steps=steps, update_interval=update_interval,
replay_capacity=replay_capacity, num_step_return=num_step_return,
agent_type=agent_type, gpu=gpu, gamma=gamma, replay_start_size=replay_start_size,
target_update_interval=target_update_interval, clip_delta=clip_delta,
batch_accumulator=batch_accumulator,
)
if args.load:
agent.load(args.load)
if load:
agent.load(load)
# experiment
if args.demo:
eval_stats = pfrl.experiments.eval_performance(env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_runs)
if demo:
eval_stats = pfrl.experiments.eval_performance(env=eval_env, agent=agent, n_steps=None, n_episodes=eval_n_runs)
logger.info('n_runs: {} mean: {} median: {} stdev {}'.format(
args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev']))
eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev']))
else:
pfrl.experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=steps,
eval_n_steps=None, eval_n_episodes=args.eval_n_runs, eval_interval=eval_interval,
outdir=args.outdir, eval_env=eval_env, save_best_so_far_agent=True,
eval_n_steps=None, eval_n_episodes=eval_n_runs, eval_interval=eval_interval,
outdir=outdir, eval_env=eval_env, save_best_so_far_agent=True,
)
env.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment