Source code for pfrl.nn.branched

import torch


[docs]class Branched(torch.nn.Module): """Module that calls forward functions of child modules in parallel. When the `forward` method of this module is called, all the arguments are forwarded to each child module's `forward` method. The returned values from the child modules are returned as a tuple. Args: *modules: Child modules. Each module should be callable. """ def __init__(self, *modules): super().__init__() self.child_modules = torch.nn.ModuleList(modules) def forward(self, *args, **kwargs): """Forward the arguments to the child modules. Args: *args, **kwargs: Any arguments forwarded to child modules. Each child module should be able to accept the arguments. Returns: tuple: Tuple of the returned values from the child modules. """ return tuple(mod(*args, **kwargs) for mod in self.child_modules)