Source code for pfrl.agents.a2c

import warnings
from logging import getLogger

import torch

from pfrl import agent
from pfrl.utils import clip_l2_grad_norm_
from pfrl.utils.batch_states import batch_states
from pfrl.utils.mode_of_distribution import mode_of_distribution

logger = getLogger(__name__)


[docs]class A2C(agent.AttributeSavingMixin, agent.BatchAgent): """A2C: Advantage Actor-Critic. A2C is a synchronous, deterministic variant of Asynchronous Advantage Actor Critic (A3C). See https://arxiv.org/abs/1708.05144 Args: model (nn.Module): Model to train optimizer (torch.optim.Optimizer): optimizer used to train the model gamma (float): Discount factor [0,1] num_processes (int): The number of processes gpu (int): GPU device id if not None nor negative. update_steps (int): The number of update steps phi (callable): Feature extractor function pi_loss_coef (float): Weight coefficient for the loss of the policy v_loss_coef (float): Weight coefficient for the loss of the value function entropy_coeff (float): Weight coefficient for the loss of the entropy use_gae (bool): use generalized advantage estimation(GAE) tau (float): gae parameter act_deterministically (bool): If set true, choose most probable actions in act method. max_grad_norm (float or None): Maximum L2 norm of the gradient used for gradient clipping. If set to None, the gradient is not clipped. average_actor_loss_decay (float): Decay rate of average actor loss. Used only to record statistics. average_entropy_decay (float): Decay rate of average entropy. Used only to record statistics. average_value_decay (float): Decay rate of average value. Used only to record statistics. batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` """ process_idx = None saved_attributes = ("model", "optimizer") def __init__( self, model, optimizer, gamma, num_processes, gpu=None, update_steps=5, phi=lambda x: x, pi_loss_coef=1.0, v_loss_coef=0.5, entropy_coeff=0.01, use_gae=False, tau=0.95, act_deterministically=False, max_grad_norm=None, average_actor_loss_decay=0.999, average_entropy_decay=0.999, average_value_decay=0.999, batch_states=batch_states, ): self.model = model if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.optimizer = optimizer self.update_steps = update_steps self.num_processes = num_processes self.gamma = gamma self.use_gae = use_gae self.tau = tau self.act_deterministically = act_deterministically self.max_grad_norm = max_grad_norm self.phi = phi self.pi_loss_coef = pi_loss_coef self.v_loss_coef = v_loss_coef self.entropy_coeff = entropy_coeff self.average_actor_loss_decay = average_actor_loss_decay self.average_value_decay = average_value_decay self.average_entropy_decay = average_entropy_decay self.batch_states = batch_states self.t = 0 self.t_start = 0 # Stats self.average_actor_loss = 0 self.average_value = 0 self.average_entropy = 0 def _flush_storage(self, obs_shape, action): obs_shape = obs_shape[1:] action_shape = action.shape[1:] self.states = torch.zeros( self.update_steps + 1, self.num_processes, *obs_shape, device=self.device, dtype=torch.float ) self.actions = torch.zeros( self.update_steps, self.num_processes, *action_shape, device=self.device, dtype=torch.float ) self.rewards = torch.zeros( self.update_steps, self.num_processes, device=self.device, dtype=torch.float ) self.value_preds = torch.zeros( self.update_steps + 1, self.num_processes, device=self.device, dtype=torch.float, ) self.returns = torch.zeros( self.update_steps + 1, self.num_processes, device=self.device, dtype=torch.float, ) self.masks = torch.ones( self.update_steps, self.num_processes, device=self.device, dtype=torch.float ) self.obs_shape = obs_shape self.action_shape = action_shape def _compute_returns(self, next_value): if self.use_gae: self.value_preds[-1] = next_value gae = 0 for i in reversed(range(self.update_steps)): delta = ( self.rewards[i] + self.gamma * self.value_preds[i + 1] * self.masks[i] - self.value_preds[i] ) gae = delta + self.gamma * self.tau * self.masks[i] * gae self.returns[i] = gae + self.value_preds[i] else: self.returns[-1] = next_value for i in reversed(range(self.update_steps)): self.returns[i] = ( self.rewards[i] + self.gamma * self.returns[i + 1] * self.masks[i] ) def update(self): with torch.no_grad(): _, next_value = self.model(self.states[-1]) next_value = next_value[:, 0] self._compute_returns(next_value) pout, values = self.model(self.states[:-1].reshape(-1, *self.obs_shape)) actions = self.actions.reshape(-1, *self.action_shape) dist_entropy = pout.entropy().mean() action_log_probs = pout.log_prob(actions) values = values.reshape((self.update_steps, self.num_processes)) action_log_probs = action_log_probs.reshape( (self.update_steps, self.num_processes) ) advantages = self.returns[:-1] - values value_loss = (advantages * advantages).mean() action_loss = -(advantages.detach() * action_log_probs).mean() self.optimizer.zero_grad() ( value_loss * self.v_loss_coef + action_loss * self.pi_loss_coef - dist_entropy * self.entropy_coeff ).backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.states[0] = self.states[-1] self.t_start = self.t # Update stats self.average_actor_loss += (1 - self.average_actor_loss_decay) * ( float(action_loss) - self.average_actor_loss ) self.average_value += (1 - self.average_value_decay) * ( float(value_loss) - self.average_value ) self.average_entropy += (1 - self.average_entropy_decay) * ( float(dist_entropy) - self.average_entropy ) def batch_act(self, batch_obs): if self.training: return self._batch_act_train(batch_obs) else: return self._batch_act_eval(batch_obs) def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset): if self.training: self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset) def _batch_act_train(self, batch_obs): assert self.training statevar = self.batch_states(batch_obs, self.device, self.phi) if self.t == 0: with torch.no_grad(): pout, _ = self.model(statevar) action = pout.sample() self._flush_storage(statevar.shape, action) self.states[self.t - self.t_start] = statevar if self.t - self.t_start == self.update_steps: self.update() with torch.no_grad(): pout, value = self.model(statevar) action = pout.sample() self.actions[self.t - self.t_start] = action.reshape(-1, *self.action_shape) self.value_preds[self.t - self.t_start] = value[:, 0] return action.cpu().numpy() def _batch_act_eval(self, batch_obs): assert not self.training statevar = self.batch_states(batch_obs, self.device, self.phi) with torch.no_grad(): pout, _ = self.model(statevar) if self.act_deterministically: action = mode_of_distribution(pout) else: action = pout.sample() return action.cpu().numpy() def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset): assert self.training self.t += 1 if any(batch_reset): warnings.warn( "A2C currently does not support resetting an env without reaching a" " terminal state during training. When receiving True in batch_reset," " A2C considers it as True in batch_done instead." ) # NOQA batch_done = list(batch_done) for i, reset in enumerate(batch_reset): if reset: batch_done[i] = True statevar = self.batch_states(batch_obs, self.device, self.phi) self.masks[self.t - self.t_start - 1] = torch.as_tensor( [0.0 if done else 1.0 for done in batch_done], device=self.device ) self.rewards[self.t - self.t_start - 1] = torch.as_tensor( batch_reward, device=self.device, dtype=torch.float ) self.states[self.t - self.t_start] = statevar if self.t - self.t_start == self.update_steps: self.update() def get_statistics(self): return [ ("average_actor", self.average_actor_loss), ("average_value", self.average_value), ("average_entropy", self.average_entropy), ]