netcl wiki
architecture

Architecture: JIT Compiler

Architecture: JIT Compiler

netcl.autograd.compiler (entry point: jit_compile) implements a dynamic just-in-time compiler for chains of elementwise ops. The idea is simple: when a Python function consists only of elementwise ops and runs in a hot loop, fuse the entire chain into a single OpenCL program that does the work in one pass — one launch, one read per input, one write per output, no intermediate Tensors created on the device.

The pipeline is illustrated below. The trace pass records the ops; the compile pass generates OpenCL C source; the cache returns the already-built program on subsequent calls.

Caption — @jit_compile wraps a Python function. The first call traces the op sequence into a TracingContext, generates OpenCL source, builds it through cl.Program.build(), and caches the result by (op_sequence, shapes, dtypes). Subsequent calls hit the cache and skip the build entirely.

@jit_compile — the user-facing entry point

from netcl.autograd.compiler import jit_compile

@jit_compile
def fused_mlp_block(x, w, b):
    h = ag.add(ag.matmul(x, w), b)
    h = ag.relu(h)
    return ag.add(h, 1.0)

The first call with new shapes traces and compiles. The cache key is (fn, tuple_of_(shape, dtype), sorted_kwargs), so changing the input shape forces a recompile. There are two compile paths:

  • OpenCL path — when the first Node argument's Tensor is on a GPU queue. Produces a fused forward kernel and a fused backward kernel.
  • NumPy path — when the first Node argument's Tensor is on the CPU backend. Produces a pair of fused NumPy functions. Useful for unit tests that run without pyopencl.

If the op sequence contains an op the compiler does not know about, @jit_compile falls back to the un-fused Python implementation, so decorating a function is always safe.

TraceNode and TracingContext

class TraceNode:
    def __init__(self, op_name, inputs, shape, dtype, attrs=None):
        ...

class TracingContext:
    def __init__(self):
        self.active = False
        self.nodes: list[TraceNode] = []

When compiler.jit_compile enters trace mode, it sets tracing_context.active = True and runs the user's function with symbolic Node objects whose .value is a TraceNode. Inside apply_op, the tracing_context.active branch fires, and a new TraceNode is appended to the context instead of executing the op. The result is a DAG of TraceNodes that mirrors the forward graph; no kernels have been launched.

When the function returns, jit_compile does a depth-first topological sort, calls each registered AutogradPrimitive's forward and backward to produce C source, and builds the program.

AutogradPrimitive and register_primitive

@dataclass(frozen=True)
class AutogradPrimitive:
    name:    str
    forward: Callable[[List[str], dict], str]
    backward:Callable[[List[str], str, dict, str], List[str]]
    arity:   int | None = None
    fusible: bool = True

An AutogradPrimitive is a tiny record that tells the source generator how to translate one op:

  • forward(args, attrs) — given a list of input expressions (e.g. ["v0", "v1"]) and a dict of scalar attributes, return a single C expression that computes the output (e.g. "v0 + v1").
  • backward(args, grad_var, attrs, out_var) — given the inputs, the upstream gradient variable, and the output variable, return one C expression per input, in input order.
  • arity1, 2, or 3 for unary/binary/ternary. None means variadic; the JIT emits explicit input variables.
  • fusible — when False, the op is always a kernel boundary (e.g. matmul would not be fusible if it were ever registered).

The built-in primitives cover the common scalar/binary/unary ops: add, sub, mul, div, maximum, minimum, where, lt, le, gt, ge, plus the _scalar_right / _scalar_left variants for operations against a Python constant. The autograd API page lists the full set.

You can add your own:

from netcl.autograd.compiler import register_primitive

register_primitive(
    name="my_op",
    forward=lambda args, attrs: f"MY_OP({args[0]}, {args[1]})",
    backward=lambda args, grad_var, attrs, out_var: [
        f"MUL({grad_var}, {args[1]})",
        f"MUL({grad_var}, {args[0]})",
    ],
    arity=2,
    fusible=True,
)

After this call, my_op participates in @jit_compile chains just like add or mul. The Tutorial: Writing a Custom OpenCL Kernel walks through a worked example that registers two primitives and verifies the generated source.

Source-generation pipeline

The C source for a fused kernel is built in four stages:

  1. Preamble — the shared PRIMITIVE_PREAMBLE from core/kernels/primitives.py is prepended. It defines the inline ADD, MUL, RELU, SIGMOID, etc. helpers the primitives emit.
  2. Argument list — for each input Tensor, an __global const T* in_<i> is declared; for each output, an __global T* out_<i> is declared. The dtypes are inferred from the first traced call's input tensors.
  3. Broadcast index lines — when the inputs have different shapes (e.g. (B, 1) and (1, F)), the compiler emits broadcast_index_lines so the kernel computes the correct strided index per input. This is the same shape-broadcasting logic that the non-fused elementwise ops use; see ops/broadcast.py.
  4. Body — the topological order is walked in reverse; for each TraceNode a local float v<n> = …; line is emitted using the primitive's forward function. The final out_0[gid] = vN; line is appended.

The result is a single-kernel C program. The same pipeline produces a separate backward kernel, except out_<i> is replaced by the local-gradient expressions returned by the primitive's backward.

Example: generated forward kernel

A small chain like y = relu(x * 0.5) + 1.0 produces source of the form below. This is a real code sample — keep it in fenced c rather than embedding it as a diagram.

__kernel void fused_0(
    __global const float* in_0,
    __global float*       out_0
) {
    int gid = get_global_id(0);
    float v0 = in_0[gid];
    float v1 = MUL(v0, 0.5f);
    float v2 = RELU(v1);
    out_0[gid] = ADD(v2, 1.0f);
}

The MUL, RELU, ADD macros come from the preamble; v0 is the only input read, v2 is the only output written, and the entire chain is one launch.

Cache

  • Cache key = (callable, tuple_of_(shape, dtype), sorted_kwargs).
  • Cache value = (kernel_fw, kernel_bw, out_shape, out_size) for the OpenCL path, or (numpy_fw, numpy_bw, out_shape, out_size) for the CPU path.
  • The cache holds both the generated source and the built cl.Program / compiled NumPy ufuncs. Hitting the cache is therefore a single dict lookup, not a recompile.
  • There is no explicit cache invalidation API. If you need to force a rebuild (e.g. you patched a primitive's forward function), restart the interpreter.

TrainingGraphCompiler and TrainingPattern

For larger, non-elementwise op chains, autograd/training_compiler.py provides a pattern-based registry. The use case is detection losses that combine a weighted binary cross-entropy with a weighted smooth-L1: the two losses share an input, the GPU would otherwise allocate two separate temporaries, and a hand-written fused kernel saves one read pass.

@dataclass(frozen=True)
class TrainingPattern:
    name: str
    matcher: Callable[..., bool]
    planner: Callable[..., object]

class TrainingGraphCompiler:
    def __init__(self) -> None:
        self._patterns: list[TrainingPattern] = []

    def register(self, pattern: TrainingPattern) -> None: ...

A TrainingPattern has a matcher (does this call site fit the pattern?) and a planner (build the fused plan object). The planner's return value is a FusedDetectionLossPlan that holds three OpenCL kernels:

  • partial_kernel — one block per output element, accumulates a weighted BCE + weighted smooth-L1 into a per-block partial.
  • reduce_kernel — a single-block tree reduction that sums the partials into the final loss.
  • backward_kernel — the same arithmetic, with the derivative applied.

The public entry point is fused_weighted_bce_smooth_l1_loss(pred, target, heat_weight, reg_weight), defined in autograd/training_compiler.py. It validates that all four inputs have matching shape and float32 dtype, then either compiles (and caches) a FusedDetectionLossPlan or returns a cached one.

The same module also defines the TrainingGraphCompiler registry and a register method so that other training patterns can be added in user code:

compiler = ag.TrainingGraphCompiler()      # the global singleton
compiler.register(TrainingPattern(...))    # user-defined

When to use @jit_compile vs. the raw autograd

Scenario Recommendation
Plain forward, no inner loop Use nn.Module; the autograd path is already efficient for a few ops.
Elementwise chain in a hot loop (activations, layernorm reductions, masked softmax parts) Decorate with @jit_compile. The first call pays the build cost; subsequent calls are a single kernel launch.
Ops that include matmul, conv2d, batch_norm, or any reduction Do not decorate with @jit_compile; the decorator will fall back to the un-fused path.
A detection loss combining BCE + smooth-L1 Use fused_weighted_bce_smooth_l1_loss directly.
You want to author a new op that the JIT should know about Use register_primitive and follow the Tutorial: Writing a Custom OpenCL Kernel.

See also