import collections
import itertools
import random
import numpy as np
import torch
import torch.nn.functional as F
import pfrl
from pfrl import agent
from pfrl.utils.batch_states import batch_states
from pfrl.utils.mode_of_distribution import mode_of_distribution
from pfrl.utils.recurrent import (
concatenate_recurrent_states,
flatten_sequences_time_first,
get_recurrent_state_at,
mask_recurrent_state_at,
one_step_forward,
pack_and_forward,
)
def _mean_or_nan(xs):
"""Return its mean a non-empty sequence, numpy.nan for a empty one."""
return np.mean(xs) if xs else np.nan
def _elementwise_clip(x, x_min, x_max):
"""Elementwise clipping
Note: torch.clamp supports clipping to constant intervals
"""
return torch.min(torch.max(x, x_min), x_max)
def _add_advantage_and_value_target_to_episode(episode, gamma, lambd):
"""Add advantage and value target values to an episode."""
adv = 0.0
for transition in reversed(episode):
td_err = (
transition["reward"]
+ (gamma * transition["nonterminal"] * transition["next_v_pred"])
- transition["v_pred"]
)
adv = td_err + gamma * lambd * adv
transition["adv"] = adv
transition["v_teacher"] = adv + transition["v_pred"]
def _add_advantage_and_value_target_to_episodes(episodes, gamma, lambd):
"""Add advantage and value target values to a list of episodes."""
for episode in episodes:
_add_advantage_and_value_target_to_episode(episode, gamma=gamma, lambd=lambd)
def _add_log_prob_and_value_to_episodes_recurrent(
episodes,
model,
phi,
batch_states,
obs_normalizer,
device,
):
# Sort desc by lengths so that pack_sequence does not change the order
episodes = sorted(episodes, key=len, reverse=True)
# Prepare data for a recurrent model
seqs_states = []
seqs_next_states = []
for ep in episodes:
states = batch_states([transition["state"] for transition in ep], device, phi)
next_states = batch_states(
[transition["next_state"] for transition in ep], device, phi
)
if obs_normalizer:
states = obs_normalizer(states, update=False)
next_states = obs_normalizer(next_states, update=False)
seqs_states.append(states)
seqs_next_states.append(next_states)
flat_transitions = flatten_sequences_time_first(episodes)
# Predict values using a recurrent model
with torch.no_grad(), pfrl.utils.evaluating(model):
rs = concatenate_recurrent_states([ep[0]["recurrent_state"] for ep in episodes])
next_rs = concatenate_recurrent_states(
[ep[0]["next_recurrent_state"] for ep in episodes]
)
assert (rs is None) or (next_rs is None) or (len(rs) == len(next_rs))
(flat_distribs, flat_vs), _ = pack_and_forward(model, seqs_states, rs)
(_, flat_next_vs), _ = pack_and_forward(model, seqs_next_states, next_rs)
flat_actions = torch.tensor(
[b["action"] for b in flat_transitions], device=device
)
flat_log_probs = flat_distribs.log_prob(flat_actions).cpu().numpy()
flat_vs = flat_vs.cpu().numpy()
flat_next_vs = flat_next_vs.cpu().numpy()
# Add predicted values to transitions
for transition, log_prob, v, next_v in zip(
flat_transitions, flat_log_probs, flat_vs, flat_next_vs
):
transition["log_prob"] = float(log_prob)
transition["v_pred"] = float(v)
transition["next_v_pred"] = float(next_v)
def _add_log_prob_and_value_to_episodes(
episodes,
model,
phi,
batch_states,
obs_normalizer,
device,
):
dataset = list(itertools.chain.from_iterable(episodes))
# Compute v_pred and next_v_pred
states = batch_states([b["state"] for b in dataset], device, phi)
next_states = batch_states([b["next_state"] for b in dataset], device, phi)
if obs_normalizer:
states = obs_normalizer(states, update=False)
next_states = obs_normalizer(next_states, update=False)
with torch.no_grad(), pfrl.utils.evaluating(model):
distribs, vs_pred = model(states)
_, next_vs_pred = model(next_states)
actions = torch.tensor([b["action"] for b in dataset], device=device)
log_probs = distribs.log_prob(actions).cpu().numpy()
vs_pred = vs_pred.cpu().numpy().ravel()
next_vs_pred = next_vs_pred.cpu().numpy().ravel()
for transition, log_prob, v_pred, next_v_pred in zip(
dataset, log_probs, vs_pred, next_vs_pred
):
transition["log_prob"] = log_prob
transition["v_pred"] = v_pred
transition["next_v_pred"] = next_v_pred
def _limit_sequence_length(sequences, max_len):
assert max_len > 0
new_sequences = []
for sequence in sequences:
while len(sequence) > max_len:
new_sequences.append(sequence[:max_len])
sequence = sequence[max_len:]
assert 0 < len(sequence) <= max_len
new_sequences.append(sequence)
return new_sequences
def _yield_subset_of_sequences_with_fixed_number_of_items(sequences, n_items):
assert n_items > 0
stack = list(reversed(sequences))
while stack:
subset = []
count = 0
while count < n_items and stack:
sequence = stack.pop()
subset.append(sequence)
count += len(sequence)
if count > n_items:
# Split last sequence
sequence_to_split = subset[-1]
n_exceeds = count - n_items
assert n_exceeds > 0
subset[-1] = sequence_to_split[:-n_exceeds]
stack.append(sequence_to_split[-n_exceeds:])
if sum(len(seq) for seq in subset) == n_items:
yield subset
else:
# This ends the while loop.
assert len(stack) == 0
def _compute_explained_variance(transitions):
"""Compute 1 - Var[return - v]/Var[return].
This function computes the fraction of variance that value predictions can
explain about returns.
"""
t = np.array([tr["v_teacher"] for tr in transitions])
y = np.array([tr["v_pred"] for tr in transitions])
vart = np.var(t)
if vart == 0:
return np.nan
else:
return float(1 - np.var(t - y) / vart)
def _make_dataset_recurrent(
episodes,
model,
phi,
batch_states,
obs_normalizer,
gamma,
lambd,
max_recurrent_sequence_len,
device,
):
"""Make a list of sequences with necessary information."""
_add_log_prob_and_value_to_episodes_recurrent(
episodes=episodes,
model=model,
phi=phi,
batch_states=batch_states,
obs_normalizer=obs_normalizer,
device=device,
)
_add_advantage_and_value_target_to_episodes(episodes, gamma=gamma, lambd=lambd)
if max_recurrent_sequence_len is not None:
dataset = _limit_sequence_length(episodes, max_recurrent_sequence_len)
else:
dataset = list(episodes)
return dataset
def _make_dataset(
episodes, model, phi, batch_states, obs_normalizer, gamma, lambd, device
):
"""Make a list of transitions with necessary information."""
_add_log_prob_and_value_to_episodes(
episodes=episodes,
model=model,
phi=phi,
batch_states=batch_states,
obs_normalizer=obs_normalizer,
device=device,
)
_add_advantage_and_value_target_to_episodes(episodes, gamma=gamma, lambd=lambd)
return list(itertools.chain.from_iterable(episodes))
def _yield_minibatches(dataset, minibatch_size, num_epochs):
assert dataset
buf = []
n = 0
while n < len(dataset) * num_epochs:
while len(buf) < minibatch_size:
buf = random.sample(dataset, k=len(dataset)) + buf
assert len(buf) >= minibatch_size
yield buf[-minibatch_size:]
n += minibatch_size
buf = buf[:-minibatch_size]
[docs]class PPO(agent.AttributeSavingMixin, agent.BatchAgent):
"""Proximal Policy Optimization
See https://arxiv.org/abs/1707.06347
Args:
model (torch.nn.Module): Model to train (including recurrent models)
state s |-> (pi(s, _), v(s))
optimizer (torch.optim.Optimizer): Optimizer used to train the model
gpu (int): GPU device id if not None nor negative
gamma (float): Discount factor [0, 1]
lambd (float): Lambda-return factor [0, 1]
phi (callable): Feature extractor function
value_func_coef (float): Weight coefficient for loss of
value function (0, inf)
entropy_coef (float): Weight coefficient for entropy bonus [0, inf)
update_interval (int): Model update interval in step
minibatch_size (int): Minibatch size
epochs (int): Training epochs in an update
clip_eps (float): Epsilon for pessimistic clipping of likelihood ratio
to update policy
clip_eps_vf (float): Epsilon for pessimistic clipping of value
to update value function. If it is ``None``, value function is not
clipped on updates.
standardize_advantages (bool): Use standardized advantages on updates
recurrent (bool): If set to True, `model` is assumed to implement
`pfrl.nn.Recurrent` and update in a recurrent
manner.
max_recurrent_sequence_len (int): Maximum length of consecutive
sequences of transitions in a minibatch for updating the model.
This value is used only when `recurrent` is True. A smaller value
will encourage a minibatch to contain more and shorter sequences.
act_deterministically (bool): If set to True, choose most probable
actions in the act method instead of sampling from distributions.
max_grad_norm (float or None): Maximum L2 norm of the gradient used for
gradient clipping. If set to None, the gradient is not clipped.
value_stats_window (int): Window size used to compute statistics
of value predictions.
entropy_stats_window (int): Window size used to compute statistics
of entropy of action distributions.
value_loss_stats_window (int): Window size used to compute statistics
of loss values regarding the value function.
policy_loss_stats_window (int): Window size used to compute statistics
of loss values regarding the policy.
Statistics:
average_value: Average of value predictions on non-terminal states.
It's updated on (batch_)act_and_train.
average_entropy: Average of entropy of action distributions on
non-terminal states. It's updated on (batch_)act_and_train.
average_value_loss: Average of losses regarding the value function.
It's updated after the model is updated.
average_policy_loss: Average of losses regarding the policy.
It's updated after the model is updated.
n_updates: Number of model updates so far.
explained_variance: Explained variance computed from the last batch.
"""
saved_attributes = ("model", "optimizer", "obs_normalizer")
def __init__(
self,
model,
optimizer,
obs_normalizer=None,
gpu=None,
gamma=0.99,
lambd=0.95,
phi=lambda x: x,
value_func_coef=1.0,
entropy_coef=0.01,
update_interval=2048,
minibatch_size=64,
epochs=10,
clip_eps=0.2,
clip_eps_vf=None,
standardize_advantages=True,
batch_states=batch_states,
recurrent=False,
max_recurrent_sequence_len=None,
act_deterministically=False,
max_grad_norm=None,
value_stats_window=1000,
entropy_stats_window=1000,
value_loss_stats_window=100,
policy_loss_stats_window=100,
):
self.model = model
self.optimizer = optimizer
self.obs_normalizer = obs_normalizer
if gpu is not None and gpu >= 0:
assert torch.cuda.is_available()
self.device = torch.device("cuda:{}".format(gpu))
self.model.to(self.device)
if self.obs_normalizer is not None:
self.obs_normalizer.to(self.device)
else:
self.device = torch.device("cpu")
self.gamma = gamma
self.lambd = lambd
self.phi = phi
self.value_func_coef = value_func_coef
self.entropy_coef = entropy_coef
self.update_interval = update_interval
self.minibatch_size = minibatch_size
self.epochs = epochs
self.clip_eps = clip_eps
self.clip_eps_vf = clip_eps_vf
self.standardize_advantages = standardize_advantages
self.batch_states = batch_states
self.recurrent = recurrent
self.max_recurrent_sequence_len = max_recurrent_sequence_len
self.act_deterministically = act_deterministically
self.max_grad_norm = max_grad_norm
# Contains episodes used for next update iteration
self.memory = []
# Contains transitions of the last episode not moved to self.memory yet
self.last_episode = []
self.last_state = None
self.last_action = None
# Batch versions of last_episode, last_state, and last_action
self.batch_last_episode = None
self.batch_last_state = None
self.batch_last_action = None
# Recurrent states of the model
self.train_recurrent_states = None
self.train_prev_recurrent_states = None
self.test_recurrent_states = None
self.value_record = collections.deque(maxlen=value_stats_window)
self.entropy_record = collections.deque(maxlen=entropy_stats_window)
self.value_loss_record = collections.deque(maxlen=value_loss_stats_window)
self.policy_loss_record = collections.deque(maxlen=policy_loss_stats_window)
self.explained_variance = np.nan
self.n_updates = 0
def _initialize_batch_variables(self, num_envs):
self.batch_last_episode = [[] for _ in range(num_envs)]
self.batch_last_state = [None] * num_envs
self.batch_last_action = [None] * num_envs
def _update_if_dataset_is_ready(self):
dataset_size = (
sum(len(episode) for episode in self.memory)
+ len(self.last_episode)
+ (
0
if self.batch_last_episode is None
else sum(len(episode) for episode in self.batch_last_episode)
)
)
if dataset_size >= self.update_interval:
self._flush_last_episode()
if self.recurrent:
dataset = _make_dataset_recurrent(
episodes=self.memory,
model=self.model,
phi=self.phi,
batch_states=self.batch_states,
obs_normalizer=self.obs_normalizer,
gamma=self.gamma,
lambd=self.lambd,
max_recurrent_sequence_len=self.max_recurrent_sequence_len,
device=self.device,
)
self._update_recurrent(dataset)
else:
dataset = _make_dataset(
episodes=self.memory,
model=self.model,
phi=self.phi,
batch_states=self.batch_states,
obs_normalizer=self.obs_normalizer,
gamma=self.gamma,
lambd=self.lambd,
device=self.device,
)
assert len(dataset) == dataset_size
self._update(dataset)
self.explained_variance = _compute_explained_variance(
list(itertools.chain.from_iterable(self.memory))
)
self.memory = []
def _flush_last_episode(self):
if self.last_episode:
self.memory.append(self.last_episode)
self.last_episode = []
if self.batch_last_episode:
for i, episode in enumerate(self.batch_last_episode):
if episode:
self.memory.append(episode)
self.batch_last_episode[i] = []
def _update_obs_normalizer(self, dataset):
assert self.obs_normalizer
states = self.batch_states([b["state"] for b in dataset], self.device, self.phi)
self.obs_normalizer.experience(states)
def _update(self, dataset):
"""Update both the policy and the value function."""
device = self.device
if self.obs_normalizer:
self._update_obs_normalizer(dataset)
assert "state" in dataset[0]
assert "v_teacher" in dataset[0]
if self.standardize_advantages:
all_advs = torch.tensor([b["adv"] for b in dataset], device=device)
std_advs, mean_advs = torch.std_mean(all_advs, unbiased=False)
for batch in _yield_minibatches(
dataset, minibatch_size=self.minibatch_size, num_epochs=self.epochs
):
states = self.batch_states(
[b["state"] for b in batch], self.device, self.phi
)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
actions = torch.tensor([b["action"] for b in batch], device=device)
distribs, vs_pred = self.model(states)
advs = torch.tensor(
[b["adv"] for b in batch], dtype=torch.float32, device=device
)
if self.standardize_advantages:
advs = (advs - mean_advs) / (std_advs + 1e-8)
log_probs_old = torch.tensor(
[b["log_prob"] for b in batch],
dtype=torch.float,
device=device,
)
vs_pred_old = torch.tensor(
[b["v_pred"] for b in batch],
dtype=torch.float,
device=device,
)
vs_teacher = torch.tensor(
[b["v_teacher"] for b in batch],
dtype=torch.float,
device=device,
)
# Same shape as vs_pred: (batch_size, 1)
vs_pred_old = vs_pred_old[..., None]
vs_teacher = vs_teacher[..., None]
self.model.zero_grad()
loss = self._lossfun(
distribs.entropy(),
vs_pred,
distribs.log_prob(actions),
vs_pred_old=vs_pred_old,
log_probs_old=log_probs_old,
advs=advs,
vs_teacher=vs_teacher,
)
loss.backward()
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.max_grad_norm
)
self.optimizer.step()
self.n_updates += 1
def _update_once_recurrent(self, episodes, mean_advs, std_advs):
assert std_advs is None or std_advs > 0
device = self.device
# Sort desc by lengths so that pack_sequence does not change the order
episodes = sorted(episodes, key=len, reverse=True)
flat_transitions = flatten_sequences_time_first(episodes)
# Prepare data for a recurrent model
seqs_states = []
for ep in episodes:
states = self.batch_states(
[transition["state"] for transition in ep],
self.device,
self.phi,
)
if self.obs_normalizer:
states = self.obs_normalizer(states, update=False)
seqs_states.append(states)
flat_actions = torch.tensor(
[transition["action"] for transition in flat_transitions],
device=device,
)
flat_advs = torch.tensor(
[transition["adv"] for transition in flat_transitions],
dtype=torch.float,
device=device,
)
if self.standardize_advantages:
flat_advs = (flat_advs - mean_advs) / (std_advs + 1e-8)
flat_log_probs_old = torch.tensor(
[transition["log_prob"] for transition in flat_transitions],
dtype=torch.float,
device=device,
)
flat_vs_pred_old = torch.tensor(
[[transition["v_pred"]] for transition in flat_transitions],
dtype=torch.float,
device=device,
)
flat_vs_teacher = torch.tensor(
[[transition["v_teacher"]] for transition in flat_transitions],
dtype=torch.float,
device=device,
)
with torch.no_grad(), pfrl.utils.evaluating(self.model):
rs = concatenate_recurrent_states(
[ep[0]["recurrent_state"] for ep in episodes]
)
(flat_distribs, flat_vs_pred), _ = pack_and_forward(self.model, seqs_states, rs)
flat_log_probs = flat_distribs.log_prob(flat_actions)
flat_entropy = flat_distribs.entropy()
self.model.zero_grad()
loss = self._lossfun(
entropy=flat_entropy,
vs_pred=flat_vs_pred,
log_probs=flat_log_probs,
vs_pred_old=flat_vs_pred_old,
log_probs_old=flat_log_probs_old,
advs=flat_advs,
vs_teacher=flat_vs_teacher,
)
loss.backward()
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
self.n_updates += 1
def _update_recurrent(self, dataset):
"""Update both the policy and the value function."""
device = self.device
flat_dataset = list(itertools.chain.from_iterable(dataset))
if self.obs_normalizer:
self._update_obs_normalizer(flat_dataset)
assert "state" in flat_dataset[0]
assert "v_teacher" in flat_dataset[0]
if self.standardize_advantages:
all_advs = torch.tensor([b["adv"] for b in flat_dataset], device=device)
std_advs, mean_advs = torch.std_mean(all_advs, unbiased=False)
else:
mean_advs = None
std_advs = None
for _ in range(self.epochs):
random.shuffle(dataset)
for minibatch in _yield_subset_of_sequences_with_fixed_number_of_items(
dataset, self.minibatch_size
):
self._update_once_recurrent(minibatch, mean_advs, std_advs)
def _lossfun(
self, entropy, vs_pred, log_probs, vs_pred_old, log_probs_old, advs, vs_teacher
):
prob_ratio = torch.exp(log_probs - log_probs_old)
loss_policy = -torch.mean(
torch.min(
prob_ratio * advs,
torch.clamp(prob_ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advs,
),
)
if self.clip_eps_vf is None:
loss_value_func = F.mse_loss(vs_pred, vs_teacher)
else:
clipped_vs_pred = _elementwise_clip(
vs_pred,
vs_pred_old - self.clip_eps_vf,
vs_pred_old + self.clip_eps_vf,
)
loss_value_func = torch.mean(
torch.max(
F.mse_loss(vs_pred, vs_teacher, reduction="none"),
F.mse_loss(clipped_vs_pred, vs_teacher, reduction="none"),
)
)
loss_entropy = -torch.mean(entropy)
self.value_loss_record.append(float(loss_value_func))
self.policy_loss_record.append(float(loss_policy))
loss = (
loss_policy
+ self.value_func_coef * loss_value_func
+ self.entropy_coef * loss_entropy
)
return loss
def batch_act(self, batch_obs):
if self.training:
return self._batch_act_train(batch_obs)
else:
return self._batch_act_eval(batch_obs)
def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
if self.training:
self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset)
else:
self._batch_observe_eval(batch_obs, batch_reward, batch_done, batch_reset)
def _batch_act_eval(self, batch_obs):
assert not self.training
b_state = self.batch_states(batch_obs, self.device, self.phi)
if self.obs_normalizer:
b_state = self.obs_normalizer(b_state, update=False)
with torch.no_grad(), pfrl.utils.evaluating(self.model):
if self.recurrent:
(action_distrib, _), self.test_recurrent_states = one_step_forward(
self.model, b_state, self.test_recurrent_states
)
else:
action_distrib, _ = self.model(b_state)
if self.act_deterministically:
action = mode_of_distribution(action_distrib).cpu().numpy()
else:
action = action_distrib.sample().cpu().numpy()
return action
def _batch_act_train(self, batch_obs):
assert self.training
b_state = self.batch_states(batch_obs, self.device, self.phi)
if self.obs_normalizer:
b_state = self.obs_normalizer(b_state, update=False)
num_envs = len(batch_obs)
if self.batch_last_episode is None:
self._initialize_batch_variables(num_envs)
assert len(self.batch_last_episode) == num_envs
assert len(self.batch_last_state) == num_envs
assert len(self.batch_last_action) == num_envs
# action_distrib will be recomputed when computing gradients
with torch.no_grad(), pfrl.utils.evaluating(self.model):
if self.recurrent:
assert self.train_prev_recurrent_states is None
self.train_prev_recurrent_states = self.train_recurrent_states
(
(action_distrib, batch_value),
self.train_recurrent_states,
) = one_step_forward(
self.model, b_state, self.train_prev_recurrent_states
)
else:
action_distrib, batch_value = self.model(b_state)
batch_action = action_distrib.sample().cpu().numpy()
self.entropy_record.extend(action_distrib.entropy().cpu().numpy())
self.value_record.extend(batch_value.cpu().numpy())
self.batch_last_state = list(batch_obs)
self.batch_last_action = list(batch_action)
return batch_action
def _batch_observe_eval(self, batch_obs, batch_reward, batch_done, batch_reset):
assert not self.training
if self.recurrent:
# Reset recurrent states when episodes end
indices_that_ended = [
i
for i, (done, reset) in enumerate(zip(batch_done, batch_reset))
if done or reset
]
if indices_that_ended:
self.test_recurrent_states = mask_recurrent_state_at(
self.test_recurrent_states, indices_that_ended
)
def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset):
assert self.training
for i, (state, action, reward, next_state, done, reset) in enumerate(
zip(
self.batch_last_state,
self.batch_last_action,
batch_reward,
batch_obs,
batch_done,
batch_reset,
)
):
if state is not None:
assert action is not None
transition = {
"state": state,
"action": action,
"reward": reward,
"next_state": next_state,
"nonterminal": 0.0 if done else 1.0,
}
if self.recurrent:
transition["recurrent_state"] = get_recurrent_state_at(
self.train_prev_recurrent_states, i, detach=True
)
transition["next_recurrent_state"] = get_recurrent_state_at(
self.train_recurrent_states, i, detach=True
)
self.batch_last_episode[i].append(transition)
if done or reset:
assert self.batch_last_episode[i]
self.memory.append(self.batch_last_episode[i])
self.batch_last_episode[i] = []
self.batch_last_state[i] = None
self.batch_last_action[i] = None
self.train_prev_recurrent_states = None
if self.recurrent:
# Reset recurrent states when episodes end
indices_that_ended = [
i
for i, (done, reset) in enumerate(zip(batch_done, batch_reset))
if done or reset
]
if indices_that_ended:
self.train_recurrent_states = mask_recurrent_state_at(
self.train_recurrent_states, indices_that_ended
)
self._update_if_dataset_is_ready()
def get_statistics(self):
return [
("average_value", _mean_or_nan(self.value_record)),
("average_entropy", _mean_or_nan(self.entropy_record)),
("average_value_loss", _mean_or_nan(self.value_loss_record)),
("average_policy_loss", _mean_or_nan(self.policy_loss_record)),
("n_updates", self.n_updates),
("explained_variance", self.explained_variance),
]