Commit 333ef2ca authored by pfrl_rainbow's avatar pfrl_rainbow

copy minerl2020-playground/src/mod

parent 0d88ef17
from logging import getLogger
import os
import tqdm
import numpy as np
from sklearn.cluster import KMeans
import joblib
import minerl
logger = getLogger(__name__)
class _KMeansCacheNotFound(FileNotFoundError):
pass
def cached_kmeans(cache_dir, env_id, n_clusters, random_state):
if cache_dir is None: # ignore cache
logger.info('Load dataset & do kmeans')
kmeans = _do_kmeans(env_id=env_id, n_clusters=n_clusters, random_state=random_state)
else:
filepath = os.path.join(cache_dir, env_id, f'n_clusters_{n_clusters}', f'random_state_{random_state}', 'kmeans.joblib')
try:
kmeans = _load_kmeans_result_cache(filepath)
logger.info('found kmeans cache')
except _KMeansCacheNotFound:
logger.info('kmeans cache not found. Load dataset & do kmeans & save result as cache')
kmeans = _do_kmeans(env_id=env_id, n_clusters=n_clusters, random_state=random_state)
_save_kmeans_result_cache(kmeans, filepath)
return kmeans
def _do_kmeans(env_id, n_clusters, random_state):
logger.debug(f'loading data...')
dat = minerl.data.make(env_id)
act_vectors = []
for _, act, _, _, _ in tqdm.tqdm(dat.batch_iter(batch_size=16, seq_len=32, num_epochs=1, preload_buffer_size=32, seed=random_state)):
act_vectors.append(act['vector'])
acts = np.concatenate(act_vectors).reshape(-1, 64)
logger.debug(f'loading data... done.')
logger.debug(f'executing keamns...')
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state).fit(acts)
logger.debug(f'executing keamns... done.')
return kmeans
# def _describe_kmeans_result(kmeans):
# result = [(obf_a, minerl.herobraine.envs.MINERL_TREECHOP_OBF_V0.unwrap_action({'vector': obf_a})) for obf_a in kmeans.cluster_centers_]
# logger.debug(result)
# return result
def _save_kmeans_result_cache(kmeans, filepath):
os.makedirs(os.path.dirname(filepath), exist_ok=True)
joblib.dump(kmeans, filepath)
logger.info(f'saved kmeans {filepath}')
def _load_kmeans_result_cache(filepath):
if not os.path.exists(filepath):
raise _KMeansCacheNotFound
logger.debug(f'loading kmeans {filepath}')
return joblib.load(filepath)
import os
import logging
import argparse
import numpy as np
import torch
import minerl # noqa: register MineRL envs as Gym envs.
import gym
import pfrl
# local modules
import sys
sys.path.append(os.path.abspath(os.path.join(__file__, os.pardir)))
import utils
from env_wrappers import wrap_env
from q_functions import parse_arch
from cached_kmeans import cached_kmeans
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser()
env_choices = [
# basic envs
'MineRLTreechop-v0',
'MineRLNavigate-v0', 'MineRLNavigateDense-v0', 'MineRLNavigateExtreme-v0', 'MineRLNavigateExtremeDense-v0',
'MineRLObtainIronPickaxe-v0', 'MineRLObtainIronPickaxeDense-v0',
'MineRLObtainDiamond-v0', 'MineRLObtainDiamondDense-v0',
# obfuscated envs
'MineRLTreechopVectorObf-v0',
'MineRLNavigateVectorObf-v0', 'MineRLNavigateExtremeVectorObf-v0',
# MineRL data pipeline fails for these envs: https://github.com/minerllabs/minerl/issues/364
# 'MineRLNavigateDenseVectorObf-v0', 'MineRLNavigateExtremeDenseVectorObf-v0',
'MineRLObtainDiamondVectorObf-v0', 'MineRLObtainDiamondDenseVectorObf-v0',
'MineRLObtainIronPickaxeVectorObf-v0', 'MineRLObtainIronPickaxeDenseVectorObf-v0',
# for debugging
'MineRLNavigateDenseFixed-v0', 'MineRLObtainTest-v0',
]
parser.add_argument('--env', type=str, choices=env_choices, required=True,
help='MineRL environment identifier.')
# meta settings
parser.add_argument('--outdir', type=str, default='results',
help='Directory path to save output files. If it does not exist, it will be created.')
parser.add_argument('--seed', type=int, default=0, help='Random seed [0, 2 ** 31)')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use, set to -1 if no GPU.')
parser.add_argument('--demo', action='store_true', default=False)
parser.add_argument('--load', type=str, default=None)
parser.add_argument('--logging-level', type=int, default=20, help='Logging level. 10:DEBUG, 20:INFO etc.')
parser.add_argument('--eval-n-runs', type=int, default=3)
parser.add_argument('--monitor', action='store_true', default=False,
help='Monitor env. Videos and additional information are saved as output files when evaluation.')
# training scheme (agent)
parser.add_argument('--agent', type=str, default='DQN', choices=['DQN', 'DoubleDQN', 'PAL', 'CategoricalDoubleDQN'])
# network architecture
parser.add_argument('--arch', type=str, default='dueling', choices=['dueling', 'distributed_dueling'],
help='Network architecture to use.')
# update rule settings
parser.add_argument('--update-interval', type=int, default=4, help='Frequency (in timesteps) of network updates.')
parser.add_argument('--frame-skip', type=int, default=None, help='Number of frames skipped (None for disable).')
parser.add_argument('--gamma', type=float, default=0.99, help='Discount rate.')
parser.add_argument('--no-clip-delta', dest='clip_delta', action='store_false')
parser.set_defaults(clip_delta=True)
parser.add_argument('--num-step-return', type=int, default=1)
parser.add_argument('--lr', type=float, default=2.5e-4, help='Learning rate.')
parser.add_argument('--adam-eps', type=float, default=1e-8, help='Epsilon for Adam.')
parser.add_argument('--batch-accumulator', type=str, default='sum', choices=['sum', 'mean'], help='accumulator for batch loss.')
# observation conversion related settings
parser.add_argument('--gray-scale', action='store_true', default=False, help='Convert pov into gray scaled image.')
parser.add_argument('--frame-stack', type=int, default=None, help='Number of frames stacked (None for disable).')
# exploration related settings
parser.add_argument('--final-exploration-frames', type=int, default=10 ** 6,
help='Timesteps after which we stop annealing exploration rate')
parser.add_argument('--final-epsilon', type=float, default=0.01, help='Final value of epsilon during training.')
parser.add_argument('--eval-epsilon', type=float, default=0.001, help='Exploration epsilon used during eval episodes.')
parser.add_argument('--noisy-net-sigma', type=float, default=None,
help='NoisyNet explorer switch. This disables following options: '
'--final-exploration-frames, --final-epsilon, --eval-epsilon')
# experience replay buffer related settings
parser.add_argument('--replay-capacity', type=int, default=10 ** 6, help='Maximum capacity for replay buffer.')
parser.add_argument('--replay-start-size', type=int, default=5 * 10 ** 4,
help='Minimum replay buffer size before performing gradient updates.')
parser.add_argument('--prioritized', action='store_true', default=False, help='Use prioritized experience replay.')
# target network related settings
parser.add_argument('--target-update-interval', type=int, default=3 * 10 ** 4,
help='Frequency (in timesteps) at which the target network is updated.')
# K-means related settings
parser.add_argument('--kmeans-n-clusters', type=int, default=30, help='#clusters for K-means')
args = parser.parse_args()
args.outdir = pfrl.experiments.prepare_output_dir(args, args.outdir)
log_format = '%(levelname)-8s - %(asctime)s - [%(name)s %(funcName)s %(lineno)d] %(message)s'
logging.basicConfig(filename=os.path.join(args.outdir, 'log.txt'), format=log_format, level=args.logging_level)
console_handler = logging.StreamHandler()
console_handler.setLevel(args.logging_level)
console_handler.setFormatter(logging.Formatter(log_format))
logging.getLogger('').addHandler(console_handler) # add hander to the root logger
logger.info('Output files will be saved in {}'.format(args.outdir))
utils.log_versions()
try:
_main(args)
except: # noqa
logger.exception('execution failed.')
raise
def _main(args):
os.environ['MALMO_MINECRAFT_OUTPUT_LOGDIR'] = args.outdir
# Set a random seed used in ChainerRL.
pfrl.utils.set_random_seed(args.seed)
# Set different random seeds for train and test envs.
train_seed = args.seed # noqa: never used in this script
test_seed = 2 ** 31 - 1 - args.seed
# K-Means
kmeans = cached_kmeans(
cache_dir=os.environ.get('KMEANS_CACHE'),
env_id=args.env,
n_clusters=args.kmeans_n_clusters,
random_state=args.seed)
# create & wrap env
def wrap_env_partial(env, test):
randomize_action = test and args.noisy_net_sigma is None
wrapped_env = wrap_env(
env=env, test=test,
env_id=args.env,
monitor=args.monitor, outdir=args.outdir,
frame_skip=args.frame_skip,
gray_scale=args.gray_scale, frame_stack=args.frame_stack,
randomize_action=randomize_action, eval_epsilon=args.eval_epsilon,
action_choices=kmeans.cluster_centers_)
return wrapped_env
logger.info('The first `gym.make(MineRL*)` may take several minutes. Be patient!')
core_env = gym.make(args.env)
# training env
env = wrap_env_partial(env=core_env, test=False)
# env.seed(int(train_seed)) # TODO: not supported yet
# evaluation env
eval_env = wrap_env_partial(env=core_env, test=True)
# env.seed(int(test_seed)) # TODO: not supported yet (also requires `core_eval_env = gym.make(args.env)`)
# calculate corresponding `steps` and `eval_interval` according to frameskip
# 8,000,000 frames = 1333 episodes if we count an episode as 6000 frames,
# 8,000,000 frames = 1000 episodes if we count an episode as 8000 frames.
maximum_frames = 8000000
if args.frame_skip is None:
steps = maximum_frames
eval_interval = 6000 * 100 # (approx.) every 100 episode (counts "1 episode = 6000 steps")
else:
steps = maximum_frames // args.frame_skip
eval_interval = 6000 * 100 // args.frame_skip # (approx.) every 100 episode (counts "1 episode = 6000 steps")
agent = get_agent(
n_actions=env.action_space.n, arch=args.arch, n_input_channels=env.observation_space.shape[0],
noisy_net_sigma=args.noisy_net_sigma, final_epsilon=args.final_epsilon,
final_exploration_frames=args.final_exploration_frames, explorer_sample_func=env.action_space.sample,
lr=args.lr, adam_eps=args.adam_eps,
prioritized=args.prioritized, steps=steps, update_interval=args.update_interval,
replay_capacity=args.replay_capacity, num_step_return=args.num_step_return,
agent_type=args.agent, gpu=args.gpu, gamma=args.gamma, replay_start_size=args.replay_start_size,
target_update_interval=args.target_update_interval, clip_delta=args.clip_delta,
batch_accumulator=args.batch_accumulator
)
if args.load:
agent.load(args.load)
# experiment
if args.demo:
eval_stats = pfrl.experiments.eval_performance(env=eval_env, agent=agent, n_steps=None, n_episodes=args.eval_n_runs)
logger.info('n_runs: {} mean: {} median: {} stdev {}'.format(
args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev']))
else:
pfrl.experiments.train_agent_with_evaluation(
agent=agent, env=env, steps=steps,
eval_n_steps=None, eval_n_episodes=args.eval_n_runs, eval_interval=eval_interval,
outdir=args.outdir, eval_env=eval_env, save_best_so_far_agent=True,
)
env.close()
eval_env.close()
def parse_agent(agent):
return {'DQN': pfrl.agents.DQN,
'DoubleDQN': pfrl.agents.DoubleDQN,
'PAL': pfrl.agents.PAL,
'CategoricalDoubleDQN': pfrl.agents.CategoricalDoubleDQN}[agent]
def get_agent(
n_actions, arch, n_input_channels,
noisy_net_sigma, final_epsilon, final_exploration_frames, explorer_sample_func,
lr, adam_eps,
prioritized, steps, update_interval, replay_capacity, num_step_return,
agent_type, gpu, gamma, replay_start_size, target_update_interval, clip_delta, batch_accumulator
):
# Q function
q_func = parse_arch(arch, n_actions, n_input_channels=n_input_channels)
# explorer
if noisy_net_sigma is not None:
pfrl.nn.to_factorized_noisy(q_func, sigma_scale=noisy_net_sigma)
# Turn off explorer
explorer = pfrl.explorers.Greedy()
else:
explorer = pfrl.explorers.LinearDecayEpsilonGreedy(
1.0, final_epsilon, final_exploration_frames, explorer_sample_func)
opt = torch.optim.Adam(q_func.parameters(), lr, eps=adam_eps) # NOTE: mirrors DQN implementation in MineRL paper
# Select a replay buffer to use
if prioritized:
# Anneal beta from beta0 to 1 throughout training
betasteps = steps / update_interval
rbuf = pfrl.replay_buffers.PrioritizedReplayBuffer(
replay_capacity, alpha=0.5, beta0=0.4, betasteps=betasteps, num_steps=num_step_return)
else:
rbuf = pfrl.replay_buffers.ReplayBuffer(replay_capacity, num_step_return)
# build agent
def phi(x):
# observation -> NN input
return np.asarray(x)
Agent = parse_agent(agent_type)
agent = Agent(
q_func, opt, rbuf, gpu=gpu, gamma=gamma, explorer=explorer, replay_start_size=replay_start_size,
target_update_interval=target_update_interval, clip_delta=clip_delta, update_interval=update_interval,
batch_accumulator=batch_accumulator, phi=phi)
return agent
if __name__ == '__main__':
main()
import copy
from logging import getLogger
from collections import deque
import os
import gym
import numpy as np
import cv2
from pfrl.wrappers import ContinuingTimeLimit, RandomizeAction, Monitor
from pfrl.wrappers.atari_wrappers import ScaledFloatFrame, LazyFrames
cv2.ocl.setUseOpenCL(False)
logger = getLogger(__name__)
def wrap_env(
env, test,
env_id,
monitor, outdir,
frame_skip,
gray_scale, frame_stack,
randomize_action, eval_epsilon,
action_choices):
# wrap env: time limit...
if isinstance(env, gym.wrappers.TimeLimit):
logger.info('Detected `gym.wrappers.TimeLimit`! Unwrap it and re-wrap our own time limit.')
env = env.env
max_episode_steps = env.spec.max_episode_steps
env = ContinuingTimeLimit(env, max_episode_steps=max_episode_steps)
# wrap env: observation...
# NOTE: wrapping order matters!
if test and monitor:
env = Monitor(
env, os.path.join(outdir, env.spec.id, 'monitor'),
mode='evaluation' if test else 'training', video_callable=lambda episode_id: True)
if frame_skip is not None:
env = FrameSkip(env, skip=frame_skip)
if gray_scale:
env = GrayScaleWrapper(env, dict_space_key='pov')
env = ObtainPoVWrapper(env)
env = MoveAxisWrapper(env, source=-1, destination=0) # convert hwc -> chw as Pytorch requires.
env = ScaledFloatFrame(env)
if frame_stack is not None and frame_stack > 0:
env = FrameStack(env, frame_stack, channel_order='chw')
env = ClusteredActionWrapper(env, clusters=action_choices)
if randomize_action:
env = RandomizeAction(env, eval_epsilon)
return env
class FrameSkip(gym.Wrapper):
"""Return every `skip`-th frame and repeat given action during skip.
Note that this wrapper does not "maximize" over the skipped frames.
"""
def __init__(self, env, skip=4):
super().__init__(env)
self._skip = skip
def step(self, action):
total_reward = 0.0
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
total_reward += reward
if done:
break
return obs, total_reward, done, info
class FrameStack(gym.Wrapper):
def __init__(self, env, k, channel_order='hwc', use_tuple=False):
"""Stack k last frames.
Returns lazy array, which is much more memory efficient.
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.observations = deque([], maxlen=k)
self.stack_axis = {'hwc': 2, 'chw': 0}[channel_order]
self.use_tuple = use_tuple
if self.use_tuple:
pov_space = env.observation_space[0]
inv_space = env.observation_space[1]
else:
pov_space = env.observation_space
low_pov = np.repeat(pov_space.low, k, axis=self.stack_axis)
high_pov = np.repeat(pov_space.high, k, axis=self.stack_axis)
pov_space = gym.spaces.Box(low=low_pov, high=high_pov, dtype=pov_space.dtype)
if self.use_tuple:
low_inv = np.repeat(inv_space.low, k, axis=0)
high_inv = np.repeat(inv_space.high, k, axis=0)
inv_space = gym.spaces.Box(low=low_inv, high=high_inv, dtype=inv_space.dtype)
self.observation_space = gym.spaces.Tuple(
(pov_space, inv_space))
else:
self.observation_space = pov_space
def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.observations.append(ob)
return self._get_ob()
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.observations.append(ob)
return self._get_ob(), reward, done, info
def _get_ob(self):
assert len(self.observations) == self.k
if self.use_tuple:
frames = [x[0] for x in self.observations]
inventory = [x[1] for x in self.observations]
return (LazyFrames(list(frames), stack_axis=self.stack_axis),
LazyFrames(list(inventory), stack_axis=0))
else:
return LazyFrames(list(self.observations), stack_axis=self.stack_axis)
class ObtainPoVWrapper(gym.ObservationWrapper):
"""Obtain 'pov' value (current game display) of the original observation."""
def __init__(self, env):
super().__init__(env)
self.observation_space = self.env.observation_space.spaces['pov']
def observation(self, observation):
return observation['pov']
class UnifiedObservationWrapper(gym.ObservationWrapper):
"""Take 'pov', 'compassAngle', 'inventory' and concatenate with scaling.
Each element of 'inventory' is converted to a square whose side length is region_size.
The color of each square is correlated to the reciprocal of (the number of the corresponding item + 1).
"""
def __init__(self, env, region_size=8):
super().__init__(env)
self._compass_angle_scale = 180 / 255 # NOTE: `ScaledFloatFrame` will scale the pixel values with 255.0 later
self.region_size = region_size
pov_space = self.env.observation_space.spaces['pov']
low_dict = {'pov': pov_space.low}
high_dict = {'pov': pov_space.high}
if 'compassAngle' in self.env.observation_space.spaces:
compass_angle_space = self.env.observation_space.spaces['compassAngle']
low_dict['compassAngle'] = compass_angle_space.low
high_dict['compassAngle'] = compass_angle_space.high
if 'inventory' in self.env.observation_space.spaces:
inventory_space = self.env.observation_space.spaces['inventory']
low_dict['inventory'] = {}
high_dict['inventory'] = {}
for key in inventory_space.spaces.keys():
low_dict['inventory'][key] = inventory_space.spaces[key].low
high_dict['inventory'][key] = inventory_space.spaces[key].high
low = self.observation(low_dict)
high = self.observation(high_dict)
self.observation_space = gym.spaces.Box(low=low, high=high)
def observation(self, observation):
obs = observation['pov']
pov_dtype = obs.dtype
if 'compassAngle' in observation:
compass_scaled = observation['compassAngle'] / self._compass_angle_scale
compass_channel = np.ones(shape=list(obs.shape[:-1]) + [1], dtype=pov_dtype) * compass_scaled
obs = np.concatenate([obs, compass_channel], axis=-1)
if 'inventory' in observation:
assert len(obs.shape[:-1]) == 2
region_max_height = obs.shape[0]
region_max_width = obs.shape[1]
rs = self.region_size
if min(region_max_height, region_max_width) < rs:
raise ValueError("'region_size' is too large.")
num_element_width = region_max_width // rs
inventory_channel = np.zeros(shape=list(obs.shape[:-1]) + [1], dtype=pov_dtype)
for idx, key in enumerate(observation['inventory']):
item_scaled = np.clip(255 - 255 / (observation['inventory'][key] + 1), # Inversed
0, 255)
item_channel = np.ones(shape=[rs, rs, 1], dtype=pov_dtype) * item_scaled
width_low = (idx % num_element_width) * rs
height_low = (idx // num_element_width) * rs
if height_low + rs > region_max_height:
raise ValueError("Too many elements on 'inventory'. Please decrease 'region_size' of each component")
inventory_channel[height_low:(height_low + rs), width_low:(width_low + rs), :] = item_channel
obs = np.concatenate([obs, inventory_channel], axis=-1)
return obs
class FullObservationSpaceWrapper(gym.ObservationWrapper):
"""Returns as observation a tuple with the frames and a list of
compassAngle and inventory items.
compassAngle is scaled to be in the interval [-1, 1] and inventory items
are scaled to be in the interval [0, 1]
"""
def __init__(self, env):
super().__init__(env)
pov_space = self.env.observation_space.spaces['pov']
low_dict = {'pov': pov_space.low, 'inventory': {}}
high_dict = {'pov': pov_space.high, 'inventory': {}}
for obs_name in self.env.observation_space.spaces['inventory'].spaces.keys():
obs_space = self.env.observation_space.spaces['inventory'].spaces[obs_name]
low_dict['inventory'][obs_name] = obs_space.low
high_dict['inventory'][obs_name] = obs_space.high
if 'compassAngle' in self.env.observation_space.spaces:
compass_angle_space = self.env.observation_space.spaces['compassAngle']
low_dict['compassAngle'] = compass_angle_space.low
high_dict['compassAngle'] = compass_angle_space.high
low = self.observation(low_dict)
high = self.observation(high_dict)
pov_space = gym.spaces.Box(low=low[0], high=high[0])
inventory_space = gym.spaces.Box(low=low[1], high=high[1])
self.observation_space = gym.spaces.Tuple((pov_space, inventory_space))
def observation(self, observation):
frame = observation['pov']
inventory = []
if 'compassAngle' in observation:
compass_scaled = observation['compassAngle'] / 180
inventory.append(compass_scaled)
for obs_name in observation['inventory'].keys():
inventory.append(observation['inventory'][obs_name] / 2304)
inventory = np.array(inventory)
return (frame, inventory)
class MoveAxisWrapper(gym.ObservationWrapper):
"""Move axes of observation ndarrays."""
def __init__(self, env, source, destination, use_tuple=False):
if use_tuple:
assert isinstance(env.observation_space[0], gym.spaces.Box)
else:
assert isinstance(env.observation_space, gym.spaces.Box)
super().__init__(env)
self.source = source
self.destination = destination
self.use_tuple = use_tuple
if self.use_tuple:
low = self.observation(
tuple([space.low for space in self.observation_space]))
high = self.observation(
tuple([space.high for space in self.observation_space]))
dtype = self.observation_space[0].dtype
pov_space = gym.spaces.Box(low=low[0], high=high[0], dtype=dtype)
inventory_space = self.observation_space[1]
self.observation_space = gym.spaces.Tuple(
(pov_space, inventory_space))
else:
low = self.observation(self.observation_space.low)
high = self.observation(self.observation_space.high)
dtype = self.observation_space.dtype
self.observation_space = gym.spaces.Box(
low=low, high=high, dtype=dtype)
def observation(self, observation):
if self.use_tuple:
new_observation = list(observation)
new_observation[0] = np.moveaxis(
observation[0], self.source, self.destination)
return tuple(new_observation)
else:
return np.moveaxis(observation, self.source, self.destination)
class GrayScaleWrapper(gym.ObservationWrapper):
def __init__(self, env, dict_space_key=None):
super().__init__(env)
self._key = dict_space_key
if self._key is None:
original_space = self.observation_space
else:
original_space = self.observation_space.spaces[self._key]
height, width = original_space.shape[0], original_space.shape[1]
# sanity checks
ideal_image_space = gym.spaces.Box(low=0, high=255, shape=(height, width, 3), dtype=np.uint8)
if original_space != ideal_image_space:
raise ValueError('Image space should be {}, but given {}.'.format(ideal_image_space, original_space))
if original_space.dtype != np.uint8:
raise ValueError('Image should `np.uint8` typed, but given {}.'.format(original_space.dtype))
height, width = original_space.shape[0], original_space.shape[1]
new_space = gym.spaces.Box(low=0, high=255, shape=(height, width, 1), dtype=np.uint8)
if self._key is None:
self.observation_space = new_space
else:
new_space_dict = copy.deepcopy(self.observation_space)
new_space_dict.spaces[self._key] = new_space
self.observation_space = new_space_dict
def observation(self, obs):
if self._key is None:
frame = obs
else:
frame = obs[self._key]
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = np.expand_dims(frame, -1)
if self._key is None:
obs = frame
else:
obs[self._key] = frame
return obs
class ClusteredActionWrapper(gym.ActionWrapper):
def __init__(self, env, clusters):
super().__init__(env)
self._clusters = clusters
self._np_random = np.random.RandomState()
self.action_space = gym.spaces.Discrete(len(clusters))
def action(self, action):
return {'vector': self._clusters[action]}