Source code for pfrl.nn.noisy_chain

"""Noisy Networks

See http://arxiv.org/abs/1706.10295
"""

import torch.nn as nn

from pfrl.nn.noisy_linear import FactorizedNoisyLinear


[docs]def to_factorized_noisy(module, *args, **kwargs): """Add noisiness to components of given module Currently this fn. only supports torch.nn.Linear (with and without bias) """ def func_to_factorized_noisy(module): if isinstance(module, nn.Linear): return FactorizedNoisyLinear(module, *args, **kwargs) else: return module _map_modules(func_to_factorized_noisy, module)
def _map_modules(func, module): for name, child in module.named_children(): new_child = func(child) if new_child is child: # It's not nn.Linear, so recurse _map_modules(func, child) else: module._modules[name] = new_child