netcl wiki
concepts

AutogradPrimitive

AutogradPrimitive

Status: Public API in netcl.autograd.compiler.AutogradPrimitive

An AutogradPrimitive is the record the JIT Compiler uses to translate one elementwise op into OpenCL C source. It is a tiny frozen dataclass that pairs a forward and a backward code-emitting function with a name and an arity. The dictionary _PRIMITIVES in autograd/compiler.py is the global registry of all primitives the compiler knows about.

A primitive is the JIT's view of an op: it tells the source generator how to express the op in the fused kernel. The actual runtime implementation of the op (the one the eager mode uses) is a separate Python function in autograd/ops.py; the primitive is only consulted when @jit_compile is generating a fused kernel.

Overview

@dataclass(frozen=True)
class AutogradPrimitive:
    name:    str
    forward: Callable[[List[str], dict], str]
    backward:Callable[[List[str], str, dict, str], List[str]]
    arity:   int | None = None
    fusible: bool = True

The fields are:

  • name — the registered op name (e.g. "add", "relu", "my_op"). This is the same string the runtime passes to apply_op.
  • forward(args, attrs) — given a list of input C expressions (e.g. ["v0", "v1"]) and a dict of scalar attributes, return a single C expression that computes the output (e.g. "v0 + v1"). The expression is emitted as the body of the local-variable assignment in the fused kernel.
  • backward(args, grad_var, attrs, out_var) — given the inputs, the upstream gradient variable (a C identifier like "g"), the attributes, and the output variable, return one C expression per input, in input order. The expressions are emitted as the per-input gradient assignments in the backward kernel.
  • arity1, 2, or 3 for unary / binary / ternary. None means variadic; the JIT emits explicit input variables.
  • fusible — when False, the op is always a kernel boundary (e.g. matmul would not be fusible if it were ever registered). A non-fusible op causes @jit_compile to break the chain and start a new kernel.

Where It Lives

  • File path: autograd/compiler.py (class AutogradPrimitive).
  • Module path: netcl.autograd.compiler.
  • Public re-export: from netcl.autograd.compiler import AutogradPrimitive.
  • Sibling: register_primitive (the registration function), get_primitive (the lookup function).

How It Works

When @jit_compile finishes tracing, it walks the trace topologically in reverse. For each TraceNode, the source generator calls the corresponding primitive's forward (in the forward-kernel pass) and backward (in the backward-kernel pass). The returned C expressions are concatenated with the shared preamble to produce a complete OpenCL C source string, which is then compiled with cl.Program.build() and cached.

The built-in primitives cover the common scalar / binary / unary ops: add, sub, mul, div, maximum, minimum, where, lt, le, gt, ge, plus the _scalar_right / _scalar_left variants for operations against a Python constant. The full list is visible in _PRIMITIVES after import.

Code Example

Registering a custom primitive:

from netcl.autograd.compiler import register_primitive

register_primitive(
    name="my_op",
    forward=lambda args, attrs: f"MY_OP({args[0]}, {args[1]})",
    backward=lambda args, grad_var, attrs, out_var: [
        f"MUL({grad_var}, {args[1]})",
        f"MUL({grad_var}, {args[0]})",
    ],
    arity=2,
    fusible=True,
)

Inspecting the registered set:

from netcl.autograd.compiler import _PRIMITIVES
print(sorted(_PRIMITIVES.keys()))
# ['add', 'add_relu', 'add_scalar_left', 'add_scalar_right',
#  'div', 'div_scalar_left', 'div_scalar_right', 'exp', ...]

A primitive that breaks the chain:

register_primitive(
    name="matmul_marker",
    forward=lambda args, attrs: "0.0f",     # placeholder
    backward=lambda args, grad_var, attrs, out_var: ["0.0f", "0.0f"],
    arity=2,
    fusible=False,                          # not fusible!
)

The Tutorial: Writing a Custom OpenCL Kernel walks through a worked example that registers two primitives and verifies the generated source.

Performance & Trade-offs

  • The forward / backward callables are called once per primitive per source-generation pass. They are not on the hot path; the hot path is the cache lookup.
  • The compiled-in ADD, MUL, RELU, etc. macros in core/kernels/primitives.py define what your forward / backward returns can use. Stick to those macros and you get the fused kernel for free; bring your own helper and you must add it to the preamble.
  • fusible=False is the correct choice for any op that needs a separate kernel boundary (e.g. reductions, matmul, anything with a work-group sync). The chain will be split cleanly at that op.
  • Re-registering an existing name silently overwrites the previous entry. This is convenient for monkey-patching in a notebook but should be done with care — other code may have cached the old primitive.

See also