Source code for pfrl.nn.recurrent_branched

from torch import nn

from pfrl.nn.recurrent import Recurrent


[docs]class RecurrentBranched(Recurrent, nn.ModuleList): """Recurrent module that bundles parallel branches. This is a recurrent analog to `pfrl.nn.Branched`. It bundles multiple recurrent modules. Args: *modules: Child modules. Each module should be recurrent and callable. """ def __init__(self, *modules): super().__init__(modules) def forward(self, sequences, recurrent_state): if recurrent_state is None: n = len(self) recurrent_state = [None] * n child_ys, rs = tuple( zip(*[link(sequences, rs) for link, rs in zip(self, recurrent_state)]) ) return child_ys, rs