Source code for pfrl.replay_buffers.prioritized

import collections

import numpy as np

from pfrl.collections.prioritized import PrioritizedBuffer
from pfrl.replay_buffers.replay_buffer import ReplayBuffer  # NOQA


class PriorityWeightError(object):
    """For proportional prioritization

    alpha determines how much prioritization is used.

    beta determines how much importance sampling weights are used. beta is
    scheduled by ``beta0`` and ``betasteps``.

    Args:
        alpha (float): Exponent of errors to compute probabilities to sample
        beta0 (float): Initial value of beta
        betasteps (float): Steps to anneal beta to 1
        eps (float): To revisit a step after its error becomes near zero
        normalize_by_max (str): Method to normalize weights. ``'batch'`` or
            ``True`` (default): divide by the maximum weight in the sampled
            batch. ``'memory'``: divide by the maximum weight in the memory.
            ``False``: do not normalize.
    """

    def __init__(
        self, alpha, beta0, betasteps, eps, normalize_by_max, error_min, error_max
    ):
        assert 0.0 <= alpha
        assert 0.0 <= beta0 <= 1.0
        self.alpha = alpha
        self.beta = beta0
        if betasteps is None:
            self.beta_add = 0
        else:
            self.beta_add = (1.0 - beta0) / betasteps
        self.eps = eps
        if normalize_by_max is True:
            normalize_by_max = "batch"
        assert normalize_by_max in [False, "batch", "memory"]
        self.normalize_by_max = normalize_by_max
        self.error_min = error_min
        self.error_max = error_max

    def priority_from_errors(self, errors):
        def _clip_error(error):
            if self.error_min is not None:
                error = max(self.error_min, error)
            if self.error_max is not None:
                error = min(self.error_max, error)
            return error

        return [(_clip_error(d) + self.eps) ** self.alpha for d in errors]

    def weights_from_probabilities(self, probabilities, min_probability):
        if self.normalize_by_max == "batch":
            # discard global min and compute batch min
            min_probability = np.min(probabilities)
        if self.normalize_by_max:
            weights = [(p / min_probability) ** -self.beta for p in probabilities]
        else:
            weights = [(len(self.memory) * p) ** -self.beta for p in probabilities]
        self.beta = min(1.0, self.beta + self.beta_add)
        return weights


[docs]class PrioritizedReplayBuffer(ReplayBuffer, PriorityWeightError): """Stochastic Prioritization https://arxiv.org/pdf/1511.05952.pdf Section 3.3 proportional prioritization Args: capacity (int): capacity in terms of number of transitions alpha (float): Exponent of errors to compute probabilities to sample beta0 (float): Initial value of beta betasteps (int): Steps to anneal beta to 1 eps (float): To revisit a step after its error becomes near zero normalize_by_max (bool): Method to normalize weights. ``'batch'`` or ``True`` (default): divide by the maximum weight in the sampled batch. ``'memory'``: divide by the maximum weight in the memory. ``False``: do not normalize """ def __init__( self, capacity=None, alpha=0.6, beta0=0.4, betasteps=2e5, eps=0.01, normalize_by_max=True, error_min=0, error_max=1, num_steps=1, ): self.capacity = capacity assert num_steps > 0 self.num_steps = num_steps self.memory = PrioritizedBuffer(capacity=capacity) self.last_n_transitions = collections.defaultdict( lambda: collections.deque([], maxlen=num_steps) ) PriorityWeightError.__init__( self, alpha, beta0, betasteps, eps, normalize_by_max, error_min=error_min, error_max=error_max, ) def sample(self, n): assert len(self.memory) >= n sampled, probabilities, min_prob = self.memory.sample(n) weights = self.weights_from_probabilities(probabilities, min_prob) for e, w in zip(sampled, weights): e[0]["weight"] = w return sampled def update_errors(self, errors): self.memory.set_last_priority(self.priority_from_errors(errors))