PyTorch - Optimizer

Summary of Optimizer

Adam family#

Algorithm of Adam:

gtθft(θt1)mtβ1mt1+(1β1)gtvtβ2vt1+(1β2)gt2m^tmt1β1t, v^tvt1β2tθtθt1ηmt^vt^+ϵ\begin{aligned}
g_{t} &\gets \nabla_{\theta} f_t(\theta_{t-1})\\
m_{t} &\gets \beta_{1} m_{t-1} + (1-\beta_{1}) g_t\\
v_{t} &\gets \beta_{2} v_{t-1} + (1-\beta_{2}) g_t^2\\
\widehat{m}_{t} &\gets \frac{m_{t}}{1-\beta_1^t},\ \widehat{v}_{t} \gets \frac{v_{t}}{1-\beta_2^t}\\
\theta_{t} &\gets \theta_{t-1} - \eta\frac{\widehat{m_t}}{\sqrt{\widehat{v_t}} + \epsilon}
\end{aligned}
  • mt^\widehat{m_t}: the normalized state moment
  • vt^\widehat{v_t}: the normalized state second moment (i.e., the squared gradient);
    • an estimation of the uncentered variance of the gradients.

Uncentered variance is a statistical measure that describes the spread of a set of values around the mean. In the context of optimization, the moving average of the second moments is used to compute the uncentered variance in the Adam optimization algorithm.

Note: Not implemented bias_correction for easily read.
If with bias_correction as follows:

# bias_correction_ = (1 - beta_^t)
m_t_hat = m_t / (1 - beta1^t)
v_t_hat = v_t / (1 - beta2^t)

More detailed in a function _single_tensor_adam of PyTorch Docs

Adam#

Equation:

gtθft(θt1)+λθt1, if with L2 regularization\begin{aligned}
g_{t} \gets \nabla_{\theta} f_t(\theta_{t-1}) + \lambda\theta_{t-1},\footnotesize\text{ if with \rm{L2} regularization}
\end{aligned}

Implementation code:

if weight_decay:
    gradient += weight_decay * param  # L2 Regularization
m_t = beta1 * m_{t-1} + (1 - beta1) * gradient
v_t = beta2 * v_{t-1} + (1 - beta2) * gradient^2
param = param - lr * m_t / (sqrt(v_t) + epsilon)

AdamW#

Equation:

θtθt1η(mt^vt^+ϵ+λθt1)\theta_{t} \gets \theta_{t-1} - \eta\left(\frac{\widehat{m_t}}{\sqrt{\widehat{v_t}} + \epsilon} + \lambda\theta_{t-1}\right)

Implementation code:

# the original paper's implementation
m_t = beta1 * m_{t-1} + (1 - beta1) * gradient
v_t = beta2 * v_{t-1} + (1 - beta2) * gradient^2
param = param - lr * (m_t / (sqrt(v_t) + epsilon) + weight_decay * param)

# PyTorch's implementation
param *= (1 - lr * weight_decay)
m_t = beta1 * m_{t-1} + (1 - beta1) * gradient
v_t = beta2 * v_{t-1} + (1 - beta2) * gradient^2
param = param - lr * m_t / (sqrt(v_t) + epsilon)

Defference of weight_decay between AdamW and Adam#

  • In Adam, the weight decay term is added to the gradient.
  • In AdamW, the weight decay term is added directly to the model parameters, and the momentum term is unaffected by weight decay.