Source code for pfrl.agents.soft_actor_critic

import collections
import copy
from logging import getLogger

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

import pfrl
from pfrl.agent import AttributeSavingMixin, BatchAgent
from pfrl.replay_buffer import ReplayUpdater, batch_experiences
from pfrl.utils import clip_l2_grad_norm_
from pfrl.utils.batch_states import batch_states
from pfrl.utils.copy_param import synchronize_parameters
from pfrl.utils.mode_of_distribution import mode_of_distribution


def _mean_or_nan(xs):
    """Return its mean a non-empty sequence, numpy.nan for a empty one."""
    return np.mean(xs) if xs else np.nan


class TemperatureHolder(nn.Module):
    """Module that holds a temperature as a learnable value.

    Args:
        initial_log_temperature (float): Initial value of log(temperature).
    """

    def __init__(self, initial_log_temperature=0):
        super().__init__()
        self.log_temperature = nn.Parameter(
            torch.tensor(initial_log_temperature, dtype=torch.float32)
        )

    def forward(self):
        """Return a temperature as a torch.Tensor."""
        return torch.exp(self.log_temperature)


[docs]class SoftActorCritic(AttributeSavingMixin, BatchAgent): """Soft Actor-Critic (SAC). See https://arxiv.org/abs/1812.05905 Args: policy (Policy): Policy. q_func1 (Module): First Q-function that takes state-action pairs as input and outputs predicted Q-values. q_func2 (Module): Second Q-function that takes state-action pairs as input and outputs predicted Q-values. policy_optimizer (Optimizer): Optimizer setup with the policy q_func1_optimizer (Optimizer): Optimizer setup with the first Q-function. q_func2_optimizer (Optimizer): Optimizer setup with the second Q-function. replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step phi (callable): Feature extractor applied to observations soft_update_tau (float): Tau of soft target update. logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` burnin_action_func (callable or None): If not None, this callable object is used to select actions before the model is updated one or more times during training. initial_temperature (float): Initial temperature value. If `entropy_target` is set to None, the temperature is fixed to it. entropy_target (float or None): If set to a float, the temperature is adjusted during training to match the policy's entropy to it. temperature_optimizer_lr (float): Learning rate of the temperature optimizer. If set to None, Adam with default hyperparameters is used. act_deterministically (bool): If set to True, choose most probable actions in the act method instead of sampling from distributions. """ saved_attributes = ( "policy", "q_func1", "q_func2", "target_q_func1", "target_q_func2", "policy_optimizer", "q_func1_optimizer", "q_func2_optimizer", "temperature_holder", "temperature_optimizer", ) def __init__( self, policy, q_func1, q_func2, policy_optimizer, q_func1_optimizer, q_func2_optimizer, replay_buffer, gamma, gpu=None, replay_start_size=10000, minibatch_size=100, update_interval=1, phi=lambda x: x, soft_update_tau=5e-3, max_grad_norm=None, logger=getLogger(__name__), batch_states=batch_states, burnin_action_func=None, initial_temperature=1.0, entropy_target=None, temperature_optimizer_lr=None, act_deterministically=True, ): self.policy = policy self.q_func1 = q_func1 self.q_func2 = q_func2 if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.policy.to(self.device) self.q_func1.to(self.device) self.q_func2.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.gamma = gamma self.gpu = gpu self.phi = phi self.soft_update_tau = soft_update_tau self.logger = logger self.policy_optimizer = policy_optimizer self.q_func1_optimizer = q_func1_optimizer self.q_func2_optimizer = q_func2_optimizer self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=self.update, batchsize=minibatch_size, n_times_update=1, replay_start_size=replay_start_size, update_interval=update_interval, episodic_update=False, ) self.max_grad_norm = max_grad_norm self.batch_states = batch_states self.burnin_action_func = burnin_action_func self.initial_temperature = initial_temperature self.entropy_target = entropy_target if self.entropy_target is not None: self.temperature_holder = TemperatureHolder( initial_log_temperature=np.log(initial_temperature) ) if temperature_optimizer_lr is not None: self.temperature_optimizer = torch.optim.Adam( self.temperature_holder.parameters(), lr=temperature_optimizer_lr ) else: self.temperature_optimizer = torch.optim.Adam( self.temperature_holder.parameters() ) if gpu is not None and gpu >= 0: self.temperature_holder.to(self.device) else: self.temperature_holder = None self.temperature_optimizer = None self.act_deterministically = act_deterministically self.t = 0 # Target model self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False) self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False) # Statistics self.q1_record = collections.deque(maxlen=1000) self.q2_record = collections.deque(maxlen=1000) self.entropy_record = collections.deque(maxlen=1000) self.q_func1_loss_record = collections.deque(maxlen=100) self.q_func2_loss_record = collections.deque(maxlen=100) self.n_policy_updates = 0 @property def temperature(self): if self.entropy_target is None: return self.initial_temperature else: with torch.no_grad(): return float(self.temperature_holder()) def sync_target_network(self): """Synchronize target network with current network.""" synchronize_parameters( src=self.q_func1, dst=self.target_q_func1, method="soft", tau=self.soft_update_tau, ) synchronize_parameters( src=self.q_func2, dst=self.target_q_func2, method="soft", tau=self.soft_update_tau, ) def update_q_func(self, batch): """Compute loss for a given Q-function.""" batch_next_state = batch["next_state"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_actions = batch["action"] batch_discount = batch["discount"] with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating( self.target_q_func1 ), pfrl.utils.evaluating(self.target_q_func2): next_action_distrib = self.policy(batch_next_state) next_actions = next_action_distrib.sample() next_log_prob = next_action_distrib.log_prob(next_actions) next_q1 = self.target_q_func1((batch_next_state, next_actions)) next_q2 = self.target_q_func2((batch_next_state, next_actions)) next_q = torch.min(next_q1, next_q2) entropy_term = self.temperature * next_log_prob[..., None] assert next_q.shape == entropy_term.shape target_q = batch_rewards + batch_discount * ( 1.0 - batch_terminal ) * torch.flatten(next_q - entropy_term) predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions))) predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions))) loss1 = 0.5 * F.mse_loss(target_q, predict_q1) loss2 = 0.5 * F.mse_loss(target_q, predict_q2) # Update stats self.q1_record.extend(predict_q1.detach().cpu().numpy()) self.q2_record.extend(predict_q2.detach().cpu().numpy()) self.q_func1_loss_record.append(float(loss1)) self.q_func2_loss_record.append(float(loss2)) self.q_func1_optimizer.zero_grad() loss1.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) self.q_func1_optimizer.step() self.q_func2_optimizer.zero_grad() loss2.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) self.q_func2_optimizer.step() def update_temperature(self, log_prob): assert not log_prob.requires_grad loss = -torch.mean(self.temperature_holder() * (log_prob + self.entropy_target)) self.temperature_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.temperature_holder.parameters(), self.max_grad_norm) self.temperature_optimizer.step() def update_policy_and_temperature(self, batch): """Compute loss for actor.""" batch_state = batch["state"] action_distrib = self.policy(batch_state) actions = action_distrib.rsample() log_prob = action_distrib.log_prob(actions) q1 = self.q_func1((batch_state, actions)) q2 = self.q_func2((batch_state, actions)) q = torch.min(q1, q2) entropy_term = self.temperature * log_prob[..., None] assert q.shape == entropy_term.shape loss = torch.mean(entropy_term - q) self.policy_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.n_policy_updates += 1 if self.entropy_target is not None: self.update_temperature(log_prob.detach()) # Record entropy with torch.no_grad(): try: self.entropy_record.extend( action_distrib.entropy().detach().cpu().numpy() ) except NotImplementedError: # Record - log p(x) instead self.entropy_record.extend(-log_prob.detach().cpu().numpy()) def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.update_q_func(batch) self.update_policy_and_temperature(batch) self.sync_target_network() def batch_select_greedy_action(self, batch_obs, deterministic=False): with torch.no_grad(), pfrl.utils.evaluating(self.policy): batch_xs = self.batch_states(batch_obs, self.device, self.phi) policy_out = self.policy(batch_xs) if deterministic: batch_action = mode_of_distribution(policy_out).cpu().numpy() else: batch_action = policy_out.sample().cpu().numpy() return batch_action def batch_act(self, batch_obs): if self.training: return self._batch_act_train(batch_obs) else: return self._batch_act_eval(batch_obs) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) def _batch_act_eval(self, batch_obs): assert not self.training return self.batch_select_greedy_action( batch_obs, deterministic=self.act_deterministically ) def _batch_act_train(self, batch_obs): assert self.training if self.burnin_action_func is not None and self.n_policy_updates == 0: batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))] else: batch_action = self.batch_select_greedy_action(batch_obs) self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) return batch_action def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training for i in range(len(batch_obs)): self.t += 1 if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer self.replay_buffer.append( state=self.batch_last_obs[i], action=self.batch_last_action[i], reward=batch_reward[i], next_state=batch_obs[i], next_action=None, is_state_terminal=batch_done[i], env_id=i, ) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) def get_statistics(self): return [ ("average_q1", _mean_or_nan(self.q1_record)), ("average_q2", _mean_or_nan(self.q2_record)), ("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)), ("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)), ("n_updates", self.n_policy_updates), ("average_entropy", _mean_or_nan(self.entropy_record)), ("temperature", self.temperature), ]