Source code for pfrl.q_functions.state_action_q_functions

import torch
import torch.nn as nn
import torch.nn.functional as F

from pfrl.initializers import init_lecun_normal
from pfrl.nn.mlp import MLP
from pfrl.nn.mlp_bn import MLPBN
from pfrl.q_function import StateActionQFunction


[docs]class SingleModelStateActionQFunction(nn.Module, StateActionQFunction): """Q-function with discrete actions. Args: model (nn.Module): Module that is callable and outputs action values. """ def __init__(self, model): super().__init__(model=model) def forward(self, x, a): h = self.model(x, a) return h
[docs]class FCSAQFunction(MLP, StateActionQFunction): """Fully-connected (s,a)-input Q-function. Args: n_dim_obs (int): Number of dimensions of observation space. n_dim_action (int): Number of dimensions of action space. n_hidden_channels (int): Number of hidden channels. n_hidden_layers (int): Number of hidden layers. nonlinearity (callable): Nonlinearity between layers. It must accept a Variable as an argument and return a Variable with the same shape. Nonlinearities with learnable parameters such as PReLU are not supported. It is not used if n_hidden_layers is zero. last_wscale (float): Scale of weight initialization of the last layer. """ def __init__( self, n_dim_obs, n_dim_action, n_hidden_channels, n_hidden_layers, nonlinearity=F.relu, last_wscale=1.0, ): self.n_input_channels = n_dim_obs + n_dim_action self.n_hidden_layers = n_hidden_layers self.n_hidden_channels = n_hidden_channels self.nonlinearity = nonlinearity super().__init__( in_size=self.n_input_channels, out_size=1, hidden_sizes=[self.n_hidden_channels] * self.n_hidden_layers, nonlinearity=nonlinearity, last_wscale=last_wscale, ) def forward(self, state, action): h = torch.cat((state, action), dim=1) return super().forward(h)
[docs]class FCLSTMSAQFunction(nn.Module, StateActionQFunction): """Fully-connected + LSTM (s,a)-input Q-function. Args: n_dim_obs (int): Number of dimensions of observation space. n_dim_action (int): Number of dimensions of action space. n_hidden_channels (int): Number of hidden channels. n_hidden_layers (int): Number of hidden layers. nonlinearity (callable): Nonlinearity between layers. It must accept a Variable as an argument and return a Variable with the same shape. Nonlinearities with learnable parameters such as PReLU are not supported. last_wscale (float): Scale of weight initialization of the last layer. """ def __init__( self, n_dim_obs, n_dim_action, n_hidden_channels, n_hidden_layers, nonlinearity=F.relu, last_wscale=1.0, ): raise NotImplementedError() self.n_input_channels = n_dim_obs + n_dim_action self.n_hidden_layers = n_hidden_layers self.n_hidden_channels = n_hidden_channels self.nonlinearity = nonlinearity super().__init__() self.fc = MLP( self.n_input_channels, n_hidden_channels, [self.n_hidden_channels] * self.n_hidden_layers, nonlinearity=nonlinearity, ) self.lstm = nn.LSTM( num_layers=1, input_size=n_hidden_channels, hidden_size=n_hidden_channels ) self.out = nn.Linear(n_hidden_channels, 1) for n, p in self.lstm.named_parameters(): if "weight" in n: init_lecun_normal(p) else: nn.init.zeros_(p) init_lecun_normal(self.out.weight, scale=last_wscale) nn.init.zeros_(self.out.bias) def forward(self, x, a): h = torch.cat((x, a), dim=1) h = self.nonlinearity(self.fc(h)) h = self.lstm(h) return self.out(h)
[docs]class FCBNSAQFunction(MLPBN, StateActionQFunction): """Fully-connected + BN (s,a)-input Q-function. Args: n_dim_obs (int): Number of dimensions of observation space. n_dim_action (int): Number of dimensions of action space. n_hidden_channels (int): Number of hidden channels. n_hidden_layers (int): Number of hidden layers. normalize_input (bool): If set to True, Batch Normalization is applied to both observations and actions. nonlinearity (callable): Nonlinearity between layers. It must accept a Variable as an argument and return a Variable with the same shape. Nonlinearities with learnable parameters such as PReLU are not supported. It is not used if n_hidden_layers is zero. last_wscale (float): Scale of weight initialization of the last layer. """ def __init__( self, n_dim_obs, n_dim_action, n_hidden_channels, n_hidden_layers, normalize_input=True, nonlinearity=F.relu, last_wscale=1.0, ): self.n_input_channels = n_dim_obs + n_dim_action self.n_hidden_layers = n_hidden_layers self.n_hidden_channels = n_hidden_channels self.normalize_input = normalize_input self.nonlinearity = nonlinearity super().__init__( in_size=self.n_input_channels, out_size=1, hidden_sizes=[self.n_hidden_channels] * self.n_hidden_layers, normalize_input=self.normalize_input, nonlinearity=nonlinearity, last_wscale=last_wscale, ) def forward(self, state, action): h = torch.cat((state, action), dim=1) return super().forward(h)
[docs]class FCBNLateActionSAQFunction(nn.Module, StateActionQFunction): """Fully-connected + BN (s,a)-input Q-function with late action input. Actions are not included until the second hidden layer and not normalized. This architecture is used in the DDPG paper: http://arxiv.org/abs/1509.02971 Args: n_dim_obs (int): Number of dimensions of observation space. n_dim_action (int): Number of dimensions of action space. n_hidden_channels (int): Number of hidden channels. n_hidden_layers (int): Number of hidden layers. It must be greater than or equal to 1. normalize_input (bool): If set to True, Batch Normalization is applied nonlinearity (callable): Nonlinearity between layers. It must accept a Variable as an argument and return a Variable with the same shape. Nonlinearities with learnable parameters such as PReLU are not supported. last_wscale (float): Scale of weight initialization of the last layer. """ def __init__( self, n_dim_obs, n_dim_action, n_hidden_channels, n_hidden_layers, normalize_input=True, nonlinearity=F.relu, last_wscale=1.0, ): assert n_hidden_layers >= 1 self.n_input_channels = n_dim_obs + n_dim_action self.n_hidden_layers = n_hidden_layers self.n_hidden_channels = n_hidden_channels self.normalize_input = normalize_input self.nonlinearity = nonlinearity super().__init__() # No need to pass nonlinearity to obs_mlp because it has no # hidden layers self.obs_mlp = MLPBN( in_size=n_dim_obs, out_size=n_hidden_channels, hidden_sizes=[], normalize_input=normalize_input, normalize_output=True, ) self.mlp = MLP( in_size=n_hidden_channels + n_dim_action, out_size=1, hidden_sizes=([self.n_hidden_channels] * (self.n_hidden_layers - 1)), nonlinearity=nonlinearity, last_wscale=last_wscale, ) self.output = self.mlp.output def forward(self, state, action): h = self.nonlinearity(self.obs_mlp(state)) h = torch.cat((h, action), dim=1) return self.mlp(h)
[docs]class FCLateActionSAQFunction(nn.Module, StateActionQFunction): """Fully-connected (s,a)-input Q-function with late action input. Actions are not included until the second hidden layer and not normalized. This architecture is used in the DDPG paper: http://arxiv.org/abs/1509.02971 Args: n_dim_obs (int): Number of dimensions of observation space. n_dim_action (int): Number of dimensions of action space. n_hidden_channels (int): Number of hidden channels. n_hidden_layers (int): Number of hidden layers. It must be greater than or equal to 1. nonlinearity (callable): Nonlinearity between layers. It must accept a Variable as an argument and return a Variable with the same shape. Nonlinearities with learnable parameters such as PReLU are not supported. last_wscale (float): Scale of weight initialization of the last layer. """ def __init__( self, n_dim_obs, n_dim_action, n_hidden_channels, n_hidden_layers, nonlinearity=F.relu, last_wscale=1.0, ): assert n_hidden_layers >= 1 self.n_input_channels = n_dim_obs + n_dim_action self.n_hidden_layers = n_hidden_layers self.n_hidden_channels = n_hidden_channels self.nonlinearity = nonlinearity super().__init__() # No need to pass nonlinearity to obs_mlp because it has no # hidden layers self.obs_mlp = MLP( in_size=n_dim_obs, out_size=n_hidden_channels, hidden_sizes=[] ) self.mlp = MLP( in_size=n_hidden_channels + n_dim_action, out_size=1, hidden_sizes=([self.n_hidden_channels] * (self.n_hidden_layers - 1)), nonlinearity=nonlinearity, last_wscale=last_wscale, ) self.output = self.mlp.output def forward(self, state, action): h = self.nonlinearity(self.obs_mlp(state)) h = torch.cat((h, action), dim=1) return self.mlp(h)