Source code for pfrl.agents.ddpg

import collections
import copy
from logging import getLogger

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from pfrl.agent import AttributeSavingMixin, BatchAgent
from pfrl.replay_buffer import ReplayUpdater, batch_experiences
from pfrl.utils.batch_states import batch_states
from pfrl.utils.contexts import evaluating
from pfrl.utils.copy_param import synchronize_parameters


def _mean_or_nan(xs):
    """Return its mean a non-empty sequence, numpy.nan for a empty one."""
    return np.mean(xs) if xs else np.nan


[docs]class DDPG(AttributeSavingMixin, BatchAgent): """Deep Deterministic Policy Gradients. This can be used as SVG(0) by specifying a Gaussian policy instead of a deterministic policy. Args: policy (torch.nn.Module): Policy q_func (torch.nn.Module): Q-function actor_optimizer (Optimizer): Optimizer setup with the policy critic_optimizer (Optimizer): Optimizer setup with the Q-function replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update batch_accumulator (str): 'mean' or 'sum' episodic_update (bool): Use full episodes for update if set True episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` burnin_action_func (callable or None): If not None, this callable object is used to select actions before the model is updated one or more times during training. """ saved_attributes = ("model", "target_model", "actor_optimizer", "critic_optimizer") def __init__( self, policy, q_func, actor_optimizer, critic_optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=50000, minibatch_size=32, update_interval=1, target_update_interval=10000, phi=lambda x: x, target_update_method="hard", soft_update_tau=1e-2, n_times_update=1, recurrent=False, episodic_update_len=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, ): self.model = nn.ModuleList([policy, q_func]) if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.logger = logger self.actor_optimizer = actor_optimizer self.critic_optimizer = critic_optimizer self.recurrent = recurrent assert not self.recurrent, "recurrent=True is not yet implemented" if self.recurrent: update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.t = 0 self.last_state = None self.last_action = None self.target_model = copy.deepcopy(self.model) self.target_model.eval() self.q_record = collections.deque(maxlen=1000) self.actor_loss_record = collections.deque(maxlen=100) self.critic_loss_record = collections.deque(maxlen=100) self.n_updates = 0 # Aliases for convenience self.policy, self.q_function = self.model self.target_policy, self.target_q_function = self.target_model self.sync_target_network() def sync_target_network(self): """Synchronize target network with current network.""" synchronize_parameters( src=self.model, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau, ) # Update Q-function def compute_critic_loss(self, batch): """Compute loss for critic.""" batch_next_state = batch["next_state"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_actions = batch["action"] batchsize = len(batch_rewards) with torch.no_grad(): assert not self.recurrent next_actions = self.target_policy(batch_next_state).sample() next_q = self.target_q_function((batch_next_state, next_actions)) target_q = batch_rewards + self.gamma * ( 1.0 - batch_terminal ) * next_q.reshape((batchsize,)) predict_q = self.q_function((batch_state, batch_actions)).reshape((batchsize,)) loss = F.mse_loss(target_q, predict_q) # Update stats self.critic_loss_record.append(float(loss.detach().cpu().numpy())) return loss def compute_actor_loss(self, batch): """Compute loss for actor.""" batch_state = batch["state"] onpolicy_actions = self.policy(batch_state).rsample() q = self.q_function((batch_state, onpolicy_actions)) loss = -q.mean() # Update stats self.q_record.extend(q.detach().cpu().numpy()) self.actor_loss_record.append(float(loss.detach().cpu().numpy())) return loss def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.critic_optimizer.zero_grad() self.compute_critic_loss(batch).backward() self.critic_optimizer.step() self.actor_optimizer.zero_grad() self.compute_actor_loss(batch).backward() self.actor_optimizer.step() self.n_updates += 1 def update_from_episodes(self, episodes, errors_out=None): raise NotImplementedError # Sort episodes desc by their lengths sorted_episodes = list(reversed(sorted(episodes, key=len))) max_epi_len = len(sorted_episodes[0]) # Precompute all the input batches batches = [] for i in range(max_epi_len): transitions = [] for ep in sorted_episodes: if len(ep) <= i: break transitions.append([ep[i]]) batch = batch_experiences( transitions, xp=self.device, phi=self.phi, gamma=self.gamma ) batches.append(batch) with self.model.state_reset(), self.target_model.state_reset(): # Since the target model is evaluated one-step ahead, # its internal states need to be updated self.target_q_function.update_state( batches[0]["state"], batches[0]["action"] ) self.target_policy(batches[0]["state"]) # Update critic through time critic_loss = 0 for batch in batches: critic_loss += self.compute_critic_loss(batch) self.critic_optimizer.update(lambda: critic_loss / max_epi_len) with self.model.state_reset(): # Update actor through time actor_loss = 0 for batch in batches: actor_loss += self.compute_actor_loss(batch) self.actor_optimizer.update(lambda: actor_loss / max_epi_len) def batch_act(self, batch_obs): if self.training: return self._batch_act_train(batch_obs) else: return self._batch_act_eval(batch_obs) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) def _batch_select_greedy_actions(self, batch_obs): with torch.no_grad(), evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) batch_action = self.policy(batch_xs).sample() return batch_action.cpu().numpy() def _batch_act_eval(self, batch_obs): assert not self.training return self._batch_select_greedy_actions(batch_obs) def _batch_act_train(self, batch_obs): assert self.training if self.burnin_action_func is not None and self.n_updates == 0: batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))] else: batch_greedy_action = self._batch_select_greedy_actions(batch_obs) batch_action = [ self.explorer.select_action(self.t, lambda: batch_greedy_action[i]) for i in range(len(batch_greedy_action)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i in range(len(batch_obs)): self.t += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer self.replay_buffer.append( state=self.batch_last_obs[i], action=self.batch_last_action[i], reward=batch_reward[i], next_state=batch_obs[i], next_action=None, is_state_terminal=batch_done[i], env_id=i, ) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) def get_statistics(self): return [ ("average_q", _mean_or_nan(self.q_record)), ("average_actor_loss", _mean_or_nan(self.actor_loss_record)), ("average_critic_loss", _mean_or_nan(self.critic_loss_record)), ("n_updates", self.n_updates), ]