Source code for pfrl.agents.double_dqn

from pfrl.agents import dqn
from pfrl.utils import evaluating
from pfrl.utils.recurrent import pack_and_forward


[docs]class DoubleDQN(dqn.DQN): """Double DQN. See: http://arxiv.org/abs/1509.06461. """ def _compute_target_values(self, exp_batch): batch_next_state = exp_batch["next_state"] with evaluating(self.model): if self.recurrent: next_qout, _ = pack_and_forward( self.model, batch_next_state, exp_batch["next_recurrent_state"], ) else: next_qout = self.model(batch_next_state) if self.recurrent: target_next_qout, _ = pack_and_forward( self.target_model, batch_next_state, exp_batch["next_recurrent_state"], ) else: target_next_qout = self.target_model(batch_next_state) next_q_max = target_next_qout.evaluate_actions(next_qout.greedy_actions) batch_rewards = exp_batch["reward"] batch_terminal = exp_batch["is_state_terminal"] discount = exp_batch["discount"] return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max