Source code for pfrl.agents.dqn

import collections
import copy
import ctypes
import multiprocessing as mp
import multiprocessing.synchronize
import os
import time
import typing
from logging import Logger, getLogger
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import numpy as np
import torch
import torch.nn.functional as F

import pfrl
from pfrl import agent
from pfrl.action_value import ActionValue
from pfrl.explorer import Explorer
from pfrl.replay_buffer import (
    AbstractEpisodicReplayBuffer,
    ReplayUpdater,
    batch_experiences,
    batch_recurrent_experiences,
)
from pfrl.replay_buffers import PrioritizedReplayBuffer
from pfrl.utils.batch_states import batch_states
from pfrl.utils.contexts import evaluating
from pfrl.utils.copy_param import synchronize_parameters
from pfrl.utils.recurrent import (
    get_recurrent_state_at,
    mask_recurrent_state_at,
    one_step_forward,
    pack_and_forward,
    recurrent_state_as_numpy,
)


def _mean_or_nan(xs: Sequence[float]) -> float:
    """Return its mean a non-empty sequence, numpy.nan for a empty one."""
    return typing.cast(float, np.mean(xs)) if xs else np.nan


def compute_value_loss(
    y: torch.Tensor,
    t: torch.Tensor,
    clip_delta: bool = True,
    batch_accumulator: str = "mean",
) -> torch.Tensor:
    """Compute a loss for value prediction problem.

    Args:
        y (torch.Tensor): Predicted values.
        t (torch.Tensor): Target values.
        clip_delta (bool): Use the Huber loss function with delta=1 if set True.
        batch_accumulator (str): 'mean' or 'sum'. 'mean' will use the mean of
            the loss values in a batch. 'sum' will use the sum.
    Returns:
        (torch.Tensor) scalar loss
    """
    assert batch_accumulator in ("mean", "sum")
    y = y.reshape(-1, 1)
    t = t.reshape(-1, 1)
    if clip_delta:
        return F.smooth_l1_loss(y, t, reduction=batch_accumulator)
    else:
        return F.mse_loss(y, t, reduction=batch_accumulator) / 2


def compute_weighted_value_loss(
    y: torch.Tensor,
    t: torch.Tensor,
    weights: torch.Tensor,
    clip_delta: bool = True,
    batch_accumulator: str = "mean",
) -> torch.Tensor:
    """Compute a loss for value prediction problem.

    Args:
        y (torch.Tensor): Predicted values.
        t (torch.Tensor): Target values.
        weights (torch.Tensor): Weights for y, t.
        clip_delta (bool): Use the Huber loss function with delta=1 if set True.
        batch_accumulator (str): 'mean' will divide loss by batchsize
    Returns:
        (torch.Tensor) scalar loss
    """
    assert batch_accumulator in ("mean", "sum")
    y = y.reshape(-1, 1)
    t = t.reshape(-1, 1)
    if clip_delta:
        losses = F.smooth_l1_loss(y, t, reduction="none")
    else:
        losses = F.mse_loss(y, t, reduction="none") / 2
    losses = losses.reshape(
        -1,
    )
    weights = weights.to(losses.device)
    loss_sum = torch.sum(losses * weights)
    if batch_accumulator == "mean":
        loss = loss_sum / y.shape[0]
    elif batch_accumulator == "sum":
        loss = loss_sum
    return loss


def _batch_reset_recurrent_states_when_episodes_end(
    batch_done: Sequence[bool], batch_reset: Sequence[bool], recurrent_states: Any
) -> Any:
    """Reset recurrent states when episodes end.

    Args:
        batch_done (array-like of bool): True iff episodes are terminal.
        batch_reset (array-like of bool): True iff episodes will be reset.
        recurrent_states (object): Recurrent state.

    Returns:
        object: New recurrent states.
    """
    indices_that_ended = [
        i
        for i, (done, reset) in enumerate(zip(batch_done, batch_reset))
        if done or reset
    ]
    if indices_that_ended:
        return mask_recurrent_state_at(recurrent_states, indices_that_ended)
    else:
        return recurrent_states


def make_target_model_as_copy(model: torch.nn.Module) -> torch.nn.Module:
    target_model = copy.deepcopy(model)

    def flatten_parameters(mod):
        if isinstance(mod, torch.nn.RNNBase):
            mod.flatten_parameters()

    # RNNBase.flatten_parameters must be called again after deep-copy.
    # See: https://discuss.pytorch.org/t/why-do-we-need-flatten-parameters-when-using-rnn-with-dataparallel/46506  # NOQA
    target_model.apply(flatten_parameters)
    # set target n/w to evaluate only.
    target_model.eval()
    return target_model


[docs]class DQN(agent.AttributeSavingMixin, agent.BatchAgent): """Deep Q-Network algorithm. Args: q_function (StateQFunction): Q-function optimizer (Optimizer): Optimizer that is already setup replay_buffer (ReplayBuffer): Replay buffer gamma (float): Discount factor explorer (Explorer): Explorer that specifies an exploration strategy. gpu (int): GPU device id if not None nor negative. replay_start_size (int): if the replay buffer's size is less than replay_start_size, skip update minibatch_size (int): Minibatch size update_interval (int): Model update interval in step target_update_interval (int): Target model update interval in step clip_delta (bool): Clip delta if set True phi (callable): Feature extractor applied to observations target_update_method (str): 'hard' or 'soft'. soft_update_tau (float): Tau of soft target update. n_times_update (int): Number of repetition of update batch_accumulator (str): 'mean' or 'sum' episodic_update_len (int or None): Subsequences of this length are used for update if set int and episodic_update=True logger (Logger): Logger used batch_states (callable): method which makes a batch of observations. default is `pfrl.utils.batch_states.batch_states` recurrent (bool): If set to True, `model` is assumed to implement `pfrl.nn.Recurrent` and is updated in a recurrent manner. max_grad_norm (float or None): Maximum L2 norm of the gradient used for gradient clipping. If set to None, the gradient is not clipped. """ saved_attributes = ("model", "target_model", "optimizer") def __init__( self, q_function: torch.nn.Module, optimizer: torch.optim.Optimizer, # type: ignore # somehow mypy complains replay_buffer: pfrl.replay_buffer.AbstractReplayBuffer, gamma: float, explorer: Explorer, gpu: Optional[int] = None, replay_start_size: int = 50000, minibatch_size: int = 32, update_interval: int = 1, target_update_interval: int = 10000, clip_delta: bool = True, phi: Callable[[Any], Any] = lambda x: x, target_update_method: str = "hard", soft_update_tau: float = 1e-2, n_times_update: int = 1, batch_accumulator: str = "mean", episodic_update_len: Optional[int] = None, logger: Logger = getLogger(__name__), batch_states: Callable[ [Sequence[Any], torch.device, Callable[[Any], Any]], Any ] = batch_states, recurrent: bool = False, max_grad_norm: Optional[float] = None, ): self.model = q_function if gpu is not None and gpu >= 0: assert torch.cuda.is_available() self.device = torch.device("cuda:{}".format(gpu)) self.model.to(self.device) else: self.device = torch.device("cpu") self.replay_buffer = replay_buffer self.optimizer = optimizer self.gamma = gamma self.explorer = explorer self.gpu = gpu self.target_update_interval = target_update_interval self.clip_delta = clip_delta self.phi = phi self.target_update_method = target_update_method self.soft_update_tau = soft_update_tau self.batch_accumulator = batch_accumulator assert batch_accumulator in ("mean", "sum") self.logger = logger self.batch_states = batch_states self.recurrent = recurrent update_func: Callable[..., None] if self.recurrent: assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) update_func = self.update_from_episodes else: update_func = self.update self.replay_updater = ReplayUpdater( replay_buffer=replay_buffer, update_func=update_func, batchsize=minibatch_size, episodic_update=recurrent, episodic_update_len=episodic_update_len, n_times_update=n_times_update, replay_start_size=replay_start_size, update_interval=update_interval, ) self.minibatch_size = minibatch_size self.episodic_update_len = episodic_update_len self.replay_start_size = replay_start_size self.update_interval = update_interval self.max_grad_norm = max_grad_norm assert ( target_update_interval % update_interval == 0 ), "target_update_interval should be a multiple of update_interval" self.t = 0 self.optim_t = 0 # Compensate pytorch optim not having `t` self._cumulative_steps = 0 self.target_model = make_target_model_as_copy(self.model) # Statistics self.q_record: collections.deque = collections.deque(maxlen=1000) self.loss_record: collections.deque = collections.deque(maxlen=100) # Recurrent states of the model self.train_recurrent_states: Any = None self.train_prev_recurrent_states: Any = None self.test_recurrent_states: Any = None # Error checking if ( self.replay_buffer.capacity is not None and self.replay_buffer.capacity < self.replay_updater.replay_start_size ): raise ValueError("Replay start size cannot exceed replay buffer capacity.") @property def cumulative_steps(self) -> int: # cumulative_steps counts the overall steps during the training. return self._cumulative_steps def _setup_actor_learner_training( self, n_actors: int, actor_update_interval: int, update_counter: Any, ) -> Tuple[ torch.nn.Module, Sequence[mp.connection.Connection], Sequence[mp.connection.Connection], ]: assert actor_update_interval > 0 self.actor_update_interval = actor_update_interval self.update_counter = update_counter # Make a copy on shared memory and share among actors and the poller shared_model = copy.deepcopy(self.model).cpu() shared_model.share_memory() # Pipes are used for infrequent communication learner_pipes, actor_pipes = list(zip(*[mp.Pipe() for _ in range(n_actors)])) return (shared_model, learner_pipes, actor_pipes) def sync_target_network(self) -> None: """Synchronize target network with current network.""" synchronize_parameters( src=self.model, dst=self.target_model, method=self.target_update_method, tau=self.soft_update_tau, ) def update( self, experiences: List[List[Dict[str, Any]]], errors_out: Optional[list] = None ) -> None: """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None """ has_weight = "weight" in experiences[0][0] exp_batch = batch_experiences( experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) if has_weight: exp_batch["weights"] = torch.tensor( [elem[0]["weight"] for elem in experiences], device=self.device, dtype=torch.float32, ) if errors_out is None: errors_out = [] loss = self._compute_loss(exp_batch, errors_out=errors_out) if has_weight: assert isinstance(self.replay_buffer, PrioritizedReplayBuffer) self.replay_buffer.update_errors(errors_out) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.optim_t += 1 def update_from_episodes( self, episodes: List[List[Dict[str, Any]]], errors_out: Optional[list] = None ) -> None: assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer" episodes = sorted(episodes, key=len, reverse=True) exp_batch = batch_recurrent_experiences( episodes, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, errors_out=None) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.optim_t += 1 def _compute_target_values(self, exp_batch: Dict[str, Any]) -> torch.Tensor: batch_next_state = exp_batch["next_state"] if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_q_max = target_next_qout.max batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max def _compute_y_and_t( self, exp_batch: Dict[str, Any] ) -> Tuple[torch.Tensor, torch.Tensor]: batch_size = exp_batch["reward"].shape[0] # Compute Q-values for current states batch_state = exp_batch["state"] if self.recurrent: qout, _ = pack_and_forward( self.model, batch_state, exp_batch["recurrent_state"] ) else: qout = self.model(batch_state) batch_actions = exp_batch["action"] batch_q = torch.reshape(qout.evaluate_actions(batch_actions), (batch_size, 1)) with torch.no_grad(): batch_q_target = torch.reshape( self._compute_target_values(exp_batch), (batch_size, 1) ) return batch_q, batch_q_target def _compute_loss( self, exp_batch: Dict[str, Any], errors_out: Optional[list] = None ) -> torch.Tensor: """Compute the Q-learning loss for a batch of experiences Args: exp_batch (dict): A dict of batched arrays of transitions Returns: Computed loss from the minibatch of experiences """ y, t = self._compute_y_and_t(exp_batch) self.q_record.extend(y.detach().cpu().numpy().ravel()) if errors_out is not None: del errors_out[:] delta = torch.abs(y - t) if delta.ndim == 2: delta = torch.sum(delta, dim=1) delta = delta.detach().cpu().numpy() for e in delta: errors_out.append(e) if "weights" in exp_batch: return compute_weighted_value_loss( y, t, exp_batch["weights"], clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) else: return compute_value_loss( y, t, clip_delta=self.clip_delta, batch_accumulator=self.batch_accumulator, ) def _evaluate_model_and_update_recurrent_states( self, batch_obs: Sequence[Any] ) -> ActionValue: batch_xs = self.batch_states(batch_obs, self.device, self.phi) if self.recurrent: if self.training: self.train_prev_recurrent_states = self.train_recurrent_states batch_av, self.train_recurrent_states = one_step_forward( self.model, batch_xs, self.train_recurrent_states ) else: batch_av, self.test_recurrent_states = one_step_forward( self.model, batch_xs, self.test_recurrent_states ) else: batch_av = self.model(batch_xs) return batch_av def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]: with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) batch_argmax = batch_av.greedy_actions.detach().cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( self.t, lambda: batch_argmax[i], action_value=batch_av[i : i + 1], ) for i in range(len(batch_obs)) ] self.batch_last_obs = list(batch_obs) self.batch_last_action = list(batch_action) else: batch_action = batch_argmax return batch_action def _batch_observe_train( self, batch_obs: Sequence[Any], batch_reward: Sequence[float], batch_done: Sequence[bool], batch_reset: Sequence[bool], ) -> None: for i in range(len(batch_obs)): self.t += 1 self._cumulative_steps += 1 # Update the target network if self.t % self.target_update_interval == 0: self.sync_target_network() if self.batch_last_obs[i] is not None: assert self.batch_last_action[i] is not None # Add a transition to the replay buffer transition = { "state": self.batch_last_obs[i], "action": self.batch_last_action[i], "reward": batch_reward[i], "next_state": batch_obs[i], "next_action": None, "is_state_terminal": batch_done[i], } if self.recurrent: transition["recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_prev_recurrent_states, i, detach=True ) ) transition["next_recurrent_state"] = recurrent_state_as_numpy( get_recurrent_state_at( self.train_recurrent_states, i, detach=True ) ) self.replay_buffer.append(env_id=i, **transition) if batch_reset[i] or batch_done[i]: self.batch_last_obs[i] = None self.batch_last_action[i] = None self.replay_buffer.stop_current_episode(env_id=i) self.replay_updater.update_if_necessary(self.t) if self.recurrent: # Reset recurrent states when episodes end self.train_prev_recurrent_states = None self.train_recurrent_states = ( _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.train_recurrent_states, ) ) def _batch_observe_eval( self, batch_obs: Sequence[Any], batch_reward: Sequence[float], batch_done: Sequence[bool], batch_reset: Sequence[bool], ) -> None: if self.recurrent: # Reset recurrent states when episodes end self.test_recurrent_states = ( _batch_reset_recurrent_states_when_episodes_end( # NOQA batch_done=batch_done, batch_reset=batch_reset, recurrent_states=self.test_recurrent_states, ) ) def batch_observe( self, batch_obs: Sequence[Any], batch_reward: Sequence[float], batch_done: Sequence[bool], batch_reset: Sequence[bool], ) -> None: if self.training: return self._batch_observe_train( batch_obs, batch_reward, batch_done, batch_reset ) else: return self._batch_observe_eval( batch_obs, batch_reward, batch_done, batch_reset ) def _can_start_replay(self) -> bool: if len(self.replay_buffer) < self.replay_start_size: return False if self.recurrent: assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) if self.replay_buffer.n_episodes < self.minibatch_size: return False return True def _poll_pipe( self, actor_idx: int, pipe: mp.connection.Connection, replay_buffer_lock: mp.synchronize.Lock, exception_event: mp.synchronize.Event, ) -> None: if pipe.closed: return try: while pipe.poll() and not exception_event.is_set(): cmd, data = pipe.recv() if cmd == "get_statistics": assert data is None with replay_buffer_lock: stats = self.get_statistics() pipe.send(stats) elif cmd == "load": self.load(data) pipe.send(None) elif cmd == "save": self.save(data) pipe.send(None) elif cmd == "transition": with replay_buffer_lock: if "env_id" not in data: data["env_id"] = actor_idx self.replay_buffer.append(**data) self._cumulative_steps += 1 elif cmd == "stop_episode": idx = actor_idx if data is None else data with replay_buffer_lock: self.replay_buffer.stop_current_episode(env_id=idx) stats = self.get_statistics() pipe.send(stats) else: raise RuntimeError("Unknown command from actor: {}".format(cmd)) except EOFError: pipe.close() except Exception: self.logger.exception("Poller loop failed. Exiting") exception_event.set() def _learner_loop( self, shared_model: torch.nn.Module, pipes: Sequence[mp.connection.Connection], replay_buffer_lock: mp.synchronize.Lock, stop_event: mp.synchronize.Event, exception_event: mp.synchronize.Event, n_updates: Optional[int] = None, step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], optimizer_step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], ) -> None: try: update_counter = 0 # To stop this loop, call stop_event.set() while not stop_event.is_set(): # Update model if possible if not self._can_start_replay(): continue if n_updates is not None: assert self.optim_t <= n_updates if self.optim_t == n_updates: stop_event.set() break if self.recurrent: assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer) with replay_buffer_lock: episodes = self.replay_buffer.sample_episodes( self.minibatch_size, self.episodic_update_len ) self.update_from_episodes(episodes) else: with replay_buffer_lock: transitions = self.replay_buffer.sample(self.minibatch_size) self.update(transitions) # Update the shared model. This can be expensive if GPU is used # since this is a DtoH copy, so it is updated only at regular # intervals. update_counter += 1 if update_counter % self.actor_update_interval == 0: with self.update_counter.get_lock(): self.update_counter.value += 1 shared_model.load_state_dict(self.model.state_dict()) # To keep the ratio of target updates to model updates, # here we calculate back the effective current timestep # from update_interval and number of updates so far. effective_timestep = self.optim_t * self.update_interval # We can safely assign self.t since in the learner # it isn't updated by any other method self.t = effective_timestep for hook in optimizer_step_hooks: hook(None, self, self.optim_t) for hook in step_hooks: hook(None, self, effective_timestep) if effective_timestep % self.target_update_interval == 0: self.sync_target_network() except Exception: self.logger.exception("Learner loop failed. Exiting") exception_event.set() def _poller_loop( self, shared_model: torch.nn.Module, pipes: Sequence[mp.connection.Connection], replay_buffer_lock: mp.synchronize.Lock, stop_event: mp.synchronize.Event, exception_event: mp.synchronize.Event, ) -> None: # To stop this loop, call stop_event.set() while not stop_event.is_set() and not exception_event.is_set(): time.sleep(1e-6) # Poll actors for messages for i, pipe in enumerate(pipes): self._poll_pipe(i, pipe, replay_buffer_lock, exception_event) def setup_actor_learner_training( self, n_actors: int, update_counter: Optional[Any] = None, n_updates: Optional[int] = None, actor_update_interval: int = 8, step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], optimizer_step_hooks: List[Callable[[None, agent.Agent, int], Any]] = [], ): if update_counter is None: update_counter = mp.Value(ctypes.c_ulong) (shared_model, learner_pipes, actor_pipes) = self._setup_actor_learner_training( n_actors, actor_update_interval, update_counter ) exception_event = mp.Event() def make_actor(i): return pfrl.agents.StateQFunctionActor( pipe=actor_pipes[i], model=shared_model, explorer=self.explorer, phi=self.phi, batch_states=self.batch_states, logger=self.logger, recurrent=self.recurrent, ) replay_buffer_lock = mp.Lock() poller_stop_event = mp.Event() poller = pfrl.utils.StoppableThread( target=self._poller_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=poller_stop_event, exception_event=exception_event, ), stop_event=poller_stop_event, ) learner_stop_event = mp.Event() learner = pfrl.utils.StoppableThread( target=self._learner_loop, kwargs=dict( shared_model=shared_model, pipes=learner_pipes, replay_buffer_lock=replay_buffer_lock, stop_event=learner_stop_event, n_updates=n_updates, exception_event=exception_event, step_hooks=step_hooks, optimizer_step_hooks=optimizer_step_hooks, ), stop_event=learner_stop_event, ) return make_actor, learner, poller, exception_event def stop_episode(self) -> None: if self.recurrent: self.test_recurrent_states = None def save_snapshot(self, dirname: str) -> None: self.save(dirname) torch.save(self.t, os.path.join(dirname, "t.pt")) torch.save(self.optim_t, os.path.join(dirname, "optim_t.pt")) torch.save( self._cumulative_steps, os.path.join(dirname, "_cumulative_steps.pt") ) self.replay_buffer.save(os.path.join(dirname, "replay_buffer.pkl")) def load_snapshot(self, dirname: str) -> None: self.load(dirname) self.t = torch.load(os.path.join(dirname, "t.pt")) self.optim_t = torch.load(os.path.join(dirname, "optim_t.pt")) self._cumulative_steps = torch.load( os.path.join(dirname, "_cumulative_steps.pt") ) self.replay_buffer.load(os.path.join(dirname, "replay_buffer.pkl")) def get_statistics(self): return [ ("average_q", _mean_or_nan(self.q_record)), ("average_loss", _mean_or_nan(self.loss_record)), ("cumulative_steps", self.cumulative_steps), ("n_updates", self.optim_t), ("rlen", len(self.replay_buffer)), ]