AMP
AMP
Status: Public API in
netcl.amp(autocast, GradScaler, supports_fp16)
AMP (Automatic Mixed Precision) is the netcl module that makes
half-precision training safe. It combines two mechanisms:
autocast— a context manager that downcasts fp32 tensors to fp16 for the forward and backward passes, while keeping the master weights in fp32.GradScaler— a dynamic loss-scaling wrapper that prevents fp16 underflow during the backward pass. The scaler multiplies the loss by a large factor (init_scale = 2**16by default), un-scales the gradients before the optimizer step, and grows or shrinks the scale based on whether the previous step saw an infinity or NaN.
The result is a training loop that runs in half-precision on the matmul / conv pipelines (where the speed-up is real) but maintains fp32 numerics for the parts where precision matters (the optimizer state, the batch-norm running statistics, the final loss).
Overview
The full AMP recipe, per step, is:
- Cast inputs to fp16 (in the autocast context).
- Run the forward in fp16; keep the weights in fp32 outside the autocast region, or let the autocast context cast on the fly.
- Compute the loss in fp32.
scaler.scale_loss(loss).backward()— multiplies the loss by the current scale factor and runs the backward in fp16.scaler.unscale_grads(params)— divides the gradients by the scale factor; detects infinities or NaNs.optimizer.step()— only called if no inf / NaN was found.scaler.update()— grows the scale if no inf / NaN was found forgrowth_intervalconsecutive steps, shrinks it otherwise.
netcl's AMP module is in amp.py and is small — fewer than 200
lines. The heavy lifting is in the per-op autocast dispatch: each
op knows whether it has a half-precision kernel and falls back to
fp32 if not.
Where It Lives
- File path:
amp.py. - Module path:
netcl.amp. - Public re-export:
from netcl.amp import autocast, GradScaler, supports_fp16.
Diagram
How It Works
autocast
class autocast:
def __init__(self, enabled=True, device_queue=None):
self.enabled = enabled and supports_fp16(device_queue)
...
def __enter__(self):
global _AUTOCAST_ENABLED
_AUTOCAST_ENABLED = self.enabled
def __exit__(self, *args):
global _AUTOCAST_ENABLED
_AUTOCAST_ENABLED = False
The autocast context sets a module-global flag. Each op in netcl checks the flag in its dispatch; if autocast is on and the op has a half-precision kernel, the kernel is used. Otherwise the fp32 kernel runs.
GradScaler
@dataclass
class GradScaler:
init_scale: float = 2.0**16
growth_factor: float = 2.0
backoff_factor: float = 0.5
growth_interval: int = 2000
enabled: bool = True
scale_loss(loss) returns loss * scale (as a new tensor;
elementwise_binary is used for the multiplication so the result
is a real device op, not a host-side multiply).
unscale_grads(params) walks the parameters, pulls each gradient
back to the host, checks for infinities or NaNs with
np.any(~np.isfinite(g)), and if everything is finite, divides
each gradient by the scale factor (as a device op).
step(optimizer, params) is the convenience wrapper:
unscale_grads first; if no inf / NaN was found, call
optimizer.step() and grow the scale after growth_interval
successful steps; otherwise skip the step and shrink the scale.
The update() method is a no-op kept for API compatibility with
PyTorch's GradScaler; the actual scale update happens inside
step().
Code Example
A typical AMP training step:
import netcl as nc
import netcl.optim as opt
from netcl.amp import autocast, GradScaler
scaler = GradScaler()
optimizer = opt.AdamW(model.parameters(), lr=3e-4)
for x, y in dataloader:
optimizer.zero_grad()
with autocast(device_queue=queue):
logits = model(x) # fp16 forward
loss = nc.functional.cross_entropy(logits, y)
scaler.scale_loss(loss).backward() # scaled backward
scaler.step(optimizer, model.parameters())
Detecting whether the device supports AMP at all:
from netcl.amp import supports_fp16
if not supports_fp16(queue):
print("device does not expose cl_khr_fp16; running in fp32")
Performance & Trade-offs
- AMP is a correctness wrapper, not a performance wrapper. You are guaranteed fp32-equivalent results up to the noise of the loss scaling; you may see a 1.5x to 2x speed-up on the matmul-heavy forward and a smaller speed-up elsewhere.
- The default
init_scale = 2**16is a good starting point. If you seefound_inf=Trueon every step, lower it; if you never see it and the loss is in the well-represented range, raise it to2**20for finer granularity. - Under AMP, the BatchNorm running statistics stay in fp32 even when the activations are fp16. Do not change this; the running statistics accumulate small per-step deltas that fp16 cannot represent.
- AMP assumes a single GPU. For multi-replica training, use
DistributedDataParallel
with AMP; each replica runs its own
GradScaler, and the gradients are all-reduced afterunscale_grads(so the inf-check is local to each replica).
See also
- AMP — the API page.
- cl_khr_fp16 — the OpenCL extension AMP depends on.
- fp16 — the format AMP casts to.
- GradScaler — the dynamic scaling wrapper.
- Adam — keep the optimizer state in fp32.
- BatchNorm — keep the running statistics in fp32.
- AMP — this article.