Source code for pfrl.policies.softmax_policy
import torch
from torch import nn
[docs]class SoftmaxCategoricalHead(nn.Module):
def forward(self, logits):
return torch.distributions.Categorical(logits=logits)
import torch
from torch import nn
[docs]class SoftmaxCategoricalHead(nn.Module):
def forward(self, logits):
return torch.distributions.Categorical(logits=logits)