from pfrl.agents import dqn
from pfrl.utils import evaluating
from pfrl.utils.recurrent import pack_and_forward
[docs]class DoubleDQN(dqn.DQN):
"""Double DQN.
See: http://arxiv.org/abs/1509.06461.
"""
def _compute_target_values(self, exp_batch):
batch_next_state = exp_batch["next_state"]
with evaluating(self.model):
if self.recurrent:
next_qout, _ = pack_and_forward(
self.model,
batch_next_state,
exp_batch["next_recurrent_state"],
)
else:
next_qout = self.model(batch_next_state)
if self.recurrent:
target_next_qout, _ = pack_and_forward(
self.target_model,
batch_next_state,
exp_batch["next_recurrent_state"],
)
else:
target_next_qout = self.target_model(batch_next_state)
next_q_max = target_next_qout.evaluate_actions(next_qout.greedy_actions)
batch_rewards = exp_batch["reward"]
batch_terminal = exp_batch["is_state_terminal"]
discount = exp_batch["discount"]
return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max