AutogradPrimitive
AutogradPrimitive
Status: Public API in
netcl.autograd.compiler.AutogradPrimitive
An AutogradPrimitive is the record the
JIT Compiler uses to translate one
elementwise op into OpenCL C source. It is a tiny frozen dataclass
that pairs a forward and a backward code-emitting function
with a name and an arity. The dictionary _PRIMITIVES in
autograd/compiler.py is the global registry of all primitives
the compiler knows about.
A primitive is the JIT's view of an op: it tells the source
generator how to express the op in the fused kernel. The actual
runtime implementation of the op (the one the eager mode uses)
is a separate Python function in autograd/ops.py; the primitive
is only consulted when @jit_compile is generating a fused
kernel.
Overview
@dataclass(frozen=True)
class AutogradPrimitive:
name: str
forward: Callable[[List[str], dict], str]
backward:Callable[[List[str], str, dict, str], List[str]]
arity: int | None = None
fusible: bool = True
The fields are:
name— the registered op name (e.g."add","relu","my_op"). This is the same string the runtime passes toapply_op.forward(args, attrs)— given a list of input C expressions (e.g.["v0", "v1"]) and a dict of scalar attributes, return a single C expression that computes the output (e.g."v0 + v1"). The expression is emitted as the body of the local-variable assignment in the fused kernel.backward(args, grad_var, attrs, out_var)— given the inputs, the upstream gradient variable (a C identifier like"g"), the attributes, and the output variable, return one C expression per input, in input order. The expressions are emitted as the per-input gradient assignments in the backward kernel.arity—1,2, or3for unary / binary / ternary.Nonemeans variadic; the JIT emits explicit input variables.fusible— whenFalse, the op is always a kernel boundary (e.g.matmulwould not be fusible if it were ever registered). A non-fusible op causes@jit_compileto break the chain and start a new kernel.
Where It Lives
- File path:
autograd/compiler.py(class AutogradPrimitive). - Module path:
netcl.autograd.compiler. - Public re-export:
from netcl.autograd.compiler import AutogradPrimitive. - Sibling:
register_primitive(the registration function),get_primitive(the lookup function).
How It Works
When @jit_compile finishes tracing, it walks the trace
topologically in reverse. For each TraceNode, the source
generator calls the corresponding primitive's forward (in the
forward-kernel pass) and backward (in the backward-kernel
pass). The returned C expressions are concatenated with the
shared preamble to produce a complete OpenCL C source string,
which is then compiled with cl.Program.build() and cached.
The built-in primitives cover the common scalar / binary / unary
ops: add, sub, mul, div, maximum, minimum, where,
lt, le, gt, ge, plus the _scalar_right / _scalar_left
variants for operations against a Python constant. The full list
is visible in _PRIMITIVES after import.
Code Example
Registering a custom primitive:
from netcl.autograd.compiler import register_primitive
register_primitive(
name="my_op",
forward=lambda args, attrs: f"MY_OP({args[0]}, {args[1]})",
backward=lambda args, grad_var, attrs, out_var: [
f"MUL({grad_var}, {args[1]})",
f"MUL({grad_var}, {args[0]})",
],
arity=2,
fusible=True,
)
Inspecting the registered set:
from netcl.autograd.compiler import _PRIMITIVES
print(sorted(_PRIMITIVES.keys()))
# ['add', 'add_relu', 'add_scalar_left', 'add_scalar_right',
# 'div', 'div_scalar_left', 'div_scalar_right', 'exp', ...]
A primitive that breaks the chain:
register_primitive(
name="matmul_marker",
forward=lambda args, attrs: "0.0f", # placeholder
backward=lambda args, grad_var, attrs, out_var: ["0.0f", "0.0f"],
arity=2,
fusible=False, # not fusible!
)
The Tutorial: Writing a Custom OpenCL Kernel walks through a worked example that registers two primitives and verifies the generated source.
Performance & Trade-offs
- The
forward/backwardcallables are called once per primitive per source-generation pass. They are not on the hot path; the hot path is the cache lookup. - The compiled-in
ADD,MUL,RELU, etc. macros incore/kernels/primitives.pydefine what yourforward/backwardreturns can use. Stick to those macros and you get the fused kernel for free; bring your own helper and you must add it to the preamble. fusible=Falseis the correct choice for any op that needs a separate kernel boundary (e.g. reductions, matmul, anything with a work-group sync). The chain will be split cleanly at that op.- Re-registering an existing name silently overwrites the previous entry. This is convenient for monkey-patching in a notebook but should be done with care — other code may have cached the old primitive.
See also
- JIT Compiler — the consumer of
AutogradPrimitive. - TraceNode — the trace node the primitive translates.
- Tensor — the underlying memory the kernel reads and writes.
- Tutorial: Custom OpenCL Kernel —
a worked example of
register_primitive. - AutogradPrimitive — this article.