Source code for pfrl.agents.td3

import collections
import copy
from logging import getLogger

import numpy as np
import torch
from torch.nn import functional as F

import pfrl
from pfrl.agent import AttributeSavingMixin, BatchAgent
from pfrl.replay_buffer import ReplayUpdater, batch_experiences
from pfrl.utils import clip_l2_grad_norm_
from pfrl.utils.batch_states import batch_states
from pfrl.utils.copy_param import synchronize_parameters


def _mean_or_nan(xs):
    """Return its mean a non-empty sequence, numpy.nan for a empty one."""
    return np.mean(xs) if xs else np.nan


def default_target_policy_smoothing_func(batch_action):
    """Add noises to actions for target policy smoothing."""
    noise = torch.clamp(0.2 * torch.randn_like(batch_action), -0.5, 0.5)
    return torch.clamp(batch_action + noise, -1, 1)


[docs]class TD3(AttributeSavingMixin, BatchAgent): """Twin Delayed Deep Deterministic Policy Gradients (TD3). See http://arxiv.org/abs/1802.09477 Args: policy (Policy): Policy. q_func1 (Module): First Q-function that takes state-action pairs as input and outputs predicted Q-values. q_func2 (Module): Second Q-function that takes state-action pairs as input and outputs predicted Q-values. policy_optimizer (Optimizer): Optimizer setup with the policy q_func1_optimizer (Optimizer): Optimizer setup with the first Q-function. q_func2_optimizer (Optimizer): Optimizer setup with the second Q-function. replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step phi (callable): Feature extractor applied to observations soft_update_tau (float): Tau of soft target update. logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` burnin_action_func (callable or None): If not None, this callable object is used to select actions before the model is updated one or more times during training. policy_update_delay (int): Delay of policy updates. Policy is updated once in `policy_update_delay` times of Q-function updates. target_policy_smoothing_func (callable): Callable that takes a batch of actions as input and outputs a noisy version of it. It is used for target policy smoothing when computing target Q-values. """ saved_attributes = ( "policy", "q_func1", "q_func2", "target_policy", "target_q_func1", "target_q_func2", "policy_optimizer", "q_func1_optimizer", "q_func2_optimizer", ) def __init__( self, policy, q_func1, q_func2, policy_optimizer, q_func1_optimizer, q_func2_optimizer, replay_buffer, gamma, explorer, gpu=None, replay_start_size=10000, minibatch_size=100, update_interval=1, phi=lambda x: x, soft_update_tau=5e-3, n_times_update=1, max_grad_norm=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, policy_update_delay=2, target_policy_smoothing_func=default_target_policy_smoothing_func, ): self.policy = policy self.q_func1 = q_func1 self.q_func2 = q_func2 if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.policy.to(self.device) self.q_func1.to(self.device) self.q_func2.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.phi = phi self.soft_update_tau = soft_update_tau self.logger = logger self.policy_optimizer = policy_optimizer self.q_func1_optimizer = q_func1_optimizer self.q_func2_optimizer = q_func2_optimizer self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=self.update, batchsize=minibatch_size, n_times_update=1, replay_start_size=replay_start_size, update_interval=update_interval, episodic_update=False, ) self.max_grad_norm = max_grad_norm self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.policy_update_delay = policy_update_delay self.target_policy_smoothing_func = target_policy_smoothing_func self.t = 0 self.policy_n_updates = 0 self.q_func_n_updates = 0 self.last_state = None self.last_action = None # Target model self.target_policy = copy.deepcopy(self.policy).eval().requires_grad_(False) self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False) self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False) # Statistics self.q1_record = collections.deque(maxlen=1000) self.q2_record = collections.deque(maxlen=1000) self.q_func1_loss_record = collections.deque(maxlen=100) self.q_func2_loss_record = collections.deque(maxlen=100) self.policy_loss_record = collections.deque(maxlen=100) def sync_target_network(self): """Synchronize target network with current network.""" synchronize_parameters( src=self.policy, dst=self.target_policy, method="soft", tau=self.soft_update_tau, ) synchronize_parameters( src=self.q_func1, dst=self.target_q_func1, method="soft", tau=self.soft_update_tau, ) synchronize_parameters( src=self.q_func2, dst=self.target_q_func2, method="soft", tau=self.soft_update_tau, ) def update_q_func(self, batch): """Compute loss for a given Q-function.""" batch_next_state = batch["next_state"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_actions = batch["action"] batch_discount = batch["discount"] with torch.no_grad(), pfrl.utils.evaluating( self.target_policy ), pfrl.utils.evaluating(self.target_q_func1), pfrl.utils.evaluating( self.target_q_func2 ): next_actions = self.target_policy_smoothing_func( self.target_policy(batch_next_state).sample() ) next_q1 = self.target_q_func1((batch_next_state, next_actions)) next_q2 = self.target_q_func2((batch_next_state, next_actions)) next_q = torch.min(next_q1, next_q2) target_q = batch_rewards + batch_discount * ( 1.0 - batch_terminal ) * torch.flatten(next_q) predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions))) predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions))) loss1 = F.mse_loss(target_q, predict_q1) loss2 = F.mse_loss(target_q, predict_q2) # Update stats self.q1_record.extend(predict_q1.detach().cpu().numpy()) self.q2_record.extend(predict_q2.detach().cpu().numpy()) self.q_func1_loss_record.append(float(loss1)) self.q_func2_loss_record.append(float(loss2)) self.q_func1_optimizer.zero_grad() loss1.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) self.q_func1_optimizer.step() self.q_func2_optimizer.zero_grad() loss2.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) self.q_func2_optimizer.step() self.q_func_n_updates += 1 def update_policy(self, batch): """Compute loss for actor.""" batch_state = batch["state"] onpolicy_actions = self.policy(batch_state).rsample() q = self.q_func1((batch_state, onpolicy_actions)) # Since we want to maximize Q, loss is negation of Q loss = -torch.mean(q) self.policy_loss_record.append(float(loss)) self.policy_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.policy_n_updates += 1 def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.update_q_func(batch) if self.q_func_n_updates % self.policy_update_delay == 0: self.update_policy(batch) self.sync_target_network() def batch_select_onpolicy_action(self, batch_obs): with torch.no_grad(), pfrl.utils.evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) batch_action = self.policy(batch_xs).sample().cpu().numpy() return list(batch_action) def batch_act(self, batch_obs): if self.training: return self._batch_act_train(batch_obs) else: return self._batch_act_eval(batch_obs) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) def _batch_act_eval(self, batch_obs): assert not self.training return self.batch_select_onpolicy_action(batch_obs) def _batch_act_train(self, batch_obs): assert self.training if self.burnin_action_func is not None and self.policy_n_updates == 0: batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))] else: batch_onpolicy_action = self.batch_select_onpolicy_action(batch_obs) batch_action = [ self.explorer.select_action(self.t, lambda: batch_onpolicy_action[i]) for i in range(len(batch_onpolicy_action)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i in range(len(batch_obs)): self.t += 1 if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer self.replay_buffer.append( state=self.batch_last_obs[i], action=self.batch_last_action[i], reward=batch_reward[i], next_state=batch_obs[i], next_action=None, is_state_terminal=batch_done[i], env_id=i, ) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) def get_statistics(self): return [ ("average_q1", _mean_or_nan(self.q1_record)), ("average_q2", _mean_or_nan(self.q2_record)), ("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)), ("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)), ("average_policy_loss", _mean_or_nan(self.policy_loss_record)), ("policy_n_updates", self.policy_n_updates), ("q_func_n_updates", self.q_func_n_updates), ]