import copy import collections import time import multiprocessing as mp from logging import getLogger import torch import torch.nn.functional as F import numpy as np import pfrl from pfrl import agent from pfrl.utils.batch_states import batch_states from pfrl.utils.contexts import evaluating from pfrl.utils.copy_param import synchronize_parameters from pfrl.replay_buffer import batch_experiences from pfrl.replay_buffer import batch_recurrent_experiences from pfrl.replay_buffer import ReplayUpdater from pfrl.utils.recurrent import get_recurrent_state_at from pfrl.utils.recurrent import mask_recurrent_state_at from pfrl.utils.recurrent import one_step_forward from pfrl.utils.recurrent import pack_and_forward from pfrl.utils.recurrent import recurrent_state_as_numpy def _mean_or_nan(xs): """Return its mean a non-empty sequence, numpy.nan for a empty one.""" return np.mean(xs) if xs else np.nan def compute_value_loss(y, t, clip_delta=True, batch_accumulator="mean"): """Compute a loss for value prediction problem. Args: y (torch.Tensor): Predicted values. t (torch.Tensor): Target values. clip_delta (bool): Use the Huber loss function with delta=1 if set True. batch_accumulator (str): 'mean' or 'sum'. 'mean' will use the mean of the loss values in a batch. 'sum' will use the sum. Returns: (torch.Tensor) scalar loss """ assert batch_accumulator in ("mean", "sum") y = y.reshape(-1, 1) t = t.reshape(-1, 1) if clip_delta: return F.smooth_l1_loss(y, t, reduction=batch_accumulator) else: return F.mse_loss(y, t, reduction=batch_accumulator) / 2 def compute_weighted_value_loss( y, t, weights, clip_delta=True, batch_accumulator="mean" ): """Compute a loss for value prediction problem. Args: y (torch.Tensor): Predicted values. t (torch.Tensor): Target values. weights (torch.Tensor): Weights for y, t. clip_delta (bool): Use the Huber loss function with delta=1 if set True. batch_accumulator (str): 'mean' will divide loss by batchsize Returns: (torch.Tensor) scalar loss """ assert batch_accumulator in ("mean", "sum") y = y.reshape(-1, 1) t = t.reshape(-1, 1) if clip_delta: losses = F.smooth_l1_loss(y, t, reduction="none") else: losses = F.mse_loss(y, t, reduction="none") / 2 losses = losses.reshape(-1,) weights = weights.to(losses.device) loss_sum = torch.sum(losses * weights) if batch_accumulator == "mean": loss = loss_sum / y.shape[0] elif batch_accumulator == "sum": loss = loss_sum return loss def _batch_reset_recurrent_states_when_episodes_end( batch_done, batch_reset, recurrent_states ): """Reset recurrent states when episodes end. Args: batch_done (array-like of bool): True iff episodes are terminal. batch_reset (array-like of bool): True iff episodes will be reset. recurrent_states (object): Recurrent state. Returns: object: New recurrent states. """ indices_that_ended = [ i for i, (done, reset) in enumerate(zip(batch_done, batch_reset)) if done or reset ] if indices_that_ended: return mask_recurrent_state_at(recurrent_states, indices_that_ended) else: return recurrent_states def load_experiences_from_demonstrations( expert_dataset, batch_size, reward=1): if expert_dataset is None: raise ValueError("Expert dataset must be provided.") ret = [] for _ in range(batch_size): ob, act, _, next_ob, done = expert_dataset.sample() ret.append([dict( state=ob, action=act, reward=reward, next_state=next_ob, next_action=None, is_state_terminal=done)]) return ret class RewardBasedSampler: # Sampling based on proportion of each subtask visited by agent. def __init__(self, expert_dataset, reward_boundaries, max_buffer_size=1000, reward=1): self.expert_dataset = expert_dataset self.reward_boundaries = reward_boundaries self.replay_buffers = [ pfrl.replay_buffers.ReplayBuffer(max_buffer_size, 1) for _ in range(len(reward_boundaries) + 1)] self.reward_scale = reward # Fill replay buffers while True: update_needed = False for rbuf in self.replay_buffers: if len(rbuf) < max_buffer_size: update_needed = True if update_needed: self._update_buffer(max_buffer_size) else: break def _policy_index(self, ob): cum_reward = np.array(ob)[-1, 0, 0] return np.sum( cum_reward > self.reward_boundaries / self.reward_boundaries[-1] - 1e-8) def _update_buffer(self, n): for _ in range(n): ob, act, _, next_ob, done = self.expert_dataset.sample() self.replay_buffers[self._policy_index(ob)].append( state=ob, action=act, reward=self.reward_scale, next_state=next_ob, next_action=None, is_state_terminal=done) def sample(self, experiences): # update_samples self._update_buffer(len(experiences)) n_samples = [0 for _ in range(len(self.reward_boundaries) + 1)] for frame in experiences: n_samples[self._policy_index(frame[0]['state'])] += 1 ret = [] for rbuf, n_sample in zip(self.replay_buffers, n_samples): samples = rbuf.sample(n_sample) for frame in samples: ret.append(frame) return ret class SQIL(agent.AttributeSavingMixin, agent.BatchAgent): """Deep Q-Network algorithm. Args: q_function (StateQFunction): Q-function optimizer (Optimizer): Optimizer that is already setup replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step clip_delta (bool): Clip delta if set True phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update batch_accumulator (str): 'mean' or 'sum' episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` recurrent (bool): If set to True, `model` is assumed to implement `pfrl.nn.Recurrent` and is updated in a recurrent manner. Changes from DQN: remove recurrent support add expert dataset """ saved_attributes = ("model", "target_model", "optimizer") def __init__( self, q_function, optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, clip_delta=True, phi=lambda x: x, target_update_method="hard", soft_update_tau=1e-2, n_times_update=1, batch_accumulator="mean", episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, expert_dataset=None, reward_scale=1.0, experience_lambda=1.0, recurrent=False, reward_boundaries=None, # specific to options ): self.expert_dataset = expert_dataset self.model = q_function if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ("mean", "sum") self.logger = logger self.batch_states = batch_states self.recurrent = recurrent if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.minibatch_size = minibatch_size self.episodic_update_len = episodic_update_len self.replay_start_size = replay_start_size self.update_interval = update_interval assert ( target_update_interval % update_interval == 0 ), "target_update_interval should be a multiple of update_interval" # For imitation self.reward_scale = reward_scale self.experience_lambda = experience_lambda if reward_boundaries is not None: self.reward_based_sampler = RewardBasedSampler( self.expert_dataset, reward_boundaries, reward=reward_scale) else: self.reward_based_sampler = None self.t = 0 self.optim_t = 0 # Compensate pytorch optim not having `t` self._cumulative_steps = 0 self.last_state = None self.last_action = None self.target_model = None self.sync_target_network() # Statistics self.q_record = collections.deque(maxlen=1000) self.loss_record = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states = None self.train_prev_recurrent_states = None self.test_recurrent_states = None # Error checking if ( self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size ): raise ValueError("Replay start size cannot exceed replay buffer capacity.") @property def cumulative_steps(self): # cumulative_steps counts the overall steps during the training. return self._cumulative_steps def _setup_actor_learner_training(self, n_actors, actor_update_interval): assert actor_update_interval > 0 self.actor_update_interval = actor_update_interval self.update_counter = 0 # Make a copy on shared memory and share among actors and the poller shared_model = copy.deepcopy(self.model).cpu() shared_model.share_memory() # Pipes are used for infrequent communication learner_pipes, actor_pipes = list(zip(*[mp.Pipe() for _ in range(n_actors)])) return (shared_model, learner_pipes, actor_pipes) def sync_target_network(self): """Synchronize target network with current network.""" if self.target_model is None: self.target_model = copy.deepcopy(self.model) def flatten_parameters(mod): if isinstance(mod, torch.nn.RNNBase): mod.flatten_parameters() # RNNBase.flatten_parameters must be called again after deep-copy. # See: https://discuss.pytorch.org/t/why-do-we-need-flatten-parameters-when-using-rnn-with-dataparallel/46506 # NOQA self.target_model.apply(flatten_parameters) # set target n/w to evaluate only. self.target_model.eval() else: synchronize_parameters( src=self.model, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau, ) def update(self, experiences, errors_out=None): """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None Changes from DQN: Learned from demonstrations """ has_weight = "weight" in experiences[0][0] exp_batch = batch_experiences( experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) if has_weight: exp_batch["weights"] = torch.tensor( [elem[0]["weight"] for elem in experiences], device=self.device, dtype=torch.float32, ) if errors_out is None: errors_out = [] if self.reward_based_sampler is not None: demo_experiences = self.reward_based_sampler.sample(experiences) else: demo_experiences = load_experiences_from_demonstrations( self.expert_dataset, self.replay_updater.batchsize, self.reward_scale) demo_batch = batch_experiences( demo_experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, demo_batch, errors_out=errors_out) if has_weight: self.replay_buffer.update_errors(errors_out) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.optim_t += 1 def update_from_episodes(self, episodes, errors_out=None): assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer" episodes = sorted(episodes, key=len, reverse=True) exp_batch = batch_recurrent_experiences( episodes, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) demo_experiences = load_experiences_from_demonstrations( self.expert_dataset, self.replay_updater.batchsize, self.reward_scale) demo_batch = batch_experiences( demo_experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, demo_batch, errors_out=None) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.optim_t += 1 def _compute_target_values(self, exp_batch): """ Changes from DQN: Consider soft Bellman error """ batch_next_state = exp_batch["next_state"] target_next_qout = self.target_model(batch_next_state) next_q_max = torch.broadcast_tensors( target_next_qout.q_values.max(dim=-1, keepdim=True)[0], target_next_qout.q_values)[0] next_q_soft = ( next_q_max[:, 0] + (target_next_qout.q_values - next_q_max).exp().sum(dim=-1).log() ) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] # return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max return batch_rewards + discount * (1.0 - batch_terminal) * next_q_soft def _compute_y_and_t(self, exp_batch): batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] if self.recurrent: qout, _ = pack_and_forward( self.model, batch_state, exp_batch["recurrent_state"] ) else: qout = self.model(batch_state) batch_actions = exp_batch["action"] batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1)) with torch.no_grad(): batch_q_target = torch.reshape( self._compute_target_values(exp_batch), (batch_size, 1) ) return batch_q, batch_q_target def __compute_loss(self, exp_batch, errors_out): y, t = self._compute_y_and_t(exp_batch) self.q_record.extend(y.detach().cpu().numpy().ravel()) if errors_out is not None: del errors_out[:] delta = torch.abs(y - t) if delta.ndim == 2: delta = torch.sum(delta, dim=1) delta = delta.detach().cpu().numpy() for e in delta: errors_out.append(e) if "weights" in exp_batch: return compute_weighted_value_loss( y, t, exp_batch["weights"], clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) else: return compute_value_loss( y, t, clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) def _compute_loss(self, exp_batch, demo_batch, errors_out=None): """Compute the Q-learning loss for a batch of experiences Args: exp_batch (dict): A dict of batched arrays of transitions Returns: Computed loss from the minibatch of experiences Changes from DQN: Learned from demonstrations """ exp_loss = self.__compute_loss(exp_batch, errors_out=errors_out) demo_loss = self.__compute_loss(demo_batch, errors_out=None) return (exp_loss * self.experience_lambda + demo_loss) / 2 def _evaluate_model_and_update_recurrent_states(self, batch_obs): batch_xs = self.batch_states(batch_obs, self.device, self.phi) if self.recurrent: if self.training: self.train_prev_recurrent_states = self.train_recurrent_states batch_av, self.train_recurrent_states = one_step_forward( self.model, batch_xs, self.train_recurrent_states ) else: batch_av, self.test_recurrent_states = one_step_forward( self.model, batch_xs, self.test_recurrent_states ) else: batch_av = self.model(batch_xs) return batch_av def batch_act(self, batch_obs): with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) batch_argmax = batch_av.greedy_actions.cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i : i + 1], ) for i in range(len(batch_obs)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) else: # stochastic batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i : i + 1], ) for i in range(len(batch_obs)) ] # deterministic # batch_action = batch_argmax return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): for i in range(len(batch_obs)): self.t += 1 self._cumulative_steps += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer transition = { "state": self.batch_last_obs[i], "action": self.batch_last_action[i], "reward": batch_reward[i], "next_state": batch_obs[i], "next_action": None, "is_state_terminal": batch_done[i], } if self.recurrent: transition["recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True ) ) transition["next_recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_recurrent_states, i, detach=True ) ) self.replay_buffer.append(env_id=i, **transition) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) if self.recurrent: # Reset recurrent states when episodes end self.train_prev_recurrent_states = None self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.train_recurrent_states, ) def _batch_observe_eval(self, batch_obs, batch_reward, batch_done, batch_reset): if self.recurrent: # Reset recurrent states when episodes end self.test_recurrent_states = _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.test_recurrent_states, ) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: return self._batch_observe_train( batch_obs, batch_reward, batch_done, batch_reset ) else: return self._batch_observe_eval( batch_obs, batch_reward, batch_done, batch_reset ) def _can_start_replay(self): if len(self.replay_buffer) < self.replay_start_size: return False if self.recurrent and self.replay_buffer.n_episodes < self.minibatch_size: return False return True def _poll_pipe(self, actor_idx, pipe, replay_buffer_lock, exception_event): if pipe.closed: return try: while pipe.poll() and not exception_event.is_set(): cmd, data = pipe.recv() if cmd == "get_statistics": assert data is None with replay_buffer_lock: stats = self.get_statistics() pipe.send(stats) elif cmd == "load": self.load(data) pipe.send(None) elif cmd == "save": self.save(data) pipe.send(None) elif cmd == "transition": with replay_buffer_lock: if "env_id" not in data: data["env_id"] = actor_idx self.replay_buffer.append(**data) self._cumulative_steps += 1 elif cmd == "stop_episode": idx = actor_idx if data is None else data with replay_buffer_lock: self.replay_buffer.stop_current_episode(env_id=idx) stats = self.get_statistics() pipe.send(stats) else: raise RuntimeError("Unknown command from actor: {}".format(cmd)) except EOFError: pipe.close() except Exception: self.logger.exception("Poller loop failed. Exiting") exception_event.set() def _learner_loop( self, shared_model, pipes, replay_buffer_lock, stop_event, exception_event, n_updates=None, ): try: update_counter = 0 # To stop this loop, call stop_event.set() while not stop_event.is_set(): # Update model if possible if not self._can_start_replay(): continue if n_updates is not None: assert self.optim_t <= n_updates if self.optim_t == n_updates: stop_event.set() break if self.recurrent: with replay_buffer_lock: episodes = self.replay_buffer.sample_episodes( self.minibatch_size, self.episodic_update_len ) self.update_from_episodes(episodes) else: with replay_buffer_lock: transitions = self.replay_buffer.sample(self.minibatch_size) self.update(transitions) # Update the shared model. This can be expensive if GPU is used # since this is a DtoH copy, so it is updated only at regular # intervals. update_counter += 1 if update_counter % self.actor_update_interval == 0: self.update_counter += 1 shared_model.load_state_dict(self.model.state_dict()) # To keep the ratio of target updates to model updates, # here we calculate back the effective current timestep # from update_interval and number of updates so far. effective_timestep = self.optim_t * self.update_interval # We can safely assign self.t since in the learner # it isn't updated by any other method self.t = effective_timestep if effective_timestep % self.target_update_interval == 0: self.sync_target_network() except Exception: self.logger.exception("Learner loop failed. Exiting") exception_event.set() def _poller_loop( self, shared_model, pipes, replay_buffer_lock, stop_event, exception_event ): # To stop this loop, call stop_event.set() while not stop_event.is_set() and not exception_event.is_set(): time.sleep(1e-6) # Poll actors for messages for i, pipe in enumerate(pipes): self._poll_pipe(i, pipe, replay_buffer_lock, exception_event) def setup_actor_learner_training( self, n_actors, n_updates=None, actor_update_interval=8 ): (shared_model, learner_pipes, actor_pipes) = self._setup_actor_learner_training( n_actors, actor_update_interval ) exception_event = mp.Event() def make_actor(i): return pfrl.agents.StateQFunctionActor( pipe=actor_pipes[i], model=shared_model, explorer=self.explorer, phi=self.phi, batch_states=self.batch_states, logger=self.logger, recurrent=self.recurrent, ) replay_buffer_lock = mp.Lock() poller_stop_event = mp.Event() poller = pfrl.utils.StoppableThread( target=self._poller_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=poller_stop_event, exception_event=exception_event, ), stop_event=poller_stop_event, ) learner_stop_event = mp.Event() learner = pfrl.utils.StoppableThread( target=self._learner_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=learner_stop_event, n_updates=n_updates, exception_event=exception_event, ), stop_event=learner_stop_event, ) return make_actor, learner, poller, exception_event def stop_episode(self): if self.recurrent: self.test_recurrent_states = None def get_statistics(self): return [ ("average_q", _mean_or_nan(self.q_record)), ("average_loss", _mean_or_nan(self.loss_record)), ("cumulative_steps", self.cumulative_steps), ("n_updates", self.optim_t), ("rlen", len(self.replay_buffer)), ]