Source code for pfrl.agents.categorical_double_dqn

import pfrl
from pfrl.agents import categorical_dqn
from pfrl.agents.categorical_dqn import _apply_categorical_projection
from pfrl.utils.recurrent import pack_and_forward


[docs]class CategoricalDoubleDQN(categorical_dqn.CategoricalDQN): """Categorical Double DQN.""" def _compute_target_values(self, exp_batch): """Compute a batch of target return distributions.""" batch_next_state = exp_batch["next_state"] batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] with pfrl.utils.evaluating(self.target_model), pfrl.utils.evaluating( self.model ): if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) next_qout, _ = pack_and_forward( self.model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_qout = self.model(batch_next_state) batch_size = batch_rewards.shape[0] z_values = target_next_qout.z_values n_atoms = z_values.numel() # next_q_max: (batch_size, n_atoms) next_q_max = target_next_qout.evaluate_actions_as_distribution( next_qout.greedy_actions.detach() ) assert next_q_max.shape == (batch_size, n_atoms), next_q_max.shape # Tz: (batch_size, n_atoms) Tz = ( batch_rewards[..., None] + (1.0 - batch_terminal[..., None]) * exp_batch["discount"][..., None] * z_values[None] ) return _apply_categorical_projection(Tz, next_q_max, z_values)