JIT Compiler
JIT Compiler
Status: Architecture component in
autograd/compiler.py
The netcl JIT compiler is a dynamic, trace-based just-in-time compiler
for chains of elementwise ops. When a Python function is decorated with
@jit_compile, the first call records the op sequence into a
TracingContext, generates OpenCL C source for the entire chain, builds
it once through cl.Program.build(), and caches the result. Subsequent
calls with the same input shapes hit the cache and skip the build
entirely.
The result is a single OpenCL kernel that does the work in one pass — one launch, one read per input, one write per output, no intermediate device tensors created along the chain. Compared to a naive per-op implementation, this is typically a 2x to 4x speed-up on elementwise-heavy networks (MLP activations, normalization tails, residual adds).
Overview
The JIT compiler is the dynamic sibling of the static
TrainingGraphCompiler (see autograd/training_compiler.py); it
specialises at trace time and produces fused forward and fused backward
kernels in one go. It is opt-in: you decorate a function with
@jit_compile and the compiler does the rest. If the op sequence
contains an op the compiler does not know about, @jit_compile falls
back to the un-fused Python implementation, so decorating a function is
always safe.
The cache key is (callable, tuple_of_(shape, dtype), sorted_kwargs),
so changing the input shape forces a recompile. There are two compile
paths:
- OpenCL path — when the first
Nodeargument's tensor is on a GPU queue. Produces a fused forward kernel and a fused backward kernel. - NumPy path — when the first
Nodeargument's tensor is on the CPU backend. Produces a pair of fused NumPy functions. Useful for unit tests that run withoutpyopencl.
Where It Lives
- File path:
autograd/compiler.py. - Module path:
netcl.autograd.compiler. - Public re-export: top-level
netcl.autograd.jit_compile. - Sibling:
autograd.training_compilerfor pattern-based fusion of non-elementwise chains.
Diagram
How It Works
The pipeline is four stages.
1. Trace. When jit_compile enters trace mode, it sets
tracing_context.active = True and runs the user's function with
symbolic Node objects whose .value is a TraceNode. Inside
apply_op, the tracing_context.active branch fires, and a new
TraceNode is appended to the context instead of executing the op.
The result is a DAG of TraceNodes that mirrors the forward graph; no
kernels have been launched.
2. Topological sort. When the function returns, jit_compile does
a depth-first topological sort over the trace.
3. Source generation. For each TraceNode, the compiler calls the
registered AutogradPrimitive.forward and backward to produce the C
expressions. The shared PRIMITIVE_PREAMBLE from
core/kernels/primitives.py is prepended — it defines the inline
ADD, MUL, RELU, SIGMOID, etc. helpers the primitives emit. For
inputs with different shapes, broadcast_index_lines is emitted so the
kernel computes the correct strided index per input.
4. Build and cache. The generated C is built with
cl.Program.build(). The resulting program is cached keyed on
(callable, shapes, dtypes, kwargs). On the next call, the cache
returns the already-built program without re-running the build.
Code Example
from netcl.autograd.compiler import jit_compile, register_primitive
import netcl.autograd as ag
@jit_compile
def fused_mlp_block(x, w, b):
h = ag.add(ag.matmul(x, w), b)
h = ag.relu(h)
return ag.add(h, 1.0)
# First call traces, compiles, runs.
y = fused_mlp_block(x, w, b)
# Second call hits the cache.
y = fused_mlp_block(x, w, b)
You can register a new primitive and have it participate in
@jit_compile:
register_primitive(
name="my_op",
forward=lambda args, attrs: f"MY_OP({args[0]}, {args[1]})",
backward=lambda args, grad_var, attrs, out_var: [
f"MUL({grad_var}, {args[1]})",
f"MUL({grad_var}, {args[0]})",
],
arity=2,
fusible=True,
)
The Tutorial: Writing a Custom OpenCL Kernel walks through a worked example that registers two primitives and verifies the generated source.
Performance & Trade-offs
- Compile cost is amortised: the first call may take 10 ms to 100 ms
(the
clBuildProgramcost), but subsequent calls in a hot loop are essentially free. - Shape instability kills the cache: a model whose input shape changes every step (variable-length sequence) will recompile every step. Pad to a fixed length or fall back to the un-fused implementation.
- Non-elementwise ops break fusion:
matmul,conv2d,bmmare not fusible in this path. UseTrainingGraphCompilerfor those. - First-step latency: a 100 ms compile on the very first step is often the cause of "the first epoch is slow" complaints. Warm up the model with a fake batch before timing the run.
See also
- JIT Compiler — architecture overview with a larger diagram.
- AutogradPrimitive — the record describing one op to the source generator.
- Tape — the forward graph the JIT traces.
- Tensor — the tensors traced into the chain.
- Tutorial: Custom OpenCL Kernel — how to register your own primitives.
- CompiledGraph — the runtime's view of a compiled chain.
- JIT Compiler — this article.