Architecture: JIT Compiler
Architecture: JIT Compiler
netcl.autograd.compiler (entry point:
jit_compile) implements a dynamic
just-in-time compiler for chains of elementwise ops. The idea is
simple: when a Python function consists only of elementwise ops and
runs in a hot loop, fuse the entire chain into a single OpenCL
program that does the work in one pass — one launch, one read per
input, one write per output, no intermediate
Tensors created on the device.
The pipeline is illustrated below. The trace pass records the ops; the compile pass generates OpenCL C source; the cache returns the already-built program on subsequent calls.
Caption — @jit_compile wraps a Python function. The first call
traces the op sequence into a TracingContext, generates OpenCL
source, builds it through cl.Program.build(), and caches the
result by (op_sequence, shapes, dtypes). Subsequent calls hit the
cache and skip the build entirely.
@jit_compile — the user-facing entry point
from netcl.autograd.compiler import jit_compile
@jit_compile
def fused_mlp_block(x, w, b):
h = ag.add(ag.matmul(x, w), b)
h = ag.relu(h)
return ag.add(h, 1.0)
The first call with new shapes traces and compiles. The cache key is
(fn, tuple_of_(shape, dtype), sorted_kwargs), so changing the input
shape forces a recompile. There are two compile paths:
- OpenCL path — when the first
Nodeargument's Tensor is on a GPU queue. Produces a fused forward kernel and a fused backward kernel. - NumPy path — when the first
Nodeargument's Tensor is on the CPU backend. Produces a pair of fused NumPy functions. Useful for unit tests that run withoutpyopencl.
If the op sequence contains an op the compiler does not know about,
@jit_compile falls back to the un-fused Python implementation, so
decorating a function is always safe.
TraceNode and TracingContext
class TraceNode:
def __init__(self, op_name, inputs, shape, dtype, attrs=None):
...
class TracingContext:
def __init__(self):
self.active = False
self.nodes: list[TraceNode] = []
When compiler.jit_compile enters trace mode, it sets
tracing_context.active = True and runs the user's function with
symbolic Node objects whose .value is a
TraceNode. Inside
apply_op, the tracing_context.active
branch fires, and a new TraceNode is appended to the context
instead of executing the op. The result is a DAG of TraceNodes
that mirrors the forward graph; no kernels have been launched.
When the function returns, jit_compile does a depth-first
topological sort, calls each registered
AutogradPrimitive's forward and
backward to produce C source, and builds the program.
AutogradPrimitive and register_primitive
@dataclass(frozen=True)
class AutogradPrimitive:
name: str
forward: Callable[[List[str], dict], str]
backward:Callable[[List[str], str, dict, str], List[str]]
arity: int | None = None
fusible: bool = True
An AutogradPrimitive is a tiny record
that tells the source generator how to translate one op:
forward(args, attrs)— given a list of input expressions (e.g.["v0", "v1"]) and a dict of scalar attributes, return a single C expression that computes the output (e.g."v0 + v1").backward(args, grad_var, attrs, out_var)— given the inputs, the upstream gradient variable, and the output variable, return one C expression per input, in input order.arity—1,2, or3for unary/binary/ternary.Nonemeans variadic; the JIT emits explicit input variables.fusible— whenFalse, the op is always a kernel boundary (e.g.matmulwould not be fusible if it were ever registered).
The built-in primitives cover the common scalar/binary/unary ops:
add, sub, mul, div, maximum, minimum, where, lt, le,
gt, ge, plus the _scalar_right / _scalar_left variants for
operations against a Python constant. The
autograd API page lists the full set.
You can add your own:
from netcl.autograd.compiler import register_primitive
register_primitive(
name="my_op",
forward=lambda args, attrs: f"MY_OP({args[0]}, {args[1]})",
backward=lambda args, grad_var, attrs, out_var: [
f"MUL({grad_var}, {args[1]})",
f"MUL({grad_var}, {args[0]})",
],
arity=2,
fusible=True,
)
After this call, my_op participates in @jit_compile chains just
like add or mul. The
Tutorial: Writing a Custom OpenCL Kernel
walks through a worked example that registers two primitives and
verifies the generated source.
Source-generation pipeline
The C source for a fused kernel is built in four stages:
- Preamble — the shared
PRIMITIVE_PREAMBLEfromcore/kernels/primitives.pyis prepended. It defines the inlineADD,MUL,RELU,SIGMOID, etc. helpers the primitives emit. - Argument list — for each input
Tensor, an
__global const T* in_<i>is declared; for each output, an__global T* out_<i>is declared. The dtypes are inferred from the first traced call's input tensors. - Broadcast index lines — when the inputs have different shapes
(e.g.
(B, 1)and(1, F)), the compiler emitsbroadcast_index_linesso the kernel computes the correct strided index per input. This is the same shape-broadcasting logic that the non-fused elementwise ops use; seeops/broadcast.py. - Body — the topological order is walked in reverse; for each
TraceNodea localfloat v<n> = …;line is emitted using the primitive'sforwardfunction. The finalout_0[gid] = vN;line is appended.
The result is a single-kernel C program. The same pipeline produces
a separate backward kernel, except out_<i> is replaced by the
local-gradient expressions returned by the primitive's backward.
Example: generated forward kernel
A small chain like y = relu(x * 0.5) + 1.0 produces source of the
form below. This is a real code sample — keep it in fenced c
rather than embedding it as a diagram.
__kernel void fused_0(
__global const float* in_0,
__global float* out_0
) {
int gid = get_global_id(0);
float v0 = in_0[gid];
float v1 = MUL(v0, 0.5f);
float v2 = RELU(v1);
out_0[gid] = ADD(v2, 1.0f);
}
The MUL, RELU, ADD macros come from the preamble; v0 is the
only input read, v2 is the only output written, and the entire
chain is one launch.
Cache
- Cache key =
(callable, tuple_of_(shape, dtype), sorted_kwargs). - Cache value =
(kernel_fw, kernel_bw, out_shape, out_size)for the OpenCL path, or(numpy_fw, numpy_bw, out_shape, out_size)for the CPU path. - The cache holds both the generated source and the built
cl.Program/ compiled NumPy ufuncs. Hitting the cache is therefore a single dict lookup, not a recompile. - There is no explicit cache invalidation API. If you need to force
a rebuild (e.g. you patched a primitive's
forwardfunction), restart the interpreter.
TrainingGraphCompiler and TrainingPattern
For larger, non-elementwise op chains, autograd/training_compiler.py
provides a pattern-based registry. The use case is detection losses
that combine a weighted binary cross-entropy with a weighted
smooth-L1: the two losses share an input, the GPU would otherwise
allocate two separate temporaries, and a hand-written fused kernel
saves one read pass.
@dataclass(frozen=True)
class TrainingPattern:
name: str
matcher: Callable[..., bool]
planner: Callable[..., object]
class TrainingGraphCompiler:
def __init__(self) -> None:
self._patterns: list[TrainingPattern] = []
def register(self, pattern: TrainingPattern) -> None: ...
A TrainingPattern has a matcher (does this call site fit the
pattern?) and a planner (build the fused plan object). The
planner's return value is a FusedDetectionLossPlan that holds
three OpenCL kernels:
partial_kernel— one block per output element, accumulates a weighted BCE + weighted smooth-L1 into a per-block partial.reduce_kernel— a single-block tree reduction that sums the partials into the final loss.backward_kernel— the same arithmetic, with the derivative applied.
The public entry point is
fused_weighted_bce_smooth_l1_loss(pred, target, heat_weight,
reg_weight), defined in
autograd/training_compiler.py. It validates that all four inputs
have matching shape and float32 dtype, then either compiles (and
caches) a FusedDetectionLossPlan or returns a cached one.
The same module also defines the TrainingGraphCompiler registry
and a register method so that other training patterns can be added
in user code:
compiler = ag.TrainingGraphCompiler() # the global singleton
compiler.register(TrainingPattern(...)) # user-defined
When to use @jit_compile vs. the raw autograd
| Scenario | Recommendation |
|---|---|
| Plain forward, no inner loop | Use nn.Module; the autograd path is already efficient for a few ops. |
| Elementwise chain in a hot loop (activations, layernorm reductions, masked softmax parts) | Decorate with @jit_compile. The first call pays the build cost; subsequent calls are a single kernel launch. |
Ops that include matmul, conv2d, batch_norm, or any reduction |
Do not decorate with @jit_compile; the decorator will fall back to the un-fused path. |
| A detection loss combining BCE + smooth-L1 | Use fused_weighted_bce_smooth_l1_loss directly. |
| You want to author a new op that the JIT should know about | Use register_primitive and follow the Tutorial: Writing a Custom OpenCL Kernel. |
See also
- autograd API — the full symbol list, including
apply_op,Tape, anddetect_anomaly. - Architecture: Autograd & Tape — the
Tape/Nodemachinery that the JIT's tracing mode piggybacks on. - Tutorial: Writing a Custom OpenCL Kernel
— walks through
register_primitiveend-to-end. - Tutorial: Understanding Autograd — the
prerequisite for the JIT pages; covers
apply_opandTape. - Tensor Backend — the layer that actually builds and runs the generated OpenCL C source.
- Tensor API — the value type that flows through the fused kernel.