import numpy as np
import torch
from torch import nn

[docs]class EmpiricalNormalization(nn.Module): """Normalize mean and variance of values based on empirical values. Args: shape (int or tuple of int): Shape of input values except batch axis. batch_axis (int): Batch axis. eps (float): Small value for stability. dtype (dtype): Dtype of input values. until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes exceeds it. """ def __init__( self, shape, batch_axis=0, eps=1e-2, dtype=np.float32, until=None, clip_threshold=None, ): super(EmpiricalNormalization, self).__init__() dtype = np.dtype(dtype) self.batch_axis = batch_axis self.eps = dtype.type(eps) self.until = until self.clip_threshold = clip_threshold self.register_buffer( "_mean", torch.tensor(np.expand_dims(np.zeros(shape, dtype=dtype), batch_axis)), ) self.register_buffer( "_var", torch.tensor(np.expand_dims(np.ones(shape, dtype=dtype), batch_axis)), ) self.register_buffer("count", torch.tensor(0)) # cache self._cached_std_inverse = None @property def mean(self): return torch.squeeze(self._mean, self.batch_axis).clone() @property def std(self): return torch.sqrt(torch.squeeze(self._var, self.batch_axis)).clone() @property def _std_inverse(self): if self._cached_std_inverse is None: self._cached_std_inverse = (self._var + self.eps) ** -0.5 return self._cached_std_inverse def experience(self, x): """Learn input values without computing the output values of them""" if self.until is not None and self.count >= self.until: return count_x = x.shape[self.batch_axis] if count_x == 0: return self.count += count_x rate = count_x / self.count.float() assert rate > 0 assert rate <= 1 var_x, mean_x = torch.var_mean( x, axis=self.batch_axis, keepdims=True, unbiased=False ) delta_mean = mean_x - self._mean self._mean += rate * delta_mean self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean)) # clear cache self._cached_std_inverse = None def forward(self, x, update=True): """Normalize mean and variance of values based on emprical values. Args: x (ndarray or Variable): Input values update (bool): Flag to learn the input values Returns: ndarray or Variable: Normalized output values """ if update: self.experience(x) normalized = (x - self._mean) * self._std_inverse if self.clip_threshold is not None: normalized = torch.clamp( normalized, -self.clip_threshold, self.clip_threshold ) return normalized def inverse(self, y): std = torch.sqrt(self._var + self.eps) return y * std + self._mean