SGD

Plain Stochastic Gradient Descent:
Move the weight along the direction of the gradient w/respect to that parameter, multiplied by a learning rate coefficient.

wt+1=wtw*η w_{t+1} = w_t - \frac{\partial{}}{\partial{w}} * \eta

where ηη is the learning rate.

w -= lr * dw

SGD with Momentum

Update the weight according to the velocity, which is an exponential moving average of past gradients

wt+1=wt+ηvtvt=βvt1+(1β)(w) w_{t+1} = w_t + \eta v_{t} \\ v_{t} = \beta v_{t-1} + (1 - \beta) \left(-\frac{\partial{}}{\partial{w}}\right)

where ββ is a value between 0 (normal SGD) and 1 (perfect momentum, velocity never updates).
ββ = .995 is typical

Momemtum Decay

These have the momentum coefficient ββ decrease as a function of training iteration tt.
Decay functions include:

Nesterov’s

βt=13t+5 \beta_t = 1 - \dfrac{3}{t + 5}

Sutskever’s

βt=min(121log2([t250]+1),β0) \large \beta_t = \min(1-2^{-1 - \log_2([t/250] + 1)}, \beta_0)

Demon

βt=β0(1Tt)(1β0)+β0(1Tt) \beta_t = \frac{\beta_0(1 -\frac{T}{t})} {(1 - \beta_0) + \beta_0(1 -\frac{T}{t})}

where TT is the maximum timestep

Nesterov Momentum

aka Nesterov’s Accelerated Gradient or NAG

ϕt+1=wtw*ηwt+1=ϕt+1+β(ϕt+1ϕt) \phi_{t+1} = w_t - \frac{\partial{}}{\partial{w}} * \eta \\ w_{t+1} = \phi_{t+1} + \beta(\phi_{t+1} - \phi_t)

AdaGrad

Like SGD, but totals up the squared gradient for each parameter, and lowers the learning rate as the total accumulates.

wt+1=wtw*ηGt+ϵGt+1=Gt+(w)2 w_{t+1} = w_t - \frac{\partial{}}{\partial{w}} * \frac{\eta}{\sqrt{G_t} + \epsilon} \\ \quad \\ G_{t+1} = G_t + \left(\frac{\partial{}}{\partial{w}}\right)^2

where ϵϵ is a small value to prevent division by zero.

RMSProp

maintains an exp. moving average of the squared gradient for each param. unlike adagrad, it can continue to learn after a big early update.

wt+1=wtw*ηGt+ϵGt=(β)Gt1+(1β)(w)2 w_{t+1} = w_t - \frac{\partial{}}{\partial{w}} * \frac{\eta}{\sqrt{G_t} + \epsilon} \\ \quad \\ G_{t} = (\beta) G_{t-1} + (1 - \beta) \left(\frac{\partial{}}{\partial{w}}\right)^2

AdaDelta

similar to RMSProp, with unwarranted additional complexity…

wt+1=wt+vtVt=(β)Vt1+(1β)vt12Gt=(β)Gt1+(1β)(w)2vt=VtGt+ϵw \begin{aligned} w_{t+1} & = w_t + v_{t} \\ V_{t} & = (\beta) V_{t-1} + (1 - \beta) v_{t-1}^2 \\ G_{t} & = (\beta) G_{t-1} + (1 - \beta) \left(\frac{\partial{}}{\partial{w}}\right)^2 \\ v_t & = -\frac{\sqrt{V_t}}{\sqrt{G_t} + \epsilon}\frac{\partial{}}{\partial{w}} \end{aligned}

Adam

maintains both an EMA of the sq. gradient, and and EMA of the unsquared gradient for each param.

wt+1=wt+v^t*ηG^t+ϵvt=(β1)vt1+(1β1)(w)v^t=vt(1β1t+1)Gt=(β2)Gt1+(1β2)(w)2G^t=Gt(1β2t+1) \begin{aligned} w_{t+1} & = w_t + \frac{\hat{v}_t * \eta}{\sqrt{\hat{G}_t} + \epsilon} \\ v_{t} & = (\beta_1) v_{t-1} + (1 - \beta_1) \left(-\frac{\partial{}}{\partial{w}}\right) \\ \hat{v}_t & = v_t / (1 - {\beta_1}^{t+1}) \\ G_{t} & = (\beta_2) G_{t-1} + (1 - \beta_2) \left(\frac{\partial{}}{\partial{w}}\right)^2 \\ \hat{G}_t & = G_t / (1 - {\beta_2}^{t+1}) \end{aligned}

with reasonable hyperparams: η=.001β1=.9β2=.999η=10e8 \eta = .001 \\ \beta_1 = .9 \\ \beta_2 = .999 \\ \eta = 10e-8

AdamW

Adam with weight decay.

AdaFactor

Doesn’t use L1 momentum, and stores only a factorized version of L2 momentum. Takes much less VRAM to train.

αt=max(ϵ2,RMS(xt1))ptGt=ft(Xt1)Rt=(βt)Rt1+(1βt)(Gt2+ϵ11n1m)1mCt=(βt)Ct1+(1βt)1n(Gt2+ϵ11n1m)Vt=RtCt1nRtUt=GtvtU^t=Utmax(1,RMS(Ut)d)Xt=Xt1αtU^t \large \begin{aligned} \alpha_t & = \max(\epsilon_2, \text{RMS}(x_{t-1}))p_t \\ G_t & = \nabla f_t(X_{t-1}) \\ R_t & = (\beta_t)R_{t-1} + (1 - \beta_t)(G^2_t + \epsilon_1 1_n 1_m^\top)1_m \\ C_t & = (\beta_t)C_{t-1} + (1 - \beta_t)1_n^\top(G^2_t + \epsilon_1 1_n 1_m^\top) \\ V_t & = \frac{R_t C_t}{1_n^\top R_t} \\ U_t & = \frac{G_t}{\sqrt{v_t}} \\ \hat U_t & = \frac{U_t}{\max(1,\text{RMS}(U_t)/d)} \\ X_t & = X_{t-1} - \alpha_t \hat U_t \end{aligned}

Proposed Hyperparams: ϵ1=1030ϵ2=103d=1pt=min(102,1t)β^t=1t0.8 \begin{aligned} \epsilon_1 & = 10^{-30} \\ \epsilon_2 & = 10^{-3} \\ d & = 1 \\ p_{t} & = \min\left(10^{-2}, \frac{1}{\sqrt{t}}\right) \\ \hat \beta_t & = 1 - t^{-0.8} \end{aligned}