import collections
import copy
from logging import getLogger
import numpy as np
import torch
from torch.nn import functional as F
import pfrl
from pfrl.agent import AttributeSavingMixin, BatchAgent
from pfrl.replay_buffer import ReplayUpdater, batch_experiences
from pfrl.utils import clip_l2_grad_norm_
from pfrl.utils.batch_states import batch_states
from pfrl.utils.copy_param import synchronize_parameters
def _mean_or_nan(xs):
"""Return its mean a non-empty sequence, numpy.nan for a empty one."""
return np.mean(xs) if xs else np.nan
def default_target_policy_smoothing_func(batch_action):
"""Add noises to actions for target policy smoothing."""
noise = torch.clamp(0.2 * torch.randn_like(batch_action), -0.5, 0.5)
return torch.clamp(batch_action + noise, -1, 1)
[docs]class TD3(AttributeSavingMixin, BatchAgent):
"""Twin Delayed Deep Deterministic Policy Gradients (TD3).
See http://arxiv.org/abs/1802.09477
Args:
policy (Policy): Policy.
q_func1 (Module): First Q-function that takes state-action pairs as input
and outputs predicted Q-values.
q_func2 (Module): Second Q-function that takes state-action pairs as
input and outputs predicted Q-values.
policy_optimizer (Optimizer): Optimizer setup with the policy
q_func1_optimizer (Optimizer): Optimizer setup with the first
Q-function.
q_func2_optimizer (Optimizer): Optimizer setup with the second
Q-function.
replay_buffer (ReplayBuffer): Replay buffer
gamma (float): Discount factor
explorer (Explorer): Explorer that specifies an exploration strategy.
gpu (int): GPU device id if not None nor negative.
replay_start_size (int): if the replay buffer's size is less than
replay_start_size, skip update
minibatch_size (int): Minibatch size
update_interval (int): Model update interval in step
phi (callable): Feature extractor applied to observations
soft_update_tau (float): Tau of soft target update.
logger (Logger): Logger used
batch_states (callable): method which makes a batch of observations.
default is `pfrl.utils.batch_states.batch_states`
burnin_action_func (callable or None): If not None, this callable
object is used to select actions before the model is updated
one or more times during training.
policy_update_delay (int): Delay of policy updates. Policy is updated
once in `policy_update_delay` times of Q-function updates.
target_policy_smoothing_func (callable): Callable that takes a batch of
actions as input and outputs a noisy version of it. It is used for
target policy smoothing when computing target Q-values.
"""
saved_attributes = (
"policy",
"q_func1",
"q_func2",
"target_policy",
"target_q_func1",
"target_q_func2",
"policy_optimizer",
"q_func1_optimizer",
"q_func2_optimizer",
)
def __init__(
self,
policy,
q_func1,
q_func2,
policy_optimizer,
q_func1_optimizer,
q_func2_optimizer,
replay_buffer,
gamma,
explorer,
gpu=None,
replay_start_size=10000,
minibatch_size=100,
update_interval=1,
phi=lambda x: x,
soft_update_tau=5e-3,
n_times_update=1,
max_grad_norm=None,
logger=getLogger(__name__),
batch_states=batch_states,
burnin_action_func=None,
policy_update_delay=2,
target_policy_smoothing_func=default_target_policy_smoothing_func,
):
self.policy = policy
self.q_func1 = q_func1
self.q_func2 = q_func2
if gpu is not None and gpu >= 0:
assert torch.cuda.is_available()
self.device = torch.device("cuda:{}".format(gpu))
self.policy.to(self.device)
self.q_func1.to(self.device)
self.q_func2.to(self.device)
else:
self.device = torch.device("cpu")
self.replay_buffer = replay_buffer
self.gamma = gamma
self.explorer = explorer
self.gpu = gpu
self.phi = phi
self.soft_update_tau = soft_update_tau
self.logger = logger
self.policy_optimizer = policy_optimizer
self.q_func1_optimizer = q_func1_optimizer
self.q_func2_optimizer = q_func2_optimizer
self.replay_updater = ReplayUpdater(
replay_buffer=replay_buffer,
update_func=self.update,
batchsize=minibatch_size,
n_times_update=1,
replay_start_size=replay_start_size,
update_interval=update_interval,
episodic_update=False,
)
self.max_grad_norm = max_grad_norm
self.batch_states = batch_states
self.burnin_action_func = burnin_action_func
self.policy_update_delay = policy_update_delay
self.target_policy_smoothing_func = target_policy_smoothing_func
self.t = 0
self.policy_n_updates = 0
self.q_func_n_updates = 0
self.last_state = None
self.last_action = None
# Target model
self.target_policy = copy.deepcopy(self.policy).eval().requires_grad_(False)
self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False)
self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False)
# Statistics
self.q1_record = collections.deque(maxlen=1000)
self.q2_record = collections.deque(maxlen=1000)
self.q_func1_loss_record = collections.deque(maxlen=100)
self.q_func2_loss_record = collections.deque(maxlen=100)
self.policy_loss_record = collections.deque(maxlen=100)
def sync_target_network(self):
"""Synchronize target network with current network."""
synchronize_parameters(
src=self.policy,
dst=self.target_policy,
method="soft",
tau=self.soft_update_tau,
)
synchronize_parameters(
src=self.q_func1,
dst=self.target_q_func1,
method="soft",
tau=self.soft_update_tau,
)
synchronize_parameters(
src=self.q_func2,
dst=self.target_q_func2,
method="soft",
tau=self.soft_update_tau,
)
def update_q_func(self, batch):
"""Compute loss for a given Q-function."""
batch_next_state = batch["next_state"]
batch_rewards = batch["reward"]
batch_terminal = batch["is_state_terminal"]
batch_state = batch["state"]
batch_actions = batch["action"]
batch_discount = batch["discount"]
with torch.no_grad(), pfrl.utils.evaluating(
self.target_policy
), pfrl.utils.evaluating(self.target_q_func1), pfrl.utils.evaluating(
self.target_q_func2
):
next_actions = self.target_policy_smoothing_func(
self.target_policy(batch_next_state).sample()
)
next_q1 = self.target_q_func1((batch_next_state, next_actions))
next_q2 = self.target_q_func2((batch_next_state, next_actions))
next_q = torch.min(next_q1, next_q2)
target_q = batch_rewards + batch_discount * (
1.0 - batch_terminal
) * torch.flatten(next_q)
predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions)))
predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions)))
loss1 = F.mse_loss(target_q, predict_q1)
loss2 = F.mse_loss(target_q, predict_q2)
# Update stats
self.q1_record.extend(predict_q1.detach().cpu().numpy())
self.q2_record.extend(predict_q2.detach().cpu().numpy())
self.q_func1_loss_record.append(float(loss1))
self.q_func2_loss_record.append(float(loss2))
self.q_func1_optimizer.zero_grad()
loss1.backward()
if self.max_grad_norm is not None:
clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm)
self.q_func1_optimizer.step()
self.q_func2_optimizer.zero_grad()
loss2.backward()
if self.max_grad_norm is not None:
clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm)
self.q_func2_optimizer.step()
self.q_func_n_updates += 1
def update_policy(self, batch):
"""Compute loss for actor."""
batch_state = batch["state"]
onpolicy_actions = self.policy(batch_state).rsample()
q = self.q_func1((batch_state, onpolicy_actions))
# Since we want to maximize Q, loss is negation of Q
loss = -torch.mean(q)
self.policy_loss_record.append(float(loss))
self.policy_optimizer.zero_grad()
loss.backward()
if self.max_grad_norm is not None:
clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy_optimizer.step()
self.policy_n_updates += 1
def update(self, experiences, errors_out=None):
"""Update the model from experiences"""
batch = batch_experiences(experiences, self.device, self.phi, self.gamma)
self.update_q_func(batch)
if self.q_func_n_updates % self.policy_update_delay == 0:
self.update_policy(batch)
self.sync_target_network()
def batch_select_onpolicy_action(self, batch_obs):
with torch.no_grad(), pfrl.utils.evaluating(self.policy):
batch_xs = self.batch_states(batch_obs, self.device, self.phi)
batch_action = self.policy(batch_xs).sample().cpu().numpy()
return list(batch_action)
def batch_act(self, batch_obs):
if self.training:
return self._batch_act_train(batch_obs)
else:
return self._batch_act_eval(batch_obs)
def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
if self.training:
self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset)
def _batch_act_eval(self, batch_obs):
assert not self.training
return self.batch_select_onpolicy_action(batch_obs)
def _batch_act_train(self, batch_obs):
assert self.training
if self.burnin_action_func is not None and self.policy_n_updates == 0:
batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))]
else:
batch_onpolicy_action = self.batch_select_onpolicy_action(batch_obs)
batch_action = [
self.explorer.select_action(self.t, lambda: batch_onpolicy_action[i])
for i in range(len(batch_onpolicy_action))
]
self.batch_last_obs = list(batch_obs)
self.batch_last_action = list(batch_action)
return batch_action
def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset):
assert self.training
for i in range(len(batch_obs)):
self.t += 1
if self.batch_last_obs[i] is not None:
assert self.batch_last_action[i] is not None
# Add a transition to the replay buffer
self.replay_buffer.append(
state=self.batch_last_obs[i],
action=self.batch_last_action[i],
reward=batch_reward[i],
next_state=batch_obs[i],
next_action=None,
is_state_terminal=batch_done[i],
env_id=i,
)
if batch_reset[i] or batch_done[i]:
self.batch_last_obs[i] = None
self.batch_last_action[i] = None
self.replay_buffer.stop_current_episode(env_id=i)
self.replay_updater.update_if_necessary(self.t)
def get_statistics(self):
return [
("average_q1", _mean_or_nan(self.q1_record)),
("average_q2", _mean_or_nan(self.q2_record)),
("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)),
("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)),
("average_policy_loss", _mean_or_nan(self.policy_loss_record)),
("policy_n_updates", self.policy_n_updates),
("q_func_n_updates", self.q_func_n_updates),
]