import copy
from logging import getLogger
import torch
import torch.nn.functional as F
import pfrl
from pfrl import agent
from pfrl.utils import clip_l2_grad_norm_, copy_param
from pfrl.utils.batch_states import batch_states
from pfrl.utils.mode_of_distribution import mode_of_distribution
from pfrl.utils.recurrent import one_step_forward, pack_and_forward
logger = getLogger(__name__)
[docs]class A3C(agent.AttributeSavingMixin, agent.AsyncAgent):
"""A3C: Asynchronous Advantage Actor-Critic.
See http://arxiv.org/abs/1602.01783
Args:
model (A3CModel): Model to train
optimizer (torch.optim.Optimizer): optimizer used to train the model
t_max (int): The model is updated after every t_max local steps
gamma (float): Discount factor [0,1]
beta (float): Weight coefficient for the entropy regularizaiton term.
process_idx (int): Index of the process.
phi (callable): Feature extractor function
pi_loss_coef (float): Weight coefficient for the loss of the policy
v_loss_coef (float): Weight coefficient for the loss of the value
function
act_deterministically (bool): If set true, choose most probable actions
in act method.
max_grad_norm (float or None): Maximum L2 norm of the gradient used for
gradient clipping. If set to None, the gradient is not clipped.
recurrent (bool): If set to True, `model` is assumed to implement
`pfrl.nn.StatelessRecurrent`.
batch_states (callable): method which makes a batch of observations.
default is `pfrl.utils.batch_states.batch_states`
"""
process_idx = None
saved_attributes = ("model", "optimizer")
def __init__(
self,
model,
optimizer,
t_max,
gamma,
beta=1e-2,
process_idx=0,
phi=lambda x: x,
pi_loss_coef=1.0,
v_loss_coef=0.5,
keep_loss_scale_same=False,
normalize_grad_by_t_max=False,
use_average_reward=False,
act_deterministically=False,
max_grad_norm=None,
recurrent=False,
average_entropy_decay=0.999,
average_value_decay=0.999,
batch_states=batch_states,
):
# Globally shared model
self.shared_model = model
# Thread specific model
self.model = copy.deepcopy(self.shared_model)
self.optimizer = optimizer
self.t_max = t_max
self.gamma = gamma
self.beta = beta
self.phi = phi
self.pi_loss_coef = pi_loss_coef
self.v_loss_coef = v_loss_coef
self.keep_loss_scale_same = keep_loss_scale_same
self.normalize_grad_by_t_max = normalize_grad_by_t_max
self.use_average_reward = use_average_reward
self.act_deterministically = act_deterministically
self.max_grad_norm = max_grad_norm
self.recurrent = recurrent
self.average_value_decay = average_value_decay
self.average_entropy_decay = average_entropy_decay
self.batch_states = batch_states
self.device = torch.device("cpu")
self.t = 0
self.t_start = 0
self.past_obs = {}
self.past_action = {}
self.past_rewards = {}
self.past_recurrent_state = {}
self.average_reward = 0
# Recurrent states of the model
self.train_recurrent_states = None
self.test_recurrent_states = None
# Stats
self.average_value = 0
self.average_entropy = 0
def sync_parameters(self):
copy_param.copy_param(target_link=self.model, source_link=self.shared_model)
def assert_shared_memory(self):
# Shared model must have tensors in shared memory
for k, v in self.shared_model.state_dict().items():
assert v.is_shared(), "{} is not in shared memory".format(k)
# Local model must not have tensors in shared memory
for k, v in self.model.state_dict().items():
assert not v.is_shared(), "{} is in shared memory".format(k)
# Optimizer must have tensors in shared memory
for param_state in self.optimizer.state_dict()["state"].values():
for k, v in param_state.items():
if isinstance(v, torch.Tensor):
assert v.is_shared(), "{} is not in shared memory".format(k)
@property
def shared_attributes(self):
return ("shared_model", "optimizer")
def update(self, statevar):
assert self.t_start < self.t
n = self.t - self.t_start
self.assert_shared_memory()
if statevar is None:
R = 0
else:
with torch.no_grad(), pfrl.utils.evaluating(self.model):
if self.recurrent:
(_, vout), _ = one_step_forward(
self.model, statevar, self.train_recurrent_states
)
else:
_, vout = self.model(statevar)
R = float(vout)
pi_loss_factor = self.pi_loss_coef
v_loss_factor = self.v_loss_coef
# Normalize the loss of sequences truncated by terminal states
if self.keep_loss_scale_same and self.t - self.t_start < self.t_max:
factor = self.t_max / (self.t - self.t_start)
pi_loss_factor *= factor
v_loss_factor *= factor
if self.normalize_grad_by_t_max:
pi_loss_factor /= self.t - self.t_start
v_loss_factor /= self.t - self.t_start
# Batch re-compute for efficient backprop
batch_obs = self.batch_states(
[self.past_obs[i] for i in range(self.t_start, self.t)],
self.device,
self.phi,
)
if self.recurrent:
(batch_distrib, batch_v), _ = pack_and_forward(
self.model,
[batch_obs],
self.past_recurrent_state[self.t_start],
)
else:
batch_distrib, batch_v = self.model(batch_obs)
batch_action = torch.stack(
[self.past_action[i] for i in range(self.t_start, self.t)]
)
batch_log_prob = batch_distrib.log_prob(batch_action)
batch_entropy = batch_distrib.entropy()
rev_returns = []
for i in reversed(range(self.t_start, self.t)):
R *= self.gamma
R += self.past_rewards[i]
rev_returns.append(R)
batch_return = torch.as_tensor(list(reversed(rev_returns)), dtype=torch.float)
batch_adv = batch_return - batch_v.detach().squeeze(-1)
assert batch_log_prob.shape == (n,)
assert batch_adv.shape == (n,)
assert batch_entropy.shape == (n,)
pi_loss = torch.sum(
-batch_adv * batch_log_prob - self.beta * batch_entropy, dim=0
)
assert batch_v.shape == (n, 1)
assert batch_return.shape == (n,)
v_loss = F.mse_loss(batch_v, batch_return[..., None], reduction="sum") / 2
if pi_loss_factor != 1.0:
pi_loss *= pi_loss_factor
if v_loss_factor != 1.0:
v_loss *= v_loss_factor
if self.process_idx == 0:
logger.debug("pi_loss:%s v_loss:%s", pi_loss, v_loss)
total_loss = torch.squeeze(pi_loss) + torch.squeeze(v_loss)
# Compute gradients using thread-specific model
self.model.zero_grad()
total_loss.backward()
if self.max_grad_norm is not None:
clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm)
# Copy the gradients to the globally shared model
copy_param.copy_grad(target_link=self.shared_model, source_link=self.model)
# Update the globally shared model
self.optimizer.step()
if self.process_idx == 0:
logger.debug("update")
self.sync_parameters()
self.past_obs = {}
self.past_action = {}
self.past_rewards = {}
self.past_recurrent_state = {}
self.t_start = self.t
def act(self, obs):
if self.training:
return self._act_train(obs)
else:
return self._act_eval(obs)
def observe(self, obs, reward, done, reset):
if self.training:
self._observe_train(obs, reward, done, reset)
else:
self._observe_eval(obs, reward, done, reset)
def _act_train(self, obs):
self.past_obs[self.t] = obs
with torch.no_grad():
statevar = self.batch_states([obs], self.device, self.phi)
if self.recurrent:
self.past_recurrent_state[self.t] = self.train_recurrent_states
(pout, vout), self.train_recurrent_states = one_step_forward(
self.model, statevar, self.train_recurrent_states
)
else:
pout, vout = self.model(statevar)
# Do not backprop through sampled actions
action = pout.sample()
self.past_action[self.t] = action[0].detach()
action = action.cpu().numpy()[0]
# Update stats
self.average_value += (1 - self.average_value_decay) * (
float(vout) - self.average_value
)
self.average_entropy += (1 - self.average_entropy_decay) * (
float(pout.entropy()) - self.average_entropy
)
return action
def _observe_train(self, obs, reward, done, reset):
self.t += 1
self.past_rewards[self.t - 1] = reward
if self.process_idx == 0:
logger.debug(
"t:%s action:%s reward:%s", self.t, self.past_action[self.t - 1], reward
)
if self.t - self.t_start == self.t_max or done or reset:
if done:
statevar = None
else:
statevar = self.batch_states([obs], self.device, self.phi)
self.update(statevar)
if done or reset:
self.train_recurrent_states = None
def _act_eval(self, obs):
# Use the process-local model for acting
with torch.no_grad(), pfrl.utils.evaluating(self.model):
statevar = self.batch_states([obs], self.device, self.phi)
if self.recurrent:
(pout, _), self.test_recurrent_states = one_step_forward(
self.model, statevar, self.test_recurrent_states
)
else:
pout, _ = self.model(statevar)
if self.act_deterministically:
return mode_of_distribution(pout).cpu().numpy()[0]
else:
return pout.sample().cpu().numpy()[0]
def _observe_eval(self, obs, reward, done, reset):
if done or reset:
self.test_recurrent_states = None
def load(self, dirname):
super().load(dirname)
copy_param.copy_param(target_link=self.shared_model, source_link=self.model)
def get_statistics(self):
return [
("average_value", self.average_value),
("average_entropy", self.average_entropy),
]