netcl wiki
concepts

AdamW

AdamW

Status: Public API in netcl.optim.adamw.AdamW (re-exported from netcl.optim)

AdamW (Loshchilov and Hutter, 2019) is a variant of Adam that decouples weight decay from the gradient update. In the original Adam, weight decay is implemented as an L2 penalty added to the gradient before the moment updates, which interacts poorly with the adaptive learning rate. AdamW instead applies the decay directly to the parameter after the moment-based update:

m_t     = beta1 * m_{t-1} + (1 - beta1) * g_t
v_t     = beta2 * v_{t-1} + (1 - beta2) * g_t ** 2
m_hat   = m_t / (1 - beta1 ** t)
v_hat   = v_t / (1 - beta2 ** t)
theta_t = theta_{t-1} - lr * (m_hat / (sqrt(v_hat) + eps)
                              + weight_decay * theta_{t-1})

The result is a regularizer that shrinks weights at a rate proportional to the learning rate, not to the gradient scale. In practice this generalises better than the coupled L2 penalty and is the default for most modern training recipes (vision Transformers, language models, etc.).

Overview

AdamW is a drop-in replacement for Adam in the netcl API: the constructor signature is identical, the per-parameter state is allocated lazily on the first step, and the moment updates are unchanged. The only difference is the final line of step().

The two parameters that actually matter in practice are lr and weight_decay. The defaults are betas=(0.9, 0.999), eps=1e-8, which match the original paper.

Where It Lives

  • File path: optim/adamw.py.
  • Module path: netcl.optim.adamw.
  • Public re-export: from netcl.optim import AdamW.

How It Works

The implementation is a thin layer over Adam's _step_kernel. After the moment-based update theta -= lr * m_hat / (sqrt(v_hat) + eps), an extra line is emitted:

param[i] -= lr * weight_decay * param[i];

This is fused with the rest of the parameter update into a single kernel per parameter. There is no additional state.

Code Example

import netcl.optim as opt

optimizer = opt.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.95),     # common LLM defaults
    eps=1e-8,
    weight_decay=0.1,       # strong but not crushing
)

Performance & Trade-offs

  • Cost is identical to Adam (one extra fused multiply on the parameter).
  • The right weight_decay value depends on the model family. Vision CNNs are typically fine with 1e-4; vision Transformers prefer 0.1; language models prefer 0.1 to 0.5.
  • For very small datasets, the decoupled decay can over-shrink weights. In that case, either lower weight_decay or switch back to Adam with a coupled penalty.

See also

  • Adam — the moment-estimation algorithm.
  • AdamW — the API page.
  • SGD — when you want a non-adaptive optimizer.
  • AMP — keep the optimizer state in fp32.
  • AdamW — this article.