Commit a0b39efa authored by pfrl_rainbow's avatar pfrl_rainbow

remove (+2 squashed commits)

Squashed commits:
[caebe10] git lfs
[cc5ef9a] rainbow baseline (diamond dense)
parent 3ec3c1db
*.pt filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
......@@ -143,7 +143,6 @@ def _main(args):
randomize_action = test and args.noisy_net_sigma is None
wrapped_env = wrap_env(
env=env, test=test,
env_id=args.env,
monitor=args.monitor, outdir=args.outdir,
frame_skip=args.frame_skip,
gray_scale=args.gray_scale, frame_stack=args.frame_stack,
......
......@@ -16,7 +16,6 @@ logger = getLogger(__name__)
def wrap_env(
env, test,
env_id,
monitor, outdir,
frame_skip,
gray_scale, frame_stack,
......
......@@ -18,13 +18,49 @@ import numpy as np
import coloredlogs
coloredlogs.install(logging.DEBUG)
# our dependencies
import joblib
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, os.pardir, 'mod')))
from dqn_family import get_agent
from env_wrappers import wrap_env
GPU = -1
ARCH = 'distributed_dueling'
NOISY_NET_SIGMA = 0.5
FINAL_EPSILON = 0.01
FINAL_EXPLORATION_FRAMES = 10 ** 6
LR = 0.0000625
ADAM_EPS = 0.00015
PRIORITIZED = True
UPDATE_INTERVAL = 4
REPLAY_CAPACITY = 300000
NUM_STEP_RETURN = 10
AGENT_TYPE = 'CategoricalDoubleDQN'
GAMMA = 0.99
REPLAY_START_SIZE = 5000
TARGET_UPDATE_INTERVAL = 10000
CLIP_DELTA = True
BATCH_ACCUMULATOR = 'mean'
FRAME_SKIP = 4
GRAY_SCALE = False
FRAME_STACK = 4
RANDOMIZE_ACTION = NOISY_NET_SIGMA is None
EVAL_EPSILON = 0.001
maximum_frames = 8000000
STEPS = maximum_frames // FRAME_SKIP
# All the evaluations will be evaluated on MineRLObtainDiamondVectorObf-v0 environment
MINERL_GYM_ENV = os.getenv('MINERL_GYM_ENV', 'MineRLObtainDiamondVectorObf-v0')
MINERL_MAX_EVALUATION_EPISODES = int(os.getenv('MINERL_MAX_EVALUATION_EPISODES', 5))
# Parallel testing/inference, **you can override** below value based on compute
# requirements, etc to save OOM in this phase.
EVALUATION_THREAD_COUNT = int(os.getenv('EPISODES_EVALUATION_THREAD_COUNT', 2))
EVALUATION_THREAD_COUNT = int(os.getenv('EPISODES_EVALUATION_THREAD_COUNT', 1))
class EpisodeDone(Exception):
pass
......@@ -100,61 +136,37 @@ class MineRLAgentBase(abc.ABC):
#######################
# YOUR CODE GOES HERE #
#######################
class MineRLRainbowBaselineAgent(MineRLAgentBase):
def __init__(self, env):
self.env = env
class MineRLMatrixAgent(MineRLAgentBase):
"""
An example random agent.
Note, you MUST subclass MineRLAgentBase.
"""
def load_agent(self):
"""In this example we make a random matrix which
we will use to multiply the state by to produce an action!
This is where you could load a neural network.
"""
# Some helpful constants from the environment.
flat_video_obs_size = 64*64*3
obs_size = 64
ac_size = 64
self.matrix = np.random.random(size=(ac_size, flat_video_obs_size + obs_size))*2 -1
self.flatten_obs = lambda obs: np.concatenate([obs['pov'].flatten()/255.0, obs['vector'].flatten()])
self.act = lambda flat_obs: {'vector': np.clip(self.matrix.dot(flat_obs), -1,1)}
def run_agent_on_episode(self, single_episode_env : Episode):
"""Runs the agent on a SINGLE episode.
Args:
single_episode_env (Episode): The episode on which to run the agent.
"""
import torch
device = torch.device('cuda:0')
x = torch.randn(64, 1000, device=device, dtype=torch.float)
assert torch.cuda.is_available()
obs = single_episode_env.reset()
done = False
while not done:
obs,reward,done,_ = single_episode_env.step(self.act(self.flatten_obs(obs)))
class MineRLRandomAgent(MineRLAgentBase):
"""A random agent"""
def load_agent(self):
pass # Nothing to do, this agent is a random agent.
self.agent = get_agent(
n_actions=self.env.action_space.n, arch=ARCH, n_input_channels=self.env.observation_space.shape[0],
noisy_net_sigma=NOISY_NET_SIGMA, final_epsilon=FINAL_EPSILON,
final_exploration_frames=FINAL_EXPLORATION_FRAMES, explorer_sample_func=self.env.action_space.sample,
lr=LR, adam_eps=ADAM_EPS,
prioritized=PRIORITIZED, steps=STEPS, update_interval=UPDATE_INTERVAL,
replay_capacity=REPLAY_CAPACITY, num_step_return=NUM_STEP_RETURN,
agent_type=AGENT_TYPE, gpu=GPU, gamma=GAMMA, replay_start_size=REPLAY_START_SIZE,
target_update_interval=TARGET_UPDATE_INTERVAL, clip_delta=CLIP_DELTA,
batch_accumulator=BATCH_ACCUMULATOR,
)
self.agent.load(os.path.abspath(os.path.join(__file__, os.pardir, 'train')))
def run_agent_on_episode(self, single_episode_env: Episode):
with self.agent.eval_mode():
obs = single_episode_env.reset()
while True:
a = self.agent.act(obs)
obs, r, done, info = single_episode_env.step(a)
def run_agent_on_episode(self, single_episode_env : Episode):
obs = single_episode_env.reset()
done = False
while not done:
random_act = single_episode_env.action_space.sample()
single_episode_env.step(random_act)
#####################################################################
# IMPORTANT: SET THIS VARIABLE WITH THE AGENT CLASS YOU ARE USING #
######################################################################
AGENT_TO_TEST = MineRLMatrixAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAgentHere
AGENT_TO_TEST = MineRLRainbowBaselineAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAgentHere
......@@ -162,34 +174,60 @@ AGENT_TO_TEST = MineRLMatrixAgent # MineRLMatrixAgent, MineRLRandomAgent, YourAg
# EVALUATION CODE #
####################
def main():
agent = AGENT_TO_TEST()
# agent = AGENT_TO_TEST()
# assert isinstance(agent, MineRLAgentBase)
# agent.load_agent()
#
# assert MINERL_MAX_EVALUATION_EPISODES > 0
# assert EVALUATION_THREAD_COUNT > 0
#
# # Create the parallel envs (sequentially to prevent issues!)
# envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
# episodes_per_thread = [MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT for _ in range(EVALUATION_THREAD_COUNT)]
# episodes_per_thread[-1] += MINERL_MAX_EVALUATION_EPISODES - EVALUATION_THREAD_COUNT *(MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT)
# # A simple funciton to evaluate on episodes!
# def evaluate(i, env):
# print("[{}] Starting evaluator.".format(i))
# for i in range(episodes_per_thread[i]):
# try:
# agent.run_agent_on_episode(Episode(env))
# except EpisodeDone:
# print("[{}] Episode complete".format(i))
# pass
#
# evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)]
# for thread in evaluator_threads:
# thread.start()
#
# # wait fo the evaluation to finish
# for thread in evaluator_threads:
# thread.join()
assert MINERL_MAX_EVALUATION_EPISODES > 0
assert EVALUATION_THREAD_COUNT == 1
kmeans = joblib.load(os.path.abspath(os.path.join(__file__, os.pardir, 'train', 'kmeans.joblib')))
core_env = gym.make(MINERL_GYM_ENV)
env = wrap_env(
env=core_env, test=True, monitor=False, outdir=None,
frame_skip=FRAME_SKIP, gray_scale=GRAY_SCALE, frame_stack=FRAME_STACK,
randomize_action=RANDOMIZE_ACTION, eval_epsilon=EVAL_EPSILON,
action_choices=kmeans.cluster_centers_,
)
agent = AGENT_TO_TEST(env)
assert isinstance(agent, MineRLAgentBase)
agent.load_agent()
assert MINERL_MAX_EVALUATION_EPISODES > 0
assert EVALUATION_THREAD_COUNT > 0
# Create the parallel envs (sequentially to prevent issues!)
envs = [gym.make(MINERL_GYM_ENV) for _ in range(EVALUATION_THREAD_COUNT)]
episodes_per_thread = [MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT for _ in range(EVALUATION_THREAD_COUNT)]
episodes_per_thread[-1] += MINERL_MAX_EVALUATION_EPISODES - EVALUATION_THREAD_COUNT *(MINERL_MAX_EVALUATION_EPISODES // EVALUATION_THREAD_COUNT)
# A simple funciton to evaluate on episodes!
def evaluate(i, env):
for i in range(MINERL_MAX_EVALUATION_EPISODES):
print("[{}] Starting evaluator.".format(i))
for i in range(episodes_per_thread[i]):
try:
agent.run_agent_on_episode(Episode(env))
except EpisodeDone:
print("[{}] Episode complete".format(i))
pass
evaluator_threads = [threading.Thread(target=evaluate, args=(i, envs[i])) for i in range(EVALUATION_THREAD_COUNT)]
for thread in evaluator_threads:
thread.start()
# wait fo the evaluation to finish
for thread in evaluator_threads:
thread.join()
try:
agent.run_agent_on_episode(Episode(env))
except EpisodeDone:
print("[{}] Episode complete".format(i))
pass
if __name__ == "__main__":
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment