netcl wiki
concepts

AMP

AMP

Status: Public API in netcl.amp (autocast, GradScaler, supports_fp16)

AMP (Automatic Mixed Precision) is the netcl module that makes half-precision training safe. It combines two mechanisms:

  • autocast — a context manager that downcasts fp32 tensors to fp16 for the forward and backward passes, while keeping the master weights in fp32.
  • GradScaler — a dynamic loss-scaling wrapper that prevents fp16 underflow during the backward pass. The scaler multiplies the loss by a large factor (init_scale = 2**16 by default), un-scales the gradients before the optimizer step, and grows or shrinks the scale based on whether the previous step saw an infinity or NaN.

The result is a training loop that runs in half-precision on the matmul / conv pipelines (where the speed-up is real) but maintains fp32 numerics for the parts where precision matters (the optimizer state, the batch-norm running statistics, the final loss).

Overview

The full AMP recipe, per step, is:

  1. Cast inputs to fp16 (in the autocast context).
  2. Run the forward in fp16; keep the weights in fp32 outside the autocast region, or let the autocast context cast on the fly.
  3. Compute the loss in fp32.
  4. scaler.scale_loss(loss).backward() — multiplies the loss by the current scale factor and runs the backward in fp16.
  5. scaler.unscale_grads(params) — divides the gradients by the scale factor; detects infinities or NaNs.
  6. optimizer.step() — only called if no inf / NaN was found.
  7. scaler.update() — grows the scale if no inf / NaN was found for growth_interval consecutive steps, shrinks it otherwise.

netcl's AMP module is in amp.py and is small — fewer than 200 lines. The heavy lifting is in the per-op autocast dispatch: each op knows whether it has a half-precision kernel and falls back to fp32 if not.

Where It Lives

  • File path: amp.py.
  • Module path: netcl.amp.
  • Public re-export: from netcl.amp import autocast, GradScaler, supports_fp16.

Diagram

How It Works

autocast

class autocast:
    def __init__(self, enabled=True, device_queue=None):
        self.enabled = enabled and supports_fp16(device_queue)
        ...
    def __enter__(self):
        global _AUTOCAST_ENABLED
        _AUTOCAST_ENABLED = self.enabled
    def __exit__(self, *args):
        global _AUTOCAST_ENABLED
        _AUTOCAST_ENABLED = False

The autocast context sets a module-global flag. Each op in netcl checks the flag in its dispatch; if autocast is on and the op has a half-precision kernel, the kernel is used. Otherwise the fp32 kernel runs.

GradScaler

@dataclass
class GradScaler:
    init_scale: float = 2.0**16
    growth_factor: float = 2.0
    backoff_factor: float = 0.5
    growth_interval: int = 2000
    enabled: bool = True

scale_loss(loss) returns loss * scale (as a new tensor; elementwise_binary is used for the multiplication so the result is a real device op, not a host-side multiply).

unscale_grads(params) walks the parameters, pulls each gradient back to the host, checks for infinities or NaNs with np.any(~np.isfinite(g)), and if everything is finite, divides each gradient by the scale factor (as a device op).

step(optimizer, params) is the convenience wrapper: unscale_grads first; if no inf / NaN was found, call optimizer.step() and grow the scale after growth_interval successful steps; otherwise skip the step and shrink the scale.

The update() method is a no-op kept for API compatibility with PyTorch's GradScaler; the actual scale update happens inside step().

Code Example

A typical AMP training step:

import netcl as nc
import netcl.optim as opt
from netcl.amp import autocast, GradScaler

scaler = GradScaler()
optimizer = opt.AdamW(model.parameters(), lr=3e-4)

for x, y in dataloader:
    optimizer.zero_grad()
    with autocast(device_queue=queue):
        logits = model(x)              # fp16 forward
        loss = nc.functional.cross_entropy(logits, y)
    scaler.scale_loss(loss).backward()  # scaled backward
    scaler.step(optimizer, model.parameters())

Detecting whether the device supports AMP at all:

from netcl.amp import supports_fp16
if not supports_fp16(queue):
    print("device does not expose cl_khr_fp16; running in fp32")

Performance & Trade-offs

  • AMP is a correctness wrapper, not a performance wrapper. You are guaranteed fp32-equivalent results up to the noise of the loss scaling; you may see a 1.5x to 2x speed-up on the matmul-heavy forward and a smaller speed-up elsewhere.
  • The default init_scale = 2**16 is a good starting point. If you see found_inf=True on every step, lower it; if you never see it and the loss is in the well-represented range, raise it to 2**20 for finer granularity.
  • Under AMP, the BatchNorm running statistics stay in fp32 even when the activations are fp16. Do not change this; the running statistics accumulate small per-step deltas that fp16 cannot represent.
  • AMP assumes a single GPU. For multi-replica training, use DistributedDataParallel with AMP; each replica runs its own GradScaler, and the gradients are all-reduced after unscale_grads (so the inf-check is local to each replica).

See also

  • AMP — the API page.
  • cl_khr_fp16 — the OpenCL extension AMP depends on.
  • fp16 — the format AMP casts to.
  • GradScaler — the dynamic scaling wrapper.
  • Adam — keep the optimizer state in fp32.
  • BatchNorm — keep the running statistics in fp32.
  • AMP — this article.