Source code for pfrl.q_functions.state_q_functions

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from pfrl.action_value import (
    DiscreteActionValue,
    DistributionalDiscreteActionValue,
    QuadraticActionValue,
)
from pfrl.functions.lower_triangular_matrix import lower_triangular_matrix
from pfrl.initializers import init_chainer_default
from pfrl.nn import Lambda
from pfrl.nn.mlp import MLP
from pfrl.q_function import StateQFunction


def scale_by_tanh(x, low, high):
    scale = (high - low) / 2
    scale = torch.unsqueeze(torch.from_numpy(scale), dim=0).to(x.device)
    mean = (high + low) / 2
    mean = torch.unsqueeze(torch.from_numpy(mean), dim=0).to(x.device)
    return torch.tanh(x) * scale + mean


[docs]class SingleModelStateQFunctionWithDiscreteAction(nn.Module, StateQFunction): """Q-function with discrete actions. Args: model (nn.Module): Model that is callable and outputs action values. """ def __init__(self, model): super().__init__() self.model = model def forward(self, x): h = self.model(x) return DiscreteActionValue(h)
[docs]class FCStateQFunctionWithDiscreteAction(SingleModelStateQFunctionWithDiscreteAction): """Fully-connected state-input Q-function with discrete actions. Args: n_dim_obs: number of dimensions of observation space n_actions (int): Number of actions in action space. n_hidden_channels: number of hidden channels n_hidden_layers: number of hidden layers nonlinearity (callable): Nonlinearity applied after each hidden layer. last_wscale (float): Weight scale of the last layer. """ def __init__( self, ndim_obs, n_actions, n_hidden_channels, n_hidden_layers, nonlinearity=F.relu, last_wscale=1.0, ): super().__init__( model=MLP( in_size=ndim_obs, out_size=n_actions, hidden_sizes=[n_hidden_channels] * n_hidden_layers, nonlinearity=nonlinearity, last_wscale=last_wscale, ) )
[docs]class DistributionalSingleModelStateQFunctionWithDiscreteAction( nn.Module, StateQFunction ): """Distributional Q-function with discrete actions. Args: model (nn.Module): model that is callable and outputs atoms for each action. z_values (ndarray): Returns represented by atoms. Its shape must be (n_atoms,). """ def __init__(self, model, z_values): super().__init__() self.model = model self.register_buffer("z_values", torch.from_numpy(z_values)) def forward(self, x): h = self.model(x) return DistributionalDiscreteActionValue(h, self.z_values)
[docs]class DistributionalFCStateQFunctionWithDiscreteAction( DistributionalSingleModelStateQFunctionWithDiscreteAction ): """Distributional fully-connected Q-function with discrete actions. Args: n_dim_obs (int): Number of dimensions of observation space. n_actions (int): Number of actions in action space. n_atoms (int): Number of atoms of return distribution. v_min (float): Minimum value this model can approximate. v_max (float): Maximum value this model can approximate. n_hidden_channels (int): Number of hidden channels. n_hidden_layers (int): Number of hidden layers. nonlinearity (callable): Nonlinearity applied after each hidden layer. last_wscale (float): Weight scale of the last layer. """ def __init__( self, ndim_obs, n_actions, n_atoms, v_min, v_max, n_hidden_channels, n_hidden_layers, nonlinearity=F.relu, last_wscale=1.0, ): assert n_atoms >= 2 assert v_min < v_max z_values = np.linspace(v_min, v_max, num=n_atoms, dtype=np.float32) model = nn.Sequential( MLP( in_size=ndim_obs, out_size=n_actions * n_atoms, hidden_sizes=[n_hidden_channels] * n_hidden_layers, nonlinearity=nonlinearity, last_wscale=last_wscale, ), Lambda(lambda x: torch.reshape(x, (-1, n_actions, n_atoms))), nn.Softmax(dim=2), ) super().__init__(model=model, z_values=z_values)
[docs]class FCQuadraticStateQFunction(nn.Module, StateQFunction): """Fully-connected state-input continuous Q-function. See: https://arxiv.org/abs/1603.00748 Args: n_input_channels: number of input channels n_dim_action: number of dimensions of action space n_hidden_channels: number of hidden channels n_hidden_layers: number of hidden layers action_space: action_space scale_mu (bool): scale mu by applying tanh if True """ def __init__( self, n_input_channels, n_dim_action, n_hidden_channels, n_hidden_layers, action_space, scale_mu=True, ): self.n_input_channels = n_input_channels self.n_hidden_layers = n_hidden_layers self.n_hidden_channels = n_hidden_channels self.n_dim_action = n_dim_action assert action_space is not None self.scale_mu = scale_mu self.action_space = action_space super().__init__() hidden_layers = nn.ModuleList() assert n_hidden_layers >= 1 hidden_layers.append( init_chainer_default(nn.Linear(n_input_channels, n_hidden_channels)) ) for _ in range(n_hidden_layers - 1): hidden_layers.append( init_chainer_default(nn.Linear(n_hidden_channels, n_hidden_channels)) ) self.hidden_layers = hidden_layers self.v = init_chainer_default(nn.Linear(n_hidden_channels, 1)) self.mu = init_chainer_default(nn.Linear(n_hidden_channels, n_dim_action)) self.mat_diag = init_chainer_default(nn.Linear(n_hidden_channels, n_dim_action)) non_diag_size = n_dim_action * (n_dim_action - 1) // 2 if non_diag_size > 0: self.mat_non_diag = init_chainer_default( nn.Linear(n_hidden_channels, non_diag_size) ) def forward(self, state): h = state for layer in self.hidden_layers: h = F.relu(layer(h)) v = self.v(h) mu = self.mu(h) if self.scale_mu: mu = scale_by_tanh( mu, high=self.action_space.high, low=self.action_space.low ) mat_diag = torch.exp(self.mat_diag(h)) if hasattr(self, "mat_non_diag"): mat_non_diag = self.mat_non_diag(h) tril = lower_triangular_matrix(mat_diag, mat_non_diag) mat = torch.matmul(tril, torch.transpose(tril, 1, 2)) else: mat = torch.unsqueeze(mat_diag**2, dim=2) return QuadraticActionValue( mu, mat, v, min_action=self.action_space.low, max_action=self.action_space.high, )
class DiscreteActionValueHead(nn.Module): def forward(self, q_values): return DiscreteActionValue(q_values)