Source code for pfrl.agents.categorical_dqn

import torch

from pfrl.agents import dqn
from pfrl.utils.recurrent import pack_and_forward


def _apply_categorical_projection(y, y_probs, z):
    """Apply categorical projection.

    See Algorithm 1 in https://arxiv.org/abs/1707.06887.

    Args:
        y (ndarray): Values of atoms before projection. Its shape must be
            (batch_size, n_atoms).
        y_probs (ndarray): Probabilities of atoms whose values are y.
            Its shape must be (batch_size, n_atoms).
        z (ndarray): Values of atoms after projection. Its shape must be
            (n_atoms,). It is assumed that the values are sorted in ascending
            order and evenly spaced.

    Returns:
        ndarray: Probabilities of atoms whose values are z.
    """
    batch_size, n_atoms = y.shape
    assert z.shape == (n_atoms,)
    assert y_probs.shape == (batch_size, n_atoms)
    delta_z = z[1] - z[0]
    v_min = z[0]
    v_max = z[-1]
    y = torch.clamp(y, v_min, v_max)

    # bj: (batch_size, n_atoms)
    bj = (y - v_min) / delta_z
    assert bj.shape == (batch_size, n_atoms)
    # Avoid the error caused by inexact delta_z
    bj = torch.clamp(bj, 0, n_atoms - 1)

    # l, u: (batch_size, n_atoms)
    l, u = torch.floor(bj), torch.ceil(bj)
    assert l.shape == (batch_size, n_atoms)
    assert u.shape == (batch_size, n_atoms)

    z_probs = torch.zeros((batch_size, n_atoms), dtype=torch.float32, device=y.device)
    offset = torch.arange(
        0, batch_size * n_atoms, n_atoms, dtype=torch.int32, device=y.device
    )[..., None]
    # Accumulate m_l
    # Note that u - bj in the original paper is replaced with 1 - (bj - l) to
    # deal with the case when bj is an integer, i.e., l = u = bj
    z_probs.view(-1).scatter_add_(
        0, (l.long() + offset).view(-1), (y_probs * (1 - (bj - l))).view(-1)
    )
    # Accumulate m_u
    z_probs.view(-1).scatter_add_(
        0, (u.long() + offset).view(-1), (y_probs * (bj - l)).view(-1)
    )
    return z_probs


def compute_value_loss(eltwise_loss, batch_accumulator="mean"):
    """Compute a loss for value prediction problem.

    Args:
        eltwise_loss (Variable): Element-wise loss per example per atom
        batch_accumulator (str): 'mean' or 'sum'. 'mean' will use the mean of
            the loss values in a batch. 'sum' will use the sum.
    Returns:
        (Variable) scalar loss
    """
    assert batch_accumulator in ("mean", "sum")

    if batch_accumulator == "sum":
        loss = eltwise_loss.sum()
    else:
        loss = eltwise_loss.sum(dim=1).mean()
    return loss


def compute_weighted_value_loss(
    eltwise_loss, batch_size, weights, batch_accumulator="mean"
):
    """Compute a loss for value prediction problem.

    Args:
        eltwise_loss (Variable): Element-wise loss per example per atom
        weights (ndarray): Weights for y, t.
        batch_accumulator (str): 'mean' will divide loss by batchsize
    Returns:
        (Variable) scalar loss
    """
    assert batch_accumulator in ("mean", "sum")

    # eltwise_loss is (batchsize, n_atoms) array of losses
    # weights is an array of shape (batch_size)
    # sum loss across atoms and then apply weight per example in batch
    weights = weights.to(eltwise_loss.device)
    loss_sum = torch.matmul(eltwise_loss.sum(dim=1), weights)
    if batch_accumulator == "mean":
        loss = loss_sum / batch_size
    elif batch_accumulator == "sum":
        loss = loss_sum
    return loss


[docs]class CategoricalDQN(dqn.DQN): """Categorical DQN. See https://arxiv.org/abs/1707.06887. Arguments are the same as those of DQN except q_function must return DistributionalDiscreteActionValue and clip_delta is ignored. """ def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch["next_state"] if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] batch_size = exp_batch["reward"].shape[0] z_values = target_next_qout.z_values n_atoms = z_values.size()[0] # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.max_as_distribution.detach() assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = ( batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * torch.unsqueeze(exp_batch["discount"], 1) * z_values[None] ) return _apply_categorical_projection(Tz, next_q_max, z_values) def _compute_y_and_t(self, exp_batch): """Compute a batch of predicted/target return distributions.""" batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] # (batch_size, n_actions, n_atoms) if self.recurrent: qout, _ = pack_and_forward( self.model, batch_state, exp_batch["recurrent_state"] ) else: qout = self.model(batch_state) n_atoms = qout.z_values.size()[0] batch_actions = exp_batch["action"] batch_q = qout.evaluate_actions_as_distribution(batch_actions) assert batch_q.shape == (batch_size, n_atoms) with torch.no_grad(): batch_q_target = self._compute_target_values(exp_batch) assert batch_q_target.shape == (batch_size, n_atoms) # for `agent.get_statistics()` batch_q_scalars = qout.evaluate_actions(batch_actions) self.q_record.extend(batch_q_scalars.detach().cpu().numpy().ravel()) return batch_q, batch_q_target def _compute_loss(self, exp_batch, errors_out=None): """Compute a loss of categorical DQN.""" y, t = self._compute_y_and_t(exp_batch) # Minimize the cross entropy # y is clipped to avoid log(0) eltwise_loss = -t * torch.log(torch.clamp(y, 1e-10, 1.0)) if errors_out is not None: del errors_out[:] # The loss per example is the sum of the atom-wise loss # Prioritization by KL-divergence delta = eltwise_loss.sum(dim=1) delta = delta.detach().cpu().numpy() for e in delta: errors_out.append(e) if "weights" in exp_batch: return compute_weighted_value_loss( eltwise_loss, y.shape[0], exp_batch["weights"], batch_accumulator=self.batch_accumulator, ) else: return compute_value_loss( eltwise_loss, batch_accumulator=self.batch_accumulator )