from typing import List, Dict, Any, Tuple, Union, Optional
from collections import namedtuple, deque
import sys
import os
import copy
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import Normal, Independent
from ding.torch_utils import Adam, to_device
from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
from ding.model import model_wrap
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate, default_decollate
from .base_policy import Policy
from .common_utils import default_preprocess_learn
[docs]@POLICY_REGISTRY.register('sac')
class SACPolicy(Policy):
r"""
Overview:
Policy class of SAC algorithm.
Config:
== ==================== ======== ============= ================================= =======================
ID Symbol Type Default Value Description Other(Shape)
== ==================== ======== ============= ================================= =======================
1 ``type`` str td3 | RL policy register name, refer | this arg is optional,
| to registry ``POLICY_REGISTRY`` | a placeholder
2 ``cuda`` bool True | Whether to use cuda for network |
3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for
| ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/
| | buffer when training starts. | TD3.
4 | ``model.policy_`` int 256 | Linear layer size for policy |
| ``embedding_size`` | network. |
5 | ``model.soft_q_`` int 256 | Linear layer size for soft q |
| ``embedding_size`` | network. |
6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when
| ``embedding_size`` | network. | model.value_network
| | | is False.
7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when
| ``_rate_q`` | network. | model.value_network
| | | is True.
8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when
| ``_rate_policy`` | network. | model.value_network
| | | is True.
9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when
| ``_rate_value`` | network. | model.value_network
| | | is False.
10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali-
| | coefficient. | zation for auto
| | | `\alpha`, when
| | | auto_alpha is True
11 | ``learn.repara_`` bool True | Determine whether to use |
| ``meterization`` | reparameterization trick. |
12 | ``learn.`` bool False | Determine whether to use | Temperature parameter
| ``auto_alpha`` | auto temperature parameter | determines the
| | `\alpha`. | relative importance
| | | of the entropy term
| | | against the reward.
13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
| ``ignore_done`` | done flag. | in halfcheetah env.
14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
| ``target_theta`` | target network. | factor in polyak aver
| | | aging for target
| | | networks.
== ==================== ======== ============= ================================= =======================
"""
config = dict(
# (str) RL policy register name (refer to function "POLICY_REGISTRY").
type='sac',
# (bool) Whether to use cuda for network.
cuda=False,
# (bool type) on_policy: Determine whether on-policy or off-policy.
# on-policy setting influences the behaviour of buffer.
# Default False in SAC.
on_policy=False,
# (bool type) priority: Determine whether to use priority in buffer sample.
# Default False in SAC.
priority=False,
# (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
priority_IS_weight=False,
# (int) Number of training samples(randomly collected) in replay buffer when training starts.
# Default 10000 in SAC.
random_collect_size=10000,
model=dict(
# (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation.
# Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one .
# Default to True.
twin_critic=True,
# (bool type) value_network: Determine whether to use value network as the
# original SAC paper (arXiv 1801.01290).
# using value_network needs to set learning_rate_value, learning_rate_q,
# and learning_rate_policy in `cfg.policy.learn`.
# Default to False.
# value_network=False,
actor_head_type='reparameterization',
),
learn=dict(
# (bool) Whether to use multi gpu
multi_gpu=False,
# How many updates(iterations) to train after collector's one collection.
# Bigger "update_per_collect" means bigger off-policy.
# collect data -> update policy-> collect data -> ...
update_per_collect=1,
# (int) Minibatch size for gradient descent.
batch_size=256,
# (float type) learning_rate_q: Learning rate for soft q network.
# Default to 3e-4.
# Please set to 1e-3, when model.value_network is True.
learning_rate_q=3e-4,
# (float type) learning_rate_policy: Learning rate for policy network.
# Default to 3e-4.
# Please set to 1e-3, when model.value_network is True.
learning_rate_policy=3e-4,
# (float type) learning_rate_value: Learning rate for value network.
# `learning_rate_value` should be initialized, when model.value_network is True.
# Please set to 3e-4, when model.value_network is True.
learning_rate_value=3e-4,
# (float type) learning_rate_alpha: Learning rate for auto temperature parameter `\alpha`.
# Default to 3e-4.
learning_rate_alpha=3e-4,
# (float type) target_theta: Used for soft update of the target network,
# aka. Interpolation factor in polyak averaging for target networks.
# Default to 0.005.
target_theta=0.005,
# (float) discount factor for the discounted sum of rewards, aka. gamma.
discount_factor=0.99,
# (float type) alpha: Entropy regularization coefficient.
# Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
# If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`.
# Default to 0.2.
alpha=0.2,
# (bool type) auto_alpha: Determine whether to use auto temperature parameter `\alpha` .
# Temperature parameter determines the relative importance of the entropy term against the reward.
# Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
# Default to False.
# Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`.
auto_alpha=True,
# (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
# Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
# These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
# However, interaction with HalfCheetah always gets done with done is False,
# Since we inplace done==True with done==False to keep
# TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
# when the episode step is greater than max episode step.
ignore_done=False,
# (float) Weight uniform initialization range in the last output layer
init_w=3e-3,
),
collect=dict(
# You can use either "n_sample" or "n_episode" in actor.collect.
# Get "n_sample" samples per collect.
# Default n_sample to 1.
n_sample=1,
# (int) Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
eval=dict(),
other=dict(
replay_buffer=dict(
# (int type) replay_buffer_size: Max size of replay buffer.
replay_buffer_size=1000000,
# (int type) max_use: Max use times of one data in the buffer.
# Data will be removed once used for too many times.
# Default to infinite.
# max_use=256,
),
),
)
r"""
Overview:
Policy class of SAC algorithm.
"""
def _init_learn(self) -> None:
r"""
Overview:
Learn mode init method. Called by ``self.__init__``.
Init q, value and policy's optimizers, algorithm config, main and target models.
"""
# Init
self._priority = self._cfg.priority
self._priority_IS_weight = self._cfg.priority_IS_weight
self._value_network = False # TODO self._cfg.model.value_network
self._twin_critic = self._cfg.model.twin_critic
# Weight Init
init_w = self._cfg.learn.init_w
self._model.actor[2].mu.weight.data.uniform_(-init_w, init_w)
self._model.actor[2].mu.bias.data.uniform_(-init_w, init_w)
self._model.actor[2].log_sigma_layer.weight.data.uniform_(-init_w, init_w)
self._model.actor[2].log_sigma_layer.bias.data.uniform_(-init_w, init_w)
if self._twin_critic:
self._model.critic[0][2].last.weight.data.uniform_(-init_w, init_w)
self._model.critic[0][2].last.bias.data.uniform_(-init_w, init_w)
self._model.critic[1][2].last.weight.data.uniform_(-init_w, init_w)
self._model.critic[1][2].last.bias.data.uniform_(-init_w, init_w)
else:
self._model.critic[2].last.weight.data.uniform_(-init_w, init_w)
self._model.critic[2].last.bias.data.uniform_(-init_w, init_w)
# Optimizers
if self._value_network:
self._optimizer_value = Adam(
self._model.value_critic.parameters(),
lr=self._cfg.learn.learning_rate_value,
)
self._optimizer_q = Adam(
self._model.critic.parameters(),
lr=self._cfg.learn.learning_rate_q,
)
self._optimizer_policy = Adam(
self._model.actor.parameters(),
lr=self._cfg.learn.learning_rate_policy,
)
# Algorithm config
self._gamma = self._cfg.learn.discount_factor
# Init auto alpha
if self._cfg.learn.auto_alpha:
self._target_entropy = -np.prod(self._cfg.model.action_shape)
self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha]))
self._log_alpha = self._log_alpha.to(self._device).requires_grad_()
self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)
self._auto_alpha = True
assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad
self._alpha = self._log_alpha.detach().exp()
else:
self._alpha = torch.tensor(
[self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32
)
self._auto_alpha = False
# Main and target models
self._target_model = copy.deepcopy(self._model)
self._target_model = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.learn.target_theta}
)
self._learn_model = model_wrap(self._model, wrapper_name='base')
self._learn_model.reset()
self._target_model.reset()
self._forward_learn_cnt = 0
def _forward_learn(self, data: dict) -> Dict[str, Any]:
r"""
Overview:
Forward and backward function of learn mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
Returns:
- info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
"""
loss_dict = {}
data = default_preprocess_learn(
data,
use_priority=self._priority,
use_priority_IS_weight=self._cfg.priority_IS_weight,
ignore_done=self._cfg.learn.ignore_done,
use_nstep=False
)
if self._cuda:
data = to_device(data, self._device)
self._learn_model.train()
self._target_model.train()
obs = data.get('obs')
next_obs = data.get('next_obs')
reward = data.get('reward')
action = data.get('action')
done = data.get('done')
# predict q value
q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
# predict target value depend self._value_network.
if self._value_network:
# predict v value
v_value = self._learn_model.forward(obs, mode='compute_value_critic')['v_value']
with torch.no_grad():
next_v_value = self._target_model.forward(next_obs, mode='compute_value_critic')['v_value']
else:
# target q value. SARSA: first predict next action, then calculate next q value
with torch.no_grad():
(mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit']
dist = Independent(Normal(mu, sigma), 1)
pred = dist.rsample()
next_action = torch.tanh(pred)
y = 1 - next_action.pow(2) + 1e-6
next_log_prob = dist.log_prob(pred).unsqueeze(-1)
next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True)
next_data = {'obs': next_obs, 'action': next_action}
target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
# the value of a policy according to the maximum entropy objective
if self._twin_critic:
# find min one as target q value
target_q_value = torch.min(target_q_value[0],
target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1)
else:
target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1)
target_value = next_v_value if self._value_network else target_q_value
# =================
# q network
# =================
# compute q loss
if self._twin_critic:
q_data0 = v_1step_td_data(q_value[0], target_value, reward, done, data['weight'])
loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
q_data1 = v_1step_td_data(q_value[1], target_value, reward, done, data['weight'])
loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)
td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
else:
q_data = v_1step_td_data(q_value, target_value, reward, done, data['weight'])
loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma)
# update q network
self._optimizer_q.zero_grad()
loss_dict['critic_loss'].backward()
if self._twin_critic:
loss_dict['twin_critic_loss'].backward()
self._optimizer_q.step()
# evaluate to get action distribution
(mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit']
dist = Independent(Normal(mu, sigma), 1)
pred = dist.rsample()
action = torch.tanh(pred)
y = 1 - action.pow(2) + 1e-6
log_prob = dist.log_prob(pred).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
eval_data = {'obs': obs, 'action': action}
new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
if self._twin_critic:
new_q_value = torch.min(new_q_value[0], new_q_value[1])
# =================
# value network
# =================
# compute value loss
if self._value_network:
# new_q_value: (bs, ), log_prob: (bs, act_shape) -> target_v_value: (bs, )
target_v_value = (new_q_value.unsqueeze(-1) - self._alpha * log_prob).mean(dim=-1)
loss_dict['value_loss'] = F.mse_loss(v_value, target_v_value.detach())
# update value network
self._optimizer_value.zero_grad()
loss_dict['value_loss'].backward()
self._optimizer_value.step()
# =================
# policy network
# =================
# compute policy loss
policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean()
loss_dict['policy_loss'] = policy_loss
# update policy network
self._optimizer_policy.zero_grad()
loss_dict['policy_loss'].backward()
self._optimizer_policy.step()
# compute alpha loss
if self._auto_alpha:
log_prob = log_prob.detach() + self._target_entropy
loss_dict['alpha_loss'] = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
loss_dict['alpha_loss'].backward()
self._alpha_optim.step()
self._alpha = self._log_alpha.detach().exp()
loss_dict['total_loss'] = sum(loss_dict.values())
info_dict = {}
if self._value_network:
info_dict['cur_lr_v'] = self._optimizer_value.defaults['lr']
# =============
# after update
# =============
self._forward_learn_cnt += 1
# target update
self._target_model.update(self._learn_model.state_dict())
return {
'cur_lr_q': self._optimizer_q.defaults['lr'],
'cur_lr_p': self._optimizer_policy.defaults['lr'],
'priority': td_error_per_sample.abs().tolist(),
'td_error': td_error_per_sample.detach().mean().item(),
'alpha': self._alpha.item(),
'target_value': target_value.detach().mean().item(),
**info_dict,
**loss_dict
}
def _state_dict_learn(self) -> Dict[str, Any]:
ret = {
'model': self._learn_model.state_dict(),
'optimizer_q': self._optimizer_q.state_dict(),
'optimizer_policy': self._optimizer_policy.state_dict(),
}
if self._value_network:
ret.update({'optimizer_value': self._optimizer_value.state_dict()})
if self._auto_alpha:
ret.update({'optimizer_alpha': self._alpha_optim.state_dict()})
return ret
def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
self._learn_model.load_state_dict(state_dict['model'])
self._optimizer_q.load_state_dict(state_dict['optimizer_q'])
self._optimizer_value.load_state_dict(state_dict['optimizer_value'])
self._optimizer_policy.load_state_dict(state_dict['optimizer_policy'])
if self._auto_alpha:
self._alpha_optim.load_state_dict(state_dict['optimizer_alpha'])
def _init_collect(self) -> None:
r"""
Overview:
Collect mode init method. Called by ``self.__init__``.
Init traj and unroll length, collect model.
Use action noise for exploration.
"""
self._unroll_len = self._cfg.collect.unroll_len
# TODO remove noise
# self._collect_model = model_wrap(
# self._model,
# wrapper_name='action_noise',
# noise_type='gauss',
# noise_kwargs={
# 'mu': 0.0,
# 'sigma': self._cfg.collect.noise_sigma
# },
# noise_range=None
# )
self._collect_model = model_wrap(self._model, wrapper_name='base')
self._collect_model.reset()
def _forward_collect(self, data: dict) -> dict:
r"""
Overview:
Forward function of collect mode.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs'].
Returns:
- output (:obj:`dict`): Dict type data, including at least inferred action according to input obs.
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._collect_model.eval()
with torch.no_grad():
(mu, sigma) = self._collect_model.forward(data, mode='compute_actor')['logit']
dist = Independent(Normal(mu, sigma), 1)
action = torch.tanh(dist.rsample())
output = {'logit': (mu, sigma), 'action': action}
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
r"""
Overview:
Generate dict type transition data from inputs.
Arguments:
- obs (:obj:`Any`): Env observation
- model_output (:obj:`dict`): Output of collect model, including at least ['action']
- timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
(here 'obs' indicates obs after env step, i.e. next_obs).
Return:
- transition (:obj:`Dict[str, Any]`): Dict type transition data.
"""
transition = {
'obs': obs,
'next_obs': timestep.obs,
'action': model_output['action'],
'reward': timestep.reward,
'done': timestep.done,
}
return transition
def _get_train_sample(self, data: deque) -> Union[None, List[Any]]:
return get_train_sample(data, self._unroll_len)
def _init_eval(self) -> None:
r"""
Overview:
Evaluate mode init method. Called by ``self.__init__``.
Init eval model. Unlike learn and collect model, eval model does not need noise.
"""
self._eval_model = model_wrap(self._model, wrapper_name='base')
self._eval_model.reset()
def _forward_eval(self, data: dict) -> dict:
r"""
Overview:
Forward function for eval mode, similar to ``self._forward_collect``.
Arguments:
- data (:obj:`dict`): Dict type data, including at least ['obs'].
Returns:
- output (:obj:`dict`): Dict type data, including at least inferred action according to input obs.
"""
data_id = list(data.keys())
data = default_collate(list(data.values()))
if self._cuda:
data = to_device(data, self._device)
self._eval_model.eval()
with torch.no_grad():
(mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit']
action = torch.tanh(mu) # deterministic_eval
output = {'action': action}
if self._cuda:
output = to_device(output, 'cpu')
output = default_decollate(output)
return {i: d for i, d in zip(data_id, output)}
def default_model(self) -> Tuple[str, List[str]]:
return 'qac', ['ding.model.template.qac']
def _monitor_vars_learn(self) -> List[str]:
r"""
Overview:
Return variables' name if variables are to used in monitor.
Returns:
- vars (:obj:`List[str]`): Variables' name list.
"""
twin_critic = ['twin_critic_loss'] if self._twin_critic else []
if self._auto_alpha:
return super()._monitor_vars_learn() + [
'alpha_loss', 'policy_loss', 'critic_loss', 'cur_lr_q', 'cur_lr_p', 'target_q_value', 'q_value_1',
'q_value_2', 'alpha', 'td_error', 'target_value'
] + twin_critic
else:
return super()._monitor_vars_learn() + [
'policy_loss', 'critic_loss', 'cur_lr_q', 'cur_lr_p', 'target_q_value', 'q_value_1', 'q_value_2',
'alpha', 'td_error', 'target_value'
] + twin_critic