import copy
from logging import getLogger
import numpy as np
import torch
from torch import nn
from pfrl import agent
from pfrl.action_value import SingleActionValue
from pfrl.utils import clip_l2_grad_norm_, copy_param
from pfrl.utils.batch_states import batch_states
from pfrl.utils.mode_of_distribution import mode_of_distribution
from pfrl.utils.recurrent import detach_recurrent_state, one_step_forward
def compute_importance(pi, mu, x):
with torch.no_grad():
return float(torch.exp(pi.log_prob(x) - mu.log_prob(x)))
def compute_full_importance(pi, mu):
assert isinstance(pi, torch.distributions.Categorical)
assert isinstance(mu, torch.distributions.Categorical)
# Categorical.logits is already normalized, i.e., exp(logits[i]) = probs[i]
with torch.no_grad():
pimu = torch.exp(pi.logits - mu.logits)
return pimu
def compute_policy_gradient_full_correction(
action_distrib, action_distrib_mu, action_value, v, truncation_threshold
):
"""Compute off-policy bias correction term wrt all actions."""
assert isinstance(action_distrib, torch.distributions.Categorical)
assert isinstance(action_distrib_mu, torch.distributions.Categorical)
assert truncation_threshold is not None
assert np.isscalar(v)
with torch.no_grad():
rho_all_inv = compute_full_importance(action_distrib_mu, action_distrib)
correction_weight = (
torch.nn.functional.relu(1 - truncation_threshold * rho_all_inv)
* action_distrib.probs[0]
)
correction_advantage = action_value.q_values[0] - v
# Categorical.logits is already normalized, i.e., logits[i] = log(probs[i])
return -(correction_weight * action_distrib.logits * correction_advantage).sum(1)
def compute_policy_gradient_sample_correction(
action_distrib, action_distrib_mu, action_value, v, truncation_threshold
):
"""Compute off-policy bias correction term wrt a sampled action."""
assert np.isscalar(v)
assert truncation_threshold is not None
with torch.no_grad():
sample_action = action_distrib.sample()
rho_dash_inv = compute_importance(
action_distrib_mu, action_distrib, sample_action
)
if truncation_threshold > 0 and rho_dash_inv >= 1 / truncation_threshold:
return torch.as_tensor(0, dtype=torch.float)
correction_weight = max(0, 1 - truncation_threshold * rho_dash_inv)
assert correction_weight <= 1
q = float(action_value.evaluate_actions(sample_action))
correction_advantage = q - v
return -(
correction_weight
* action_distrib.log_prob(sample_action)
* correction_advantage
)
def compute_policy_gradient_loss(
action,
advantage,
action_distrib,
action_distrib_mu,
action_value,
v,
truncation_threshold,
):
"""Compute policy gradient loss with off-policy bias correction."""
assert np.isscalar(advantage)
assert np.isscalar(v)
log_prob = action_distrib.log_prob(action)
if action_distrib_mu is not None:
# Off-policy
rho = compute_importance(action_distrib, action_distrib_mu, action)
g_loss = 0
if truncation_threshold is None:
g_loss -= rho * log_prob * advantage
else:
# Truncated off-policy policy gradient term
g_loss -= min(truncation_threshold, rho) * log_prob * advantage
# Bias correction term
if isinstance(action_distrib, torch.distributions.Categorical):
g_loss += compute_policy_gradient_full_correction(
action_distrib=action_distrib,
action_distrib_mu=action_distrib_mu,
action_value=action_value,
v=v,
truncation_threshold=truncation_threshold,
)
else:
g_loss += compute_policy_gradient_sample_correction(
action_distrib=action_distrib,
action_distrib_mu=action_distrib_mu,
action_value=action_value,
v=v,
truncation_threshold=truncation_threshold,
)
else:
# On-policy
g_loss = -log_prob * advantage
return g_loss
class ACERDiscreteActionHead(nn.Module):
"""ACER model that consists of a separate policy and V-function.
Args:
pi (Policy): Policy.
q (QFunction): Q-function.
"""
def __init__(self, pi, q):
super().__init__()
self.pi = pi
self.q = q
def forward(self, obs):
action_distrib = self.pi(obs)
action_value = self.q(obs)
v = (action_distrib.probs * action_value.q_values).sum(1)
return action_distrib, action_value, v
class ACERContinuousActionHead(nn.Module):
"""ACER model that consists of a separate policy and V-function.
Args:
pi (Policy): Policy.
v (torch.nn.Module): V-function, a callable mapping from a batch of
observations to a (batch_size, 1)-shaped `torch.Tensor`.
adv (StateActionQFunction): Advantage function.
n (int): Number of samples used to evaluate Q-values.
"""
def __init__(self, pi, v, adv, n=5):
super().__init__()
self.pi = pi
self.v = v
self.adv = adv
self.n = n
def forward(self, obs):
action_distrib = self.pi(obs)
v = self.v(obs)
def evaluator(action):
adv_mean = (
sum(self.adv((obs, action_distrib.sample())) for _ in range(self.n))
/ self.n
)
return v + self.adv((obs, action)) - adv_mean
action_value = SingleActionValue(evaluator)
return action_distrib, action_value, v
def get_params_of_distribution(distrib):
if isinstance(distrib, torch.distributions.Independent):
return get_params_of_distribution(distrib.base_dist)
elif isinstance(distrib, torch.distributions.Categorical):
return (distrib._param,)
elif isinstance(distrib, torch.distributions.Normal):
return distrib.loc, distrib.scale
else:
raise NotImplementedError("{} is not supported by ACER".format(type(distrib)))
def deepcopy_distribution(distrib):
"""Deepcopy a PyTorch distribution.
PyTorch distributions cannot be deepcopied as it is except its tensors are
graph leaves.
"""
if isinstance(distrib, torch.distributions.Independent):
return torch.distributions.Independent(
deepcopy_distribution(distrib.base_dist),
distrib.reinterpreted_batch_ndims,
)
elif isinstance(distrib, torch.distributions.Categorical):
return torch.distributions.Categorical(
logits=distrib.logits.clone().detach(),
)
elif isinstance(distrib, torch.distributions.Normal):
return torch.distributions.Normal(
loc=distrib.loc.clone().detach(),
scale=distrib.scale.clone().detach(),
)
else:
raise NotImplementedError("{} is not supported by ACER".format(type(distrib)))
def compute_loss_with_kl_constraint(distrib, another_distrib, original_loss, delta):
"""Compute loss considering a KL constraint.
Args:
distrib (Distribution): Distribution to optimize
another_distrib (Distribution): Distribution used to compute KL
original_loss (torch.Tensor): Loss to minimize
delta (float): Minimum KL difference
Returns:
torch.Tensor: new loss to minimize
"""
distrib_params = get_params_of_distribution(distrib)
for param in distrib_params:
assert param.shape[0] == 1
assert param.requires_grad
# Compute g: a direction to minimize the original loss
g = [
grad[0]
for grad in torch.autograd.grad(
[original_loss], distrib_params, retain_graph=True
)
]
# Compute k: a direction to increase KL div.
kl = torch.distributions.kl_divergence(another_distrib, distrib)
k = [
grad[0]
for grad in torch.autograd.grad([-kl], distrib_params, retain_graph=True)
]
# Compute z: combination of g and k to keep small KL div.
kg_dot = sum(torch.dot(kp.flatten(), gp.flatten()) for kp, gp in zip(k, g))
kk_dot = sum(torch.dot(kp.flatten(), kp.flatten()) for kp in k)
if kk_dot > 0:
k_factor = max(0, ((kg_dot - delta) / kk_dot))
else:
k_factor = 0
z = [gp - k_factor * kp for kp, gp in zip(k, g)]
loss = 0
for p, zp in zip(distrib_params, z):
loss += (p * zp).sum()
return loss.reshape(original_loss.shape), float(kl)
[docs]class ACER(agent.AttributeSavingMixin, agent.AsyncAgent):
"""ACER (Actor-Critic with Experience Replay).
See http://arxiv.org/abs/1611.01224
Args:
model (ACERModel): Model to train. It must be a callable that accepts
observations as input and return three values: action distributions
(Distribution), Q values (ActionValue) and state values
(torch.Tensor).
optimizer (torch.optim.Optimizer): optimizer used to train the model
t_max (int): The model is updated after every t_max local steps
gamma (float): Discount factor [0,1]
replay_buffer (EpisodicReplayBuffer): Replay buffer to use. If set
None, this agent won't use experience replay.
beta (float): Weight coefficient for the entropy regularizaiton term.
phi (callable): Feature extractor function
pi_loss_coef (float): Weight coefficient for the loss of the policy
Q_loss_coef (float): Weight coefficient for the loss of the value
function
use_trust_region (bool): If set true, use efficient TRPO.
trust_region_alpha (float): Decay rate of the average model used for
efficient TRPO.
trust_region_delta (float): Threshold used for efficient TRPO.
truncation_threshold (float or None): Threshold used to truncate larger
importance weights. If set None, importance weights are not
truncated.
disable_online_update (bool): If set true, disable online on-policy
update and rely only on experience replay.
n_times_replay (int): Number of times experience replay is repeated per
one time of online update.
replay_start_size (int): Experience replay is disabled if the number of
transitions in the replay buffer is lower than this value.
normalize_loss_by_steps (bool): If set true, losses are normalized by
the number of steps taken to accumulate the losses
act_deterministically (bool): If set true, choose most probable actions
in act method.
max_grad_norm (float or None): Maximum L2 norm of the gradient used for
gradient clipping. If set to None, the gradient is not clipped.
recurrent (bool): If set to True, `model` is assumed to implement
`pfrl.nn.StatelessRecurrent`.
use_Q_opc (bool): If set true, use Q_opc, a Q-value estimate without
importance sampling, is used to compute advantage values for policy
gradients. The original paper recommend to use in case of
continuous action.
average_entropy_decay (float): Decay rate of average entropy. Used only
to record statistics.
average_value_decay (float): Decay rate of average value. Used only
to record statistics.
average_kl_decay (float): Decay rate of kl value. Used only to record
statistics.
"""
process_idx = None
saved_attributes = ("model", "optimizer")
def __init__(
self,
model,
optimizer,
t_max,
gamma,
replay_buffer,
beta=1e-2,
phi=lambda x: x,
pi_loss_coef=1.0,
Q_loss_coef=0.5,
use_trust_region=True,
trust_region_alpha=0.99,
trust_region_delta=1,
truncation_threshold=10,
disable_online_update=False,
n_times_replay=8,
replay_start_size=10**4,
normalize_loss_by_steps=True,
act_deterministically=False,
max_grad_norm=None,
recurrent=False,
use_Q_opc=False,
average_entropy_decay=0.999,
average_value_decay=0.999,
average_kl_decay=0.999,
logger=None,
):
# Globally shared model
self.shared_model = model
# Globally shared average model used to compute trust regions
self.shared_average_model = copy.deepcopy(self.shared_model)
# Thread specific model
self.model = copy.deepcopy(self.shared_model)
self.optimizer = optimizer
self.replay_buffer = replay_buffer
self.t_max = t_max
self.gamma = gamma
self.beta = beta
self.phi = phi
self.pi_loss_coef = pi_loss_coef
self.Q_loss_coef = Q_loss_coef
self.normalize_loss_by_steps = normalize_loss_by_steps
self.act_deterministically = act_deterministically
self.max_grad_norm = max_grad_norm
self.recurrent = recurrent
self.use_trust_region = use_trust_region
self.trust_region_alpha = trust_region_alpha
self.truncation_threshold = truncation_threshold
self.trust_region_delta = trust_region_delta
self.disable_online_update = disable_online_update
self.n_times_replay = n_times_replay
self.use_Q_opc = use_Q_opc
self.replay_start_size = replay_start_size
self.average_value_decay = average_value_decay
self.average_entropy_decay = average_entropy_decay
self.average_kl_decay = average_kl_decay
self.logger = logger if logger else getLogger(__name__)
self.device = torch.device("cpu")
self.t = 0
self.last_state = None
self.last_action = None
# Recurrent states of the model
self.train_recurrent_states = None
self.shared_recurrent_states = None
self.test_recurrent_states = None
# Stats
self.average_value = 0
self.average_entropy = 0
self.average_kl = 0
self.init_history_data_for_online_update()
def init_history_data_for_online_update(self):
self.past_actions = {}
self.past_rewards = {}
self.past_values = {}
self.past_action_distrib = {}
self.past_action_values = {}
self.past_avg_action_distrib = {}
self.t_start = self.t
def sync_parameters(self):
copy_param.copy_param(target_link=self.model, source_link=self.shared_model)
copy_param.soft_copy_param(
target_link=self.shared_average_model,
source_link=self.model,
tau=1 - self.trust_region_alpha,
)
def assert_shared_memory(self):
# Shared model must have tensors in shared memory
for k, v in self.shared_model.state_dict().items():
assert v.is_shared(), "{} is not in shared memory".format(k)
# Local model must not have tensors in shared memory
for k, v in self.model.state_dict().items():
assert not v.is_shared(), "{} is in shared memory".format(k)
# Optimizer must have tensors in shared memory
for param_state in self.optimizer.state_dict()["state"].values():
for k, v in param_state.items():
if isinstance(v, torch.Tensor):
assert v.is_shared(), "{} is not in shared memory".format(k)
@property
def shared_attributes(self):
return ("shared_model", "shared_average_model", "optimizer")
def compute_one_step_pi_loss(
self,
action,
advantage,
action_distrib,
action_distrib_mu,
action_value,
v,
avg_action_distrib,
):
assert np.isscalar(advantage)
assert np.isscalar(v)
g_loss = compute_policy_gradient_loss(
action=action,
advantage=advantage,
action_distrib=action_distrib,
action_distrib_mu=action_distrib_mu,
action_value=action_value,
v=v,
truncation_threshold=self.truncation_threshold,
)
if self.use_trust_region:
pi_loss, kl = compute_loss_with_kl_constraint(
action_distrib,
avg_action_distrib,
g_loss,
delta=self.trust_region_delta,
)
self.average_kl += (1 - self.average_kl_decay) * (kl - self.average_kl)
else:
pi_loss = g_loss
# Entropy is maximized
pi_loss -= self.beta * action_distrib.entropy()
return pi_loss
def compute_loss(
self,
t_start,
t_stop,
R,
actions,
rewards,
values,
action_values,
action_distribs,
action_distribs_mu,
avg_action_distribs,
):
assert np.isscalar(R)
pi_loss = 0
Q_loss = 0
Q_ret = R
Q_opc = R
discrete = isinstance(action_distribs[t_start], torch.distributions.Categorical)
del R
for i in reversed(range(t_start, t_stop)):
r = rewards[i]
v = values[i]
action_distrib = action_distribs[i]
action_distrib_mu = action_distribs_mu[i] if action_distribs_mu else None
avg_action_distrib = avg_action_distribs[i]
action_value = action_values[i]
ba = torch.as_tensor(actions[i]).unsqueeze(0)
if action_distrib_mu is not None:
# Off-policy
rho = compute_importance(action_distrib, action_distrib_mu, ba)
else:
# On-policy
rho = 1
Q_ret = r + self.gamma * Q_ret
Q_opc = r + self.gamma * Q_opc
assert np.isscalar(Q_ret)
assert np.isscalar(Q_opc)
if self.use_Q_opc:
advantage = Q_opc - float(v)
else:
advantage = Q_ret - float(v)
pi_loss += self.compute_one_step_pi_loss(
action=ba,
advantage=advantage,
action_distrib=action_distrib,
action_distrib_mu=action_distrib_mu,
action_value=action_value,
v=float(v),
avg_action_distrib=avg_action_distrib,
)
# Accumulate gradients of value function
Q = action_value.evaluate_actions(ba)
assert Q.requires_grad, "Q must be backprop-able"
Q_loss += nn.functional.mse_loss(torch.tensor(Q_ret), Q) / 2
if not discrete:
assert v.requires_grad, "v must be backprop-able"
v_target = min(1, rho) * (Q_ret - float(Q)) + float(v)
Q_loss += nn.functional.mse_loss(torch.tensor(v_target), v) / 2
if self.process_idx == 0:
self.logger.debug(
"t:%s v:%s Q:%s Q_ret:%s Q_opc:%s",
i,
float(v),
float(Q),
Q_ret,
Q_opc,
)
if discrete:
c = min(1, rho)
else:
c = min(1, rho ** (1 / ba.numel()))
Q_ret = c * (Q_ret - float(Q)) + float(v)
Q_opc = Q_opc - float(Q) + float(v)
pi_loss *= self.pi_loss_coef
Q_loss *= self.Q_loss_coef
if self.normalize_loss_by_steps:
pi_loss /= t_stop - t_start
Q_loss /= t_stop - t_start
if self.process_idx == 0:
self.logger.debug("pi_loss:%s Q_loss:%s", float(pi_loss), float(Q_loss))
return pi_loss + Q_loss.reshape(*pi_loss.shape)
def update(
self,
t_start,
t_stop,
R,
actions,
rewards,
values,
action_values,
action_distribs,
action_distribs_mu,
avg_action_distribs,
):
assert np.isscalar(R)
self.assert_shared_memory()
total_loss = self.compute_loss(
t_start=t_start,
t_stop=t_stop,
R=R,
actions=actions,
rewards=rewards,
values=values,
action_values=action_values,
action_distribs=action_distribs,
action_distribs_mu=action_distribs_mu,
avg_action_distribs=avg_action_distribs,
)
# Compute gradients using thread-specific model
self.model.zero_grad()
total_loss.squeeze().backward()
if self.max_grad_norm is not None:
clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
# Copy the gradients to the globally shared model
copy_param.copy_grad(target_link=self.shared_model, source_link=self.model)
self.optimizer.step()
self.sync_parameters()
def update_from_replay(self):
if self.replay_buffer is None:
return
if len(self.replay_buffer) < self.replay_start_size:
return
episode = self.replay_buffer.sample_episodes(1, self.t_max)[0]
model_recurrent_state = None
shared_recurrent_state = None
rewards = {}
actions = {}
action_distribs = {}
action_distribs_mu = {}
avg_action_distribs = {}
action_values = {}
values = {}
for t, transition in enumerate(episode):
bs = batch_states([transition["state"]], self.device, self.phi)
if self.recurrent:
(
(action_distrib, action_value, v),
model_recurrent_state,
) = one_step_forward(self.model, bs, model_recurrent_state)
else:
action_distrib, action_value, v = self.model(bs)
with torch.no_grad():
if self.recurrent:
(
(avg_action_distrib, _, _),
shared_recurrent_state,
) = one_step_forward(
self.shared_average_model,
bs,
shared_recurrent_state,
)
else:
avg_action_distrib, _, _ = self.shared_average_model(bs)
actions[t] = transition["action"]
values[t] = v
action_distribs[t] = action_distrib
avg_action_distribs[t] = avg_action_distrib
rewards[t] = transition["reward"]
action_distribs_mu[t] = transition["mu"]
action_values[t] = action_value
last_transition = episode[-1]
if last_transition["is_state_terminal"]:
R = 0
else:
with torch.no_grad():
last_s = batch_states(
[last_transition["next_state"]], self.device, self.phi
)
if self.recurrent:
(_, _, last_v), _ = one_step_forward(
self.model, last_s, model_recurrent_state
)
else:
_, _, last_v = self.model(last_s)
R = float(last_v)
return self.update(
R=R,
t_start=0,
t_stop=len(episode),
rewards=rewards,
actions=actions,
values=values,
action_distribs=action_distribs,
action_distribs_mu=action_distribs_mu,
avg_action_distribs=avg_action_distribs,
action_values=action_values,
)
def update_on_policy(self, statevar):
assert self.t_start < self.t
if not self.disable_online_update:
if statevar is None:
R = 0
else:
with torch.no_grad():
if self.recurrent:
(_, _, v), _ = one_step_forward(
self.model, statevar, self.train_recurrent_states
)
else:
_, _, v = self.model(statevar)
R = float(v)
self.update(
t_start=self.t_start,
t_stop=self.t,
R=R,
actions=self.past_actions,
rewards=self.past_rewards,
values=self.past_values,
action_values=self.past_action_values,
action_distribs=self.past_action_distrib,
action_distribs_mu=None,
avg_action_distribs=self.past_avg_action_distrib,
)
self.init_history_data_for_online_update()
self.train_recurrent_states = detach_recurrent_state(
self.train_recurrent_states
)
def act(self, obs):
if self.training:
return self._act_train(obs)
else:
return self._act_eval(obs)
def observe(self, obs, reward, done, reset):
if self.training:
self._observe_train(obs, reward, done, reset)
else:
self._observe_eval(obs, reward, done, reset)
def _act_train(self, obs):
statevar = batch_states([obs], self.device, self.phi)
if self.recurrent:
(
(action_distrib, action_value, v),
self.train_recurrent_states,
) = one_step_forward(self.model, statevar, self.train_recurrent_states)
else:
action_distrib, action_value, v = self.model(statevar)
self.past_action_values[self.t] = action_value
action = action_distrib.sample()[0]
# Save values for a later update
self.past_values[self.t] = v
self.past_action_distrib[self.t] = action_distrib
with torch.no_grad():
if self.recurrent:
(
(avg_action_distrib, _, _),
self.shared_recurrent_states,
) = one_step_forward(
self.shared_average_model,
statevar,
self.shared_recurrent_states,
)
else:
avg_action_distrib, _, _ = self.shared_average_model(statevar)
self.past_avg_action_distrib[self.t] = avg_action_distrib
self.past_actions[self.t] = action
# Update stats
self.average_value += (1 - self.average_value_decay) * (
float(v) - self.average_value
)
self.average_entropy += (1 - self.average_entropy_decay) * (
float(action_distrib.entropy()) - self.average_entropy
)
self.last_state = obs
self.last_action = action.numpy()
self.last_action_distrib = deepcopy_distribution(action_distrib)
return self.last_action
def _act_eval(self, obs):
# Use the process-local model for acting
with torch.no_grad():
statevar = batch_states([obs], self.device, self.phi)
if self.recurrent:
(action_distrib, _, _), self.test_recurrent_states = one_step_forward(
self.model, statevar, self.test_recurrent_states
)
else:
action_distrib, _, _ = self.model(statevar)
if self.act_deterministically:
return mode_of_distribution(action_distrib).numpy()[0]
else:
return action_distrib.sample().numpy()[0]
def _observe_train(self, state, reward, done, reset):
assert self.last_state is not None
assert self.last_action is not None
# Add a transition to the replay buffer
if self.replay_buffer is not None:
self.replay_buffer.append(
state=self.last_state,
action=self.last_action,
reward=reward,
next_state=state,
is_state_terminal=done,
mu=self.last_action_distrib,
)
if done or reset:
self.replay_buffer.stop_current_episode()
self.t += 1
self.past_rewards[self.t - 1] = reward
if self.process_idx == 0:
self.logger.debug(
"t:%s r:%s a:%s",
self.t,
reward,
self.last_action,
)
if self.t - self.t_start == self.t_max or done or reset:
if done:
statevar = None
else:
statevar = batch_states([state], self.device, self.phi)
self.update_on_policy(statevar)
for _ in range(self.n_times_replay):
self.update_from_replay()
if done or reset:
self.train_recurrent_states = None
self.shared_recurrent_states = None
self.last_state = None
self.last_action = None
self.last_action_distrib = None
def _observe_eval(self, obs, reward, done, reset):
if done or reset:
self.test_recurrent_states = None
def load(self, dirname):
super().load(dirname)
copy_param.copy_param(target_link=self.shared_model, source_link=self.model)
def get_statistics(self):
return [
("average_value", self.average_value),
("average_entropy", self.average_entropy),
("average_kl", self.average_kl),
]