From 333ef2ca4990a5b9bfe5bac205019bac5c9d8e6e Mon Sep 17 00:00:00 2001 From: NAKATA Keisuke Date: Sun, 2 Aug 2020 17:16:13 +0900 Subject: [PATCH] copy minerl2020-playground/src/mod --- mod/cached_kmeans.py | 63 ++++++++ mod/dqn_family.py | 255 +++++++++++++++++++++++++++++++ mod/env_wrappers.py | 346 +++++++++++++++++++++++++++++++++++++++++++ mod/q_functions.py | 89 +++++++++++ mod/utils.py | 10 ++ 5 files changed, 763 insertions(+) create mode 100644 mod/cached_kmeans.py create mode 100644 mod/dqn_family.py create mode 100644 mod/env_wrappers.py create mode 100644 mod/q_functions.py create mode 100644 mod/utils.py diff --git a/mod/cached_kmeans.py b/mod/cached_kmeans.py new file mode 100644 index 0000000..ee9ea18 --- /dev/null +++ b/mod/cached_kmeans.py @@ -0,0 +1,63 @@ +from logging import getLogger +import os + +import tqdm +import numpy as np +from sklearn.cluster import KMeans +import joblib +import minerl + +logger = getLogger(__name__) + + +class _KMeansCacheNotFound(FileNotFoundError): + pass + + +def cached_kmeans(cache_dir, env_id, n_clusters, random_state): + if cache_dir is None: # ignore cache + logger.info('Load dataset & do kmeans') + kmeans = _do_kmeans(env_id=env_id, n_clusters=n_clusters, random_state=random_state) + else: + filepath = os.path.join(cache_dir, env_id, f'n_clusters_{n_clusters}', f'random_state_{random_state}', 'kmeans.joblib') + try: + kmeans = _load_kmeans_result_cache(filepath) + logger.info('found kmeans cache') + except _KMeansCacheNotFound: + logger.info('kmeans cache not found. Load dataset & do kmeans & save result as cache') + kmeans = _do_kmeans(env_id=env_id, n_clusters=n_clusters, random_state=random_state) + _save_kmeans_result_cache(kmeans, filepath) + return kmeans + + +def _do_kmeans(env_id, n_clusters, random_state): + logger.debug(f'loading data...') + dat = minerl.data.make(env_id) + act_vectors = [] + for _, act, _, _, _ in tqdm.tqdm(dat.batch_iter(batch_size=16, seq_len=32, num_epochs=1, preload_buffer_size=32, seed=random_state)): + act_vectors.append(act['vector']) + acts = np.concatenate(act_vectors).reshape(-1, 64) + logger.debug(f'loading data... done.') + logger.debug(f'executing keamns...') + kmeans = KMeans(n_clusters=n_clusters, random_state=random_state).fit(acts) + logger.debug(f'executing keamns... done.') + return kmeans + + +# def _describe_kmeans_result(kmeans): +# result = [(obf_a, minerl.herobraine.envs.MINERL_TREECHOP_OBF_V0.unwrap_action({'vector': obf_a})) for obf_a in kmeans.cluster_centers_] +# logger.debug(result) +# return result + + +def _save_kmeans_result_cache(kmeans, filepath): + os.makedirs(os.path.dirname(filepath), exist_ok=True) + joblib.dump(kmeans, filepath) + logger.info(f'saved kmeans {filepath}') + + +def _load_kmeans_result_cache(filepath): + if not os.path.exists(filepath): + raise _KMeansCacheNotFound + logger.debug(f'loading kmeans {filepath}') + return joblib.load(filepath) diff --git a/mod/dqn_family.py b/mod/dqn_family.py new file mode 100644 index 0000000..73fed26 --- /dev/null +++ b/mod/dqn_family.py @@ -0,0 +1,255 @@ +import os +import logging +import argparse + +import numpy as np +import torch +import minerl # noqa: register MineRL envs as Gym envs. +import gym + +import pfrl + + +# local modules +import sys +sys.path.append(os.path.abspath(os.path.join(__file__, os.pardir))) +import utils +from env_wrappers import wrap_env +from q_functions import parse_arch +from cached_kmeans import cached_kmeans + +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + + env_choices = [ + # basic envs + 'MineRLTreechop-v0', + 'MineRLNavigate-v0', 'MineRLNavigateDense-v0', 'MineRLNavigateExtreme-v0', 'MineRLNavigateExtremeDense-v0', + 'MineRLObtainIronPickaxe-v0', 'MineRLObtainIronPickaxeDense-v0', + 'MineRLObtainDiamond-v0', 'MineRLObtainDiamondDense-v0', + # obfuscated envs + 'MineRLTreechopVectorObf-v0', + 'MineRLNavigateVectorObf-v0', 'MineRLNavigateExtremeVectorObf-v0', + # MineRL data pipeline fails for these envs: https://github.com/minerllabs/minerl/issues/364 + # 'MineRLNavigateDenseVectorObf-v0', 'MineRLNavigateExtremeDenseVectorObf-v0', + 'MineRLObtainDiamondVectorObf-v0', 'MineRLObtainDiamondDenseVectorObf-v0', + 'MineRLObtainIronPickaxeVectorObf-v0', 'MineRLObtainIronPickaxeDenseVectorObf-v0', + # for debugging + 'MineRLNavigateDenseFixed-v0', 'MineRLObtainTest-v0', + ] + parser.add_argument('--env', type=str, choices=env_choices, required=True, + help='MineRL environment identifier.') + + # meta settings + parser.add_argument('--outdir', type=str, default='results', + help='Directory path to save output files. If it does not exist, it will be created.') + parser.add_argument('--seed', type=int, default=0, help='Random seed [0, 2 ** 31)') + parser.add_argument('--gpu', type=int, default=0, help='GPU to use, set to -1 if no GPU.') + parser.add_argument('--demo', action='store_true', default=False) + parser.add_argument('--load', type=str, default=None) + parser.add_argument('--logging-level', type=int, default=20, help='Logging level. 10:DEBUG, 20:INFO etc.') + parser.add_argument('--eval-n-runs', type=int, default=3) + parser.add_argument('--monitor', action='store_true', default=False, + help='Monitor env. Videos and additional information are saved as output files when evaluation.') + + # training scheme (agent) + parser.add_argument('--agent', type=str, default='DQN', choices=['DQN', 'DoubleDQN', 'PAL', 'CategoricalDoubleDQN']) + + # network architecture + parser.add_argument('--arch', type=str, default='dueling', choices=['dueling', 'distributed_dueling'], + help='Network architecture to use.') + + # update rule settings + parser.add_argument('--update-interval', type=int, default=4, help='Frequency (in timesteps) of network updates.') + parser.add_argument('--frame-skip', type=int, default=None, help='Number of frames skipped (None for disable).') + parser.add_argument('--gamma', type=float, default=0.99, help='Discount rate.') + parser.add_argument('--no-clip-delta', dest='clip_delta', action='store_false') + parser.set_defaults(clip_delta=True) + parser.add_argument('--num-step-return', type=int, default=1) + parser.add_argument('--lr', type=float, default=2.5e-4, help='Learning rate.') + parser.add_argument('--adam-eps', type=float, default=1e-8, help='Epsilon for Adam.') + parser.add_argument('--batch-accumulator', type=str, default='sum', choices=['sum', 'mean'], help='accumulator for batch loss.') + + # observation conversion related settings + parser.add_argument('--gray-scale', action='store_true', default=False, help='Convert pov into gray scaled image.') + parser.add_argument('--frame-stack', type=int, default=None, help='Number of frames stacked (None for disable).') + + # exploration related settings + parser.add_argument('--final-exploration-frames', type=int, default=10 ** 6, + help='Timesteps after which we stop annealing exploration rate') + parser.add_argument('--final-epsilon', type=float, default=0.01, help='Final value of epsilon during training.') + parser.add_argument('--eval-epsilon', type=float, default=0.001, help='Exploration epsilon used during eval episodes.') + parser.add_argument('--noisy-net-sigma', type=float, default=None, + help='NoisyNet explorer switch. This disables following options: ' + '--final-exploration-frames, --final-epsilon, --eval-epsilon') + + # experience replay buffer related settings + parser.add_argument('--replay-capacity', type=int, default=10 ** 6, help='Maximum capacity for replay buffer.') + parser.add_argument('--replay-start-size', type=int, default=5 * 10 ** 4, + help='Minimum replay buffer size before performing gradient updates.') + parser.add_argument('--prioritized', action='store_true', default=False, help='Use prioritized experience replay.') + + # target network related settings + parser.add_argument('--target-update-interval', type=int, default=3 * 10 ** 4, + help='Frequency (in timesteps) at which the target network is updated.') + + # K-means related settings + parser.add_argument('--kmeans-n-clusters', type=int, default=30, help='#clusters for K-means') + + args = parser.parse_args() + + args.outdir = pfrl.experiments.prepare_output_dir(args, args.outdir) + + log_format = '%(levelname)-8s - %(asctime)s - [%(name)s %(funcName)s %(lineno)d] %(message)s' + logging.basicConfig(filename=os.path.join(args.outdir, 'log.txt'), format=log_format, level=args.logging_level) + console_handler = logging.StreamHandler() + console_handler.setLevel(args.logging_level) + console_handler.setFormatter(logging.Formatter(log_format)) + logging.getLogger('').addHandler(console_handler) # add hander to the root logger + + logger.info('Output files will be saved in {}'.format(args.outdir)) + + utils.log_versions() + + try: + _main(args) + except: # noqa + logger.exception('execution failed.') + raise + + +def _main(args): + os.environ['MALMO_MINECRAFT_OUTPUT_LOGDIR'] = args.outdir + + # Set a random seed used in ChainerRL. + pfrl.utils.set_random_seed(args.seed) + + # Set different random seeds for train and test envs. + train_seed = args.seed # noqa: never used in this script + test_seed = 2 ** 31 - 1 - args.seed + + # K-Means + kmeans = cached_kmeans( + cache_dir=os.environ.get('KMEANS_CACHE'), + env_id=args.env, + n_clusters=args.kmeans_n_clusters, + random_state=args.seed) + + # create & wrap env + def wrap_env_partial(env, test): + randomize_action = test and args.noisy_net_sigma is None + wrapped_env = wrap_env( + env=env, test=test, + env_id=args.env, + monitor=args.monitor, outdir=args.outdir, + frame_skip=args.frame_skip, + gray_scale=args.gray_scale, frame_stack=args.frame_stack, + randomize_action=randomize_action, eval_epsilon=args.eval_epsilon, + action_choices=kmeans.cluster_centers_) + return wrapped_env + logger.info('The first `gym.make(MineRL*)` may take several minutes. Be patient!') + core_env = gym.make(args.env) + # training env + env = wrap_env_partial(env=core_env, test=False) + # env.seed(int(train_seed)) # TODO: not supported yet + # evaluation env + eval_env = wrap_env_partial(env=core_env, test=True) + # env.seed(int(test_seed)) # TODO: not supported yet (also requires `core_eval_env = gym.make(args.env)`) + + # calculate corresponding `steps` and `eval_interval` according to frameskip + # 8,000,000 frames = 1333 episodes if we count an episode as 6000 frames, + # 8,000,000 frames = 1000 episodes if we count an episode as 8000 frames. + maximum_frames = 8000000 + if args.frame_skip is None: + steps = maximum_frames + eval_interval = 6000 * 100 # (approx.) every 100 episode (counts "1 episode = 6000 steps") + else: + steps = maximum_frames // args.frame_skip + eval_interval = 6000 * 100 // args.frame_skip # (approx.) every 100 episode (counts "1 episode = 6000 steps") + + agent = get_agent( + n_actions=env.action_space.n, arch=args.arch, n_input_channels=env.observation_space.shape[0], + noisy_net_sigma=args.noisy_net_sigma, final_epsilon=args.final_epsilon, + final_exploration_frames=args.final_exploration_frames, explorer_sample_func=env.action_space.sample, + lr=args.lr, adam_eps=args.adam_eps, + prioritized=args.prioritized, steps=steps, update_interval=args.update_interval, + replay_capacity=args.replay_capacity, num_step_return=args.num_step_return, + agent_type=args.agent, gpu=args.gpu, gamma=args.gamma, replay_start_size=args.replay_start_size, + target_update_interval=args.target_update_interval, clip_delta=args.clip_delta, + batch_accumulator=args.batch_accumulator + ) + + if args.load: + agent.load(args.load) + + # experiment + if args.demo: + eval_stats = pfrl.experiments.eval_performance(env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_runs) + logger.info('n_runs: {} mean: {} median: {} stdev {}'.format( + args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev'])) + else: + pfrl.experiments.train_agent_with_evaluation( + agent=agent, env=env, steps=steps, + eval_n_steps=None, eval_n_episodes=args.eval_n_runs, eval_interval=eval_interval, + outdir=args.outdir, eval_env=eval_env, save_best_so_far_agent=True, + ) + + env.close() + eval_env.close() + + +def parse_agent(agent): + return {'DQN': pfrl.agents.DQN, + 'DoubleDQN': pfrl.agents.DoubleDQN, + 'PAL': pfrl.agents.PAL, + 'CategoricalDoubleDQN': pfrl.agents.CategoricalDoubleDQN}[agent] + + +def get_agent( + n_actions, arch, n_input_channels, + noisy_net_sigma, final_epsilon, final_exploration_frames, explorer_sample_func, + lr, adam_eps, + prioritized, steps, update_interval, replay_capacity, num_step_return, + agent_type, gpu, gamma, replay_start_size, target_update_interval, clip_delta, batch_accumulator +): + # Q function + q_func = parse_arch(arch, n_actions, n_input_channels=n_input_channels) + + # explorer + if noisy_net_sigma is not None: + pfrl.nn.to_factorized_noisy(q_func, sigma_scale=noisy_net_sigma) + # Turn off explorer + explorer = pfrl.explorers.Greedy() + else: + explorer = pfrl.explorers.LinearDecayEpsilonGreedy( + 1.0, final_epsilon, final_exploration_frames, explorer_sample_func) + + opt = torch.optim.Adam(q_func.parameters(), lr, eps=adam_eps) # NOTE: mirrors DQN implementation in MineRL paper + + # Select a replay buffer to use + if prioritized: + # Anneal beta from beta0 to 1 throughout training + betasteps = steps / update_interval + rbuf = pfrl.replay_buffers.PrioritizedReplayBuffer( + replay_capacity, alpha=0.5, beta0=0.4, betasteps=betasteps, num_steps=num_step_return) + else: + rbuf = pfrl.replay_buffers.ReplayBuffer(replay_capacity, num_step_return) + + # build agent + def phi(x): + # observation -> NN input + return np.asarray(x) + Agent = parse_agent(agent_type) + agent = Agent( + q_func, opt, rbuf, gpu=gpu, gamma=gamma, explorer=explorer, replay_start_size=replay_start_size, + target_update_interval=target_update_interval, clip_delta=clip_delta, update_interval=update_interval, + batch_accumulator=batch_accumulator, phi=phi) + + return agent + + +if __name__ == '__main__': + main() diff --git a/mod/env_wrappers.py b/mod/env_wrappers.py new file mode 100644 index 0000000..571ef64 --- /dev/null +++ b/mod/env_wrappers.py @@ -0,0 +1,346 @@ +import copy +from logging import getLogger +from collections import deque +import os + +import gym +import numpy as np +import cv2 + +from pfrl.wrappers import ContinuingTimeLimit, RandomizeAction, Monitor +from pfrl.wrappers.atari_wrappers import ScaledFloatFrame, LazyFrames + +cv2.ocl.setUseOpenCL(False) +logger = getLogger(__name__) + + +def wrap_env( + env, test, + env_id, + monitor, outdir, + frame_skip, + gray_scale, frame_stack, + randomize_action, eval_epsilon, + action_choices): + # wrap env: time limit... + if isinstance(env, gym.wrappers.TimeLimit): + logger.info('Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.') + env = env.env + max_episode_steps = env.spec.max_episode_steps + env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps) + + # wrap env: observation... + # NOTE: wrapping order matters! + + if test and monitor: + env = Monitor( + env, os.path.join(outdir, env.spec.id, 'monitor'), + mode='evaluation' if test else 'training', video_callable=lambda episode_id: True) + if frame_skip is not None: + env = FrameSkip(env, skip=frame_skip) + if gray_scale: + env = GrayScaleWrapper(env, dict_space_key='pov') + env = ObtainPoVWrapper(env) + env = MoveAxisWrapper(env, source=-1, destination=0) # convert hwc -> chw as Pytorch requires. + env = ScaledFloatFrame(env) + if frame_stack is not None and frame_stack > 0: + env = FrameStack(env, frame_stack, channel_order='chw') + + env = ClusteredActionWrapper(env, clusters=action_choices) + + if randomize_action: + env = RandomizeAction(env, eval_epsilon) + + return env + + +class FrameSkip(gym.Wrapper): + """Return every `skip`-th frame and repeat given action during skip. + + Note that this wrapper does not "maximize" over the skipped frames. + """ + def __init__(self, env, skip=4): + super().__init__(env) + + self._skip = skip + + def step(self, action): + total_reward = 0.0 + for _ in range(self._skip): + obs, reward, done, info = self.env.step(action) + total_reward += reward + if done: + break + return obs, total_reward, done, info + + +class FrameStack(gym.Wrapper): + def __init__(self, env, k, channel_order='hwc', use_tuple=False): + """Stack k last frames. + + Returns lazy array, which is much more memory efficient. + """ + gym.Wrapper.__init__(self, env) + self.k = k + self.observations = deque([], maxlen=k) + self.stack_axis = {'hwc': 2, 'chw': 0}[channel_order] + self.use_tuple = use_tuple + + if self.use_tuple: + pov_space = env.observation_space[0] + inv_space = env.observation_space[1] + else: + pov_space = env.observation_space + + low_pov = np.repeat(pov_space.low, k, axis=self.stack_axis) + high_pov = np.repeat(pov_space.high, k, axis=self.stack_axis) + pov_space = gym.spaces.Box(low=low_pov, high=high_pov, dtype=pov_space.dtype) + + if self.use_tuple: + low_inv = np.repeat(inv_space.low, k, axis=0) + high_inv = np.repeat(inv_space.high, k, axis=0) + inv_space = gym.spaces.Box(low=low_inv, high=high_inv, dtype=inv_space.dtype) + self.observation_space = gym.spaces.Tuple( + (pov_space, inv_space)) + else: + self.observation_space = pov_space + + def reset(self): + ob = self.env.reset() + for _ in range(self.k): + self.observations.append(ob) + return self._get_ob() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.observations.append(ob) + return self._get_ob(), reward, done, info + + def _get_ob(self): + assert len(self.observations) == self.k + if self.use_tuple: + frames = [x[0] for x in self.observations] + inventory = [x[1] for x in self.observations] + return (LazyFrames(list(frames), stack_axis=self.stack_axis), + LazyFrames(list(inventory), stack_axis=0)) + else: + return LazyFrames(list(self.observations), stack_axis=self.stack_axis) + + +class ObtainPoVWrapper(gym.ObservationWrapper): + """Obtain 'pov' value (current game display) of the original observation.""" + def __init__(self, env): + super().__init__(env) + + self.observation_space = self.env.observation_space.spaces['pov'] + + def observation(self, observation): + return observation['pov'] + + +class UnifiedObservationWrapper(gym.ObservationWrapper): + """Take 'pov', 'compassAngle', 'inventory' and concatenate with scaling. + Each element of 'inventory' is converted to a square whose side length is region_size. + The color of each square is correlated to the reciprocal of (the number of the corresponding item + 1). + """ + def __init__(self, env, region_size=8): + super().__init__(env) + + self._compass_angle_scale = 180 / 255 # NOTE: `ScaledFloatFrame` will scale the pixel values with 255.0 later + self.region_size = region_size + + pov_space = self.env.observation_space.spaces['pov'] + low_dict = {'pov': pov_space.low} + high_dict = {'pov': pov_space.high} + + if 'compassAngle' in self.env.observation_space.spaces: + compass_angle_space = self.env.observation_space.spaces['compassAngle'] + low_dict['compassAngle'] = compass_angle_space.low + high_dict['compassAngle'] = compass_angle_space.high + + if 'inventory' in self.env.observation_space.spaces: + inventory_space = self.env.observation_space.spaces['inventory'] + low_dict['inventory'] = {} + high_dict['inventory'] = {} + for key in inventory_space.spaces.keys(): + low_dict['inventory'][key] = inventory_space.spaces[key].low + high_dict['inventory'][key] = inventory_space.spaces[key].high + + low = self.observation(low_dict) + high = self.observation(high_dict) + + self.observation_space = gym.spaces.Box(low=low, high=high) + + def observation(self, observation): + obs = observation['pov'] + pov_dtype = obs.dtype + + if 'compassAngle' in observation: + compass_scaled = observation['compassAngle'] / self._compass_angle_scale + compass_channel = np.ones(shape=list(obs.shape[:-1]) + [1], dtype=pov_dtype) * compass_scaled + obs = np.concatenate([obs, compass_channel], axis=-1) + if 'inventory' in observation: + assert len(obs.shape[:-1]) == 2 + region_max_height = obs.shape[0] + region_max_width = obs.shape[1] + rs = self.region_size + if min(region_max_height, region_max_width) < rs: + raise ValueError("'region_size' is too large.") + num_element_width = region_max_width // rs + inventory_channel = np.zeros(shape=list(obs.shape[:-1]) + [1], dtype=pov_dtype) + for idx, key in enumerate(observation['inventory']): + item_scaled = np.clip(255 - 255 / (observation['inventory'][key] + 1), # Inversed + 0, 255) + item_channel = np.ones(shape=[rs, rs, 1], dtype=pov_dtype) * item_scaled + width_low = (idx % num_element_width) * rs + height_low = (idx // num_element_width) * rs + if height_low + rs > region_max_height: + raise ValueError("Too many elements on 'inventory'. Please decrease 'region_size' of each component") + inventory_channel[height_low:(height_low + rs), width_low:(width_low + rs), :] = item_channel + obs = np.concatenate([obs, inventory_channel], axis=-1) + return obs + + +class FullObservationSpaceWrapper(gym.ObservationWrapper): + """Returns as observation a tuple with the frames and a list of + compassAngle and inventory items. + compassAngle is scaled to be in the interval [-1, 1] and inventory items + are scaled to be in the interval [0, 1] + """ + def __init__(self, env): + super().__init__(env) + + pov_space = self.env.observation_space.spaces['pov'] + + low_dict = {'pov': pov_space.low, 'inventory': {}} + high_dict = {'pov': pov_space.high, 'inventory': {}} + + for obs_name in self.env.observation_space.spaces['inventory'].spaces.keys(): + obs_space = self.env.observation_space.spaces['inventory'].spaces[obs_name] + low_dict['inventory'][obs_name] = obs_space.low + high_dict['inventory'][obs_name] = obs_space.high + + if 'compassAngle' in self.env.observation_space.spaces: + compass_angle_space = self.env.observation_space.spaces['compassAngle'] + low_dict['compassAngle'] = compass_angle_space.low + high_dict['compassAngle'] = compass_angle_space.high + + low = self.observation(low_dict) + high = self.observation(high_dict) + + pov_space = gym.spaces.Box(low=low[0], high=high[0]) + inventory_space = gym.spaces.Box(low=low[1], high=high[1]) + self.observation_space = gym.spaces.Tuple((pov_space, inventory_space)) + + def observation(self, observation): + frame = observation['pov'] + inventory = [] + + if 'compassAngle' in observation: + compass_scaled = observation['compassAngle'] / 180 + inventory.append(compass_scaled) + + for obs_name in observation['inventory'].keys(): + inventory.append(observation['inventory'][obs_name] / 2304) + + inventory = np.array(inventory) + return (frame, inventory) + + +class MoveAxisWrapper(gym.ObservationWrapper): + """Move axes of observation ndarrays.""" + def __init__(self, env, source, destination, use_tuple=False): + if use_tuple: + assert isinstance(env.observation_space[0], gym.spaces.Box) + else: + assert isinstance(env.observation_space, gym.spaces.Box) + super().__init__(env) + + self.source = source + self.destination = destination + self.use_tuple = use_tuple + + if self.use_tuple: + low = self.observation( + tuple([space.low for space in self.observation_space])) + high = self.observation( + tuple([space.high for space in self.observation_space])) + dtype = self.observation_space[0].dtype + pov_space = gym.spaces.Box(low=low[0], high=high[0], dtype=dtype) + inventory_space = self.observation_space[1] + self.observation_space = gym.spaces.Tuple( + (pov_space, inventory_space)) + else: + low = self.observation(self.observation_space.low) + high = self.observation(self.observation_space.high) + dtype = self.observation_space.dtype + self.observation_space = gym.spaces.Box( + low=low, high=high, dtype=dtype) + + def observation(self, observation): + if self.use_tuple: + new_observation = list(observation) + new_observation[0] = np.moveaxis( + observation[0], self.source, self.destination) + return tuple(new_observation) + else: + return np.moveaxis(observation, self.source, self.destination) + + +class GrayScaleWrapper(gym.ObservationWrapper): + def __init__(self, env, dict_space_key=None): + super().__init__(env) + + self._key = dict_space_key + + if self._key is None: + original_space = self.observation_space + else: + original_space = self.observation_space.spaces[self._key] + height, width = original_space.shape[0], original_space.shape[1] + + # sanity checks + ideal_image_space = gym.spaces.Box(low=0, high=255, shape=(height, width, 3), dtype=np.uint8) + if original_space != ideal_image_space: + raise ValueError('Image space should be {}, but given {}.'.format(ideal_image_space, original_space)) + if original_space.dtype != np.uint8: + raise ValueError('Image should `np.uint8` typed, but given {}.'.format(original_space.dtype)) + + height, width = original_space.shape[0], original_space.shape[1] + new_space = gym.spaces.Box(low=0, high=255, shape=(height, width, 1), dtype=np.uint8) + if self._key is None: + self.observation_space = new_space + else: + new_space_dict = copy.deepcopy(self.observation_space) + new_space_dict.spaces[self._key] = new_space + self.observation_space = new_space_dict + + def observation(self, obs): + if self._key is None: + frame = obs + else: + frame = obs[self._key] + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + frame = np.expand_dims(frame, -1) + if self._key is None: + obs = frame + else: + obs[self._key] = frame + return obs + + +class ClusteredActionWrapper(gym.ActionWrapper): + def __init__(self, env, clusters): + super().__init__(env) + self._clusters = clusters + + self._np_random = np.random.RandomState() + + self.action_space = gym.spaces.Discrete(len(clusters)) + + def action(self, action): + return {'vector': self._clusters[action]} + + def seed(self, seed): + super().seed(seed) + self._np_random.seed(seed) diff --git a/mod/q_functions.py b/mod/q_functions.py new file mode 100644 index 0000000..85e662d --- /dev/null +++ b/mod/q_functions.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pfrl import action_value +from pfrl.q_function import StateQFunction +from pfrl.q_functions.dueling_dqn import constant_bias_initializer +from pfrl.initializers import init_chainer_default + + +def parse_arch(arch, n_actions, n_input_channels): + if arch == 'dueling': + # Conv2Ds of (channel, kernel, stride): [(32, 8, 4), (64, 4, 2), (64, 3, 1)] + # return DuelingDQN(n_actions, n_input_channels=n_input_channels, hiddens=[256]) + raise NotImplementedError('dueling') + elif arch == 'distributed_dueling': + n_atoms = 51 + v_min = -10 + v_max = 10 + return DistributionalDuelingDQN(n_actions, n_atoms, v_min, v_max, n_input_channels=n_input_channels) + else: + raise RuntimeError('Unsupported architecture name: {}'.format(arch)) + + +class DistributionalDuelingDQN(nn.Module, StateQFunction): + """Distributional dueling fully-connected Q-function with discrete actions.""" + + def __init__( + self, + n_actions, + n_atoms, + v_min, + v_max, + n_input_channels=4, + activation=torch.relu, + bias=0.1, + ): + assert n_atoms >= 2 + assert v_min < v_max + + self.n_actions = n_actions + self.n_input_channels = n_input_channels + self.activation = activation + self.n_atoms = n_atoms + + super().__init__() + self.z_values = torch.linspace(v_min, v_max, n_atoms, dtype=torch.float32) + + self.conv_layers = nn.ModuleList( + [ + nn.Conv2d(n_input_channels, 32, 8, stride=4), + nn.Conv2d(32, 64, 4, stride=2), + nn.Conv2d(64, 64, 3, stride=1), + ] + ) + + # ここだけ変える必要があった + # self.main_stream = nn.Linear(3136, 1024) + self.main_stream = nn.Linear(1024, 1024) + self.a_stream = nn.Linear(512, n_actions * n_atoms) + self.v_stream = nn.Linear(512, n_atoms) + + self.apply(init_chainer_default) + self.conv_layers.apply(constant_bias_initializer(bias=bias)) + + def forward(self, x): + h = x + for l in self.conv_layers: + h = self.activation(l(h)) + + # Advantage + batch_size = x.shape[0] + + h = self.activation(self.main_stream(h.view(batch_size, -1))) + h_a, h_v = torch.chunk(h, 2, dim=1) + ya = self.a_stream(h_a).reshape((batch_size, self.n_actions, self.n_atoms)) + + mean = ya.sum(dim=1, keepdim=True) / self.n_actions + + ya, mean = torch.broadcast_tensors(ya, mean) + ya -= mean + + # State value + ys = self.v_stream(h_v).reshape((batch_size, 1, self.n_atoms)) + ya, ys = torch.broadcast_tensors(ya, ys) + q = F.softmax(ya + ys, dim=2) + + self.z_values = self.z_values.to(x.device) + return action_value.DistributionalDiscreteActionValue(q, self.z_values) diff --git a/mod/utils.py b/mod/utils.py new file mode 100644 index 0000000..65d5958 --- /dev/null +++ b/mod/utils.py @@ -0,0 +1,10 @@ +import sys +from pip._internal.operations import freeze +from logging import getLogger + +logger = getLogger(__name__) + + +def log_versions(): + logger.info(sys.version) # Python version + logger.info(','.join(freeze.freeze())) # pip freeze -- GitLab