Source code for pfrl.nn.noisy_linear

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def init_lecun_uniform(tensor, scale=1.0):
    """Initializes the tensor with LeCunUniform."""
    fan_in = torch.nn.init._calculate_correct_fan(tensor, "fan_in")
    s = scale * np.sqrt(3.0 / fan_in)
    with torch.no_grad():
        return tensor.uniform_(-s, s)


def init_variance_scaling_constant(tensor, scale=1.0):

    if tensor.ndim == 1:
        s = scale / np.sqrt(tensor.shape[0])
    else:
        fan_in = torch.nn.init._calculate_correct_fan(tensor, "fan_in")
        s = scale / np.sqrt(fan_in)
    with torch.no_grad():
        return tensor.fill_(s)


[docs]class FactorizedNoisyLinear(nn.Module): """Linear layer in Factorized Noisy Network Args: mu_link (nn.Linear): Linear link that computes mean of output. sigma_scale (float): The hyperparameter sigma_0 in the original paper. Scaling factor of the initial weights of noise-scaling parameters. """ def __init__(self, mu_link, sigma_scale=0.4): super(FactorizedNoisyLinear, self).__init__() self._kernel = None self.out_size = mu_link.out_features self.hasbias = mu_link.bias is not None in_size = mu_link.weight.shape[1] device = mu_link.weight.device self.mu = nn.Linear(in_size, self.out_size, bias=self.hasbias) init_lecun_uniform(self.mu.weight, scale=1 / np.sqrt(3)) self.sigma = nn.Linear(in_size, self.out_size, bias=self.hasbias) init_variance_scaling_constant(self.sigma.weight, scale=sigma_scale) if self.hasbias: init_variance_scaling_constant(self.sigma.bias, scale=sigma_scale) self.mu.to(device) self.sigma.to(device) def _eps(self, shape, dtype, device): r = torch.normal(mean=0.0, std=1.0, size=(shape,), dtype=dtype, device=device) return torch.abs(torch.sqrt(torch.abs(r))) * torch.sign(r) def forward(self, x): # use info of sigma.W to avoid strange error messages dtype = self.sigma.weight.dtype out_size, in_size = self.sigma.weight.shape eps = self._eps(in_size + out_size, dtype, self.sigma.weight.device) eps_x = eps[:in_size] eps_y = eps[in_size:] W = torch.addcmul(self.mu.weight, self.sigma.weight, torch.ger(eps_y, eps_x)) if self.hasbias: b = torch.addcmul(self.mu.bias, self.sigma.bias, eps_y) return F.linear(x, W, b) else: return F.linear(x, W)