netcl wiki
concepts

JIT Compiler

JIT Compiler

Status: Architecture component in autograd/compiler.py

The netcl JIT compiler is a dynamic, trace-based just-in-time compiler for chains of elementwise ops. When a Python function is decorated with @jit_compile, the first call records the op sequence into a TracingContext, generates OpenCL C source for the entire chain, builds it once through cl.Program.build(), and caches the result. Subsequent calls with the same input shapes hit the cache and skip the build entirely.

The result is a single OpenCL kernel that does the work in one pass — one launch, one read per input, one write per output, no intermediate device tensors created along the chain. Compared to a naive per-op implementation, this is typically a 2x to 4x speed-up on elementwise-heavy networks (MLP activations, normalization tails, residual adds).

Overview

The JIT compiler is the dynamic sibling of the static TrainingGraphCompiler (see autograd/training_compiler.py); it specialises at trace time and produces fused forward and fused backward kernels in one go. It is opt-in: you decorate a function with @jit_compile and the compiler does the rest. If the op sequence contains an op the compiler does not know about, @jit_compile falls back to the un-fused Python implementation, so decorating a function is always safe.

The cache key is (callable, tuple_of_(shape, dtype), sorted_kwargs), so changing the input shape forces a recompile. There are two compile paths:

  • OpenCL path — when the first Node argument's tensor is on a GPU queue. Produces a fused forward kernel and a fused backward kernel.
  • NumPy path — when the first Node argument's tensor is on the CPU backend. Produces a pair of fused NumPy functions. Useful for unit tests that run without pyopencl.

Where It Lives

  • File path: autograd/compiler.py.
  • Module path: netcl.autograd.compiler.
  • Public re-export: top-level netcl.autograd.jit_compile.
  • Sibling: autograd.training_compiler for pattern-based fusion of non-elementwise chains.

Diagram

How It Works

The pipeline is four stages.

1. Trace. When jit_compile enters trace mode, it sets tracing_context.active = True and runs the user's function with symbolic Node objects whose .value is a TraceNode. Inside apply_op, the tracing_context.active branch fires, and a new TraceNode is appended to the context instead of executing the op. The result is a DAG of TraceNodes that mirrors the forward graph; no kernels have been launched.

2. Topological sort. When the function returns, jit_compile does a depth-first topological sort over the trace.

3. Source generation. For each TraceNode, the compiler calls the registered AutogradPrimitive.forward and backward to produce the C expressions. The shared PRIMITIVE_PREAMBLE from core/kernels/primitives.py is prepended — it defines the inline ADD, MUL, RELU, SIGMOID, etc. helpers the primitives emit. For inputs with different shapes, broadcast_index_lines is emitted so the kernel computes the correct strided index per input.

4. Build and cache. The generated C is built with cl.Program.build(). The resulting program is cached keyed on (callable, shapes, dtypes, kwargs). On the next call, the cache returns the already-built program without re-running the build.

Code Example

from netcl.autograd.compiler import jit_compile, register_primitive
import netcl.autograd as ag

@jit_compile
def fused_mlp_block(x, w, b):
    h = ag.add(ag.matmul(x, w), b)
    h = ag.relu(h)
    return ag.add(h, 1.0)

# First call traces, compiles, runs.
y = fused_mlp_block(x, w, b)
# Second call hits the cache.
y = fused_mlp_block(x, w, b)

You can register a new primitive and have it participate in @jit_compile:

register_primitive(
    name="my_op",
    forward=lambda args, attrs: f"MY_OP({args[0]}, {args[1]})",
    backward=lambda args, grad_var, attrs, out_var: [
        f"MUL({grad_var}, {args[1]})",
        f"MUL({grad_var}, {args[0]})",
    ],
    arity=2,
    fusible=True,
)

The Tutorial: Writing a Custom OpenCL Kernel walks through a worked example that registers two primitives and verifies the generated source.

Performance & Trade-offs

  • Compile cost is amortised: the first call may take 10 ms to 100 ms (the clBuildProgram cost), but subsequent calls in a hot loop are essentially free.
  • Shape instability kills the cache: a model whose input shape changes every step (variable-length sequence) will recompile every step. Pad to a fixed length or fall back to the un-fused implementation.
  • Non-elementwise ops break fusion: matmul, conv2d, bmm are not fusible in this path. Use TrainingGraphCompiler for those.
  • First-step latency: a 100 ms compile on the very first step is often the cause of "the first epoch is slow" complaints. Warm up the model with a fake batch before timing the run.

See also