Source code for pfrl.nn.mlp

import torch.nn as nn
import torch.nn.functional as F

from pfrl.initializers import init_chainer_default, init_lecun_normal


[docs]class MLP(nn.Module): """Multi-Layer Perceptron""" def __init__( self, in_size, out_size, hidden_sizes, nonlinearity=F.relu, last_wscale=1 ): self.in_size = in_size self.out_size = out_size self.hidden_sizes = hidden_sizes self.nonlinearity = nonlinearity super().__init__() if hidden_sizes: self.hidden_layers = nn.ModuleList() self.hidden_layers.append(nn.Linear(in_size, hidden_sizes[0])) for hin, hout in zip(hidden_sizes, hidden_sizes[1:]): self.hidden_layers.append(nn.Linear(hin, hout)) self.hidden_layers.apply(init_chainer_default) self.output = nn.Linear(hidden_sizes[-1], out_size) else: self.output = nn.Linear(in_size, out_size) init_lecun_normal(self.output.weight, scale=last_wscale) nn.init.zeros_(self.output.bias) def forward(self, x): h = x if self.hidden_sizes: for layer in self.hidden_layers: h = self.nonlinearity(layer(h)) return self.output(h)