BatchNorm
BatchNorm
Status:
BatchNorm2dis the public export (from netcl.nn import BatchNorm2d).BatchNorm1dis in the internal module but is not re-exported fromnetcl.nn.
Batch Normalization (BatchNorm) normalizes a layer's activations
across the batch dimension to have zero mean and unit variance, then
applies a learnable per-channel scale and shift. It was introduced
by Ioffe and Szegedy (2015) and dramatically stabilises and accelerates
the training of deep networks.
BatchNorm has two modes:
- Training mode — uses the current batch's mean and variance, and updates a running estimate of the population statistics.
- Evaluation mode — uses the running mean and variance accumulated during training. This is what you want at inference time; using batch statistics at inference makes the model's output depend on the batch composition, which is almost never what you want.
Overview
The forward pass in training mode is:
mu_B = mean(x, axis=batch)
var_B = var(x, axis=batch)
x_hat = (x - mu_B) / sqrt(var_B + eps)
y = gamma * x_hat + beta
running_mean = (1 - momentum) * running_mean + momentum * mu_B
running_var = (1 - momentum) * running_var + momentum * var_B
In evaluation mode:
y = gamma * (x - running_mean) / sqrt(running_var + eps) + beta
netcl implements BatchNorm1d and BatchNorm2d. BatchNorm1d
expects input of shape (N, C) or (N, C, L). BatchNorm2d
expects input of shape (N, C, H, W). Both share the same
underlying kernel.
Where It Lives
- File path:
nn/batchnorm.py(class BatchNorm2dandclass BatchNorm1d). - Module path:
netcl.nn.batchnorm. - Public re-export:
from netcl.nn import BatchNorm2d. - Sibling:
nn.layernorm.LayerNorm,nn.groupnorm.GroupNorm.
Diagram
How It Works
The BatchNorm2d constructor allocates four parameters:
weight(gamma) — shape(num_features,), learnable.bias(beta) — shape(num_features,), learnable.running_mean— shape(num_features,), not learnable, but updated by the forward pass in training mode.running_var— shape(num_features,), not learnable, but updated by the forward pass in training mode.
The forward pass is implemented as a fused OpenCL kernel: a single-launch reduction that computes the per-channel mean and variance, immediately followed by the normalization and the affine transform. The backward pass is similarly fused.
The fused implementation is exposed as the BATCHNORM_FUSED op; a
fallback naive implementation (BATCHNORM_NAIVE) is also kept for
devices where the fused path is not supported.
Code Example
import netcl as nc
import netcl.nn as nn
bn = nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1)
bn.train() # the default after construction
x = nc.Tensor.zeros((8, 64, 224, 224), dtype="float32",
context=ctx, queue=q)
y = bn(x) # uses batch statistics, updates running_*.
bn.eval() # switch to evaluation mode
y = bn(x) # uses running_mean / running_var.
A typical ResNet block uses BatchNorm after every convolution:
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, 1, stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
def forward(self, x):
out = nc.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = out + self.shortcut(x)
return nc.relu(out)
Performance & Trade-offs
- The
momentumparameter is the EMA coefficient for the running statistics, not the optimizer's momentum. It is0.1by default in netcl (matches the original paper's "moving average" convention); PyTorch's default is0.1as well, while some other frameworks use0.9or0.99. When you port a model from PyTorch, you usually do not need to convert momentum values. - Calling
bn.train()/bn.eval()is mandatory. Forgetting to switch toeval()at inference is the single most common bug in netcl models. The Trainer handles this automatically. - Under AMP, the running statistics stay in fp32 even when the activations are fp16. Mixing precisions is catastrophic for normalisation; do not change it.
- BatchNorm's behaviour is sensitive to batch size: a batch of 1
in training mode produces zero variance and the normalization
divides by zero. The fused kernel detects this and falls back
to
eps-only scaling, but the resulting gradient is not what you want. Use GroupNorm or LayerNorm for very small batches.