netcl wiki
tutorials

Tutorial: Writing a Custom OpenCL Kernel

Tutorial: Writing a Custom OpenCL Kernel

Sometimes you need an op that netcl does not ship — a novel activation, a domain-specific reduction, a fused variant of two existing ops. In this tutorial we drop down to the OpenCL layer: we read the device capabilities via DeviceManager, write the OpenCL C source with the PRIMITIVE_PREAMBLE macro set, wrap it in a KernelSpec and WorkGroupTuner heuristic, and integrate the kernel with autograd via apply_op. We then take the optional extra step of registering the op with the JIT Compiler so the elementwise JIT fuses it into chains of other ops.

By the end of the tutorial you will be able to add a brand-new, fully-differentiable op to netcl in about 60 lines of Python.

Prerequisites

You do not need to have read the JIT Compiler page in detail; the relevant pieces are explained inline.

What You'll Build

A custom tanh activation, out = tanh(x), implemented as:

  1. An OpenCL C kernel generated by KernelSpec, built once with pyopencl.Program and cached on the DeviceHandle.
  2. A small Python wrapper my_tanh(x: Tensor) -> Tensor that enqueues the kernel and returns the result.
  3. An autograd node tanh_node(a: Node) -> Node that registers the forward + backward closure via apply_op.
  4. (Optional) An AutogradPrimitive registration with the JIT Compiler so tanh participates in elementwise fusion.

When to Write a Custom Kernel

Decision callout. You do not need a custom kernel for an op that netcl already ships. The built-in ag.tanh, ag.silu, ag.mish, ag.gelu, ag.sigmoid, and the whole ops catalog are already registered with the JIT Compiler and the autograd machinery. Reach for a custom kernel only when:

  1. netcl does not ship the op (a domain-specific activation, a novel normalization, a custom loss component).
  2. You need a kernel that fuses several built-ins in a way the JIT Compiler cannot — for example, a reduction that reads a per-channel parameter that the JIT cannot trace symbolically.
  3. You are porting an op from a paper or a research repo and want a faithful one-kernel reproduction before considering a higher-level rewrite.

If you are unsure, try expressing the op as a composition of stock elementwise_binary / elementwise_unary calls first. The JIT Compiler will fuse them, and you save yourself the autograd plumbing.

Step-by-Step

1. Probe the Device

DeviceManager is the entry point for everything device-related. The manager singleton is constructed at import time and is what most user code reaches for. We also want the device's OpenCL capability profile so we can branch on cl_khr_fp16 if the kernel needs half-precision.

from netcl.core.device import manager
from netcl.core.capabilities import device_profile

dev = manager.default("auto")                # DeviceHandle (platform, device, ctx, queue)
prof = device_profile(dev.queue.device)      # frozen DeviceProfile
print(dev.platform_name, dev.device_name)
print("fp16:", prof.has_fp16, "subgroups:", prof.has_subgroups)
print("local_mem_kb:", prof.local_mem_size // 1024)

The returned DeviceHandle carries the OpenCL cl.Context and cl.CommandQueue that the rest of the code will use. The DeviceProfile is the conservative, fork-safe summary of what the device can do.

2. Write the OpenCL C Source

KernelSpec is the declarative codegen layer. You give it the kernel name, the parameter list, an optional preamble, and a C body; KernelSpec.to_source() returns a complete __kernel function.

from netcl.core.kernels.primitives import KernelSpec, PRIMITIVE_PREAMBLE

src = KernelSpec(
    name="tanh_kernel",
    params=["__global const float* x", "__global float* y"],
    preamble=PRIMITIVE_PREAMBLE,
    body="""
        int gid = get_global_id(0);
        y[gid] = tanh(x[gid]);
    """,
).to_source()
print(src)

PRIMITIVE_PREAMBLE is a small set of C macros (LOAD, STORE, ADD, SUB, MUL, DIV, RELU) that the JIT Compiler emits inside fused kernels. It costs nothing to include in a standalone kernel and keeps the code self-consistent with the rest of the generated bodies.

The output is a complete OpenCL C string ready to hand to pyopencl.Program.

3. Compile and Cache the Kernel

We compile the source once with pyopencl.Program(ctx, src).build(), extract the kernel callable, and bind the argument buffers lazily on every call.

import pyopencl as cl
from netcl.core.kernels.primitives import WorkGroupTuner

tuner = WorkGroupTuner()                          # heuristic local/global sizes
local_size = tuner.suggest_local_size_1d()        # e.g. 256
print("local_size:", local_size, "global_size:", ((x.size + local_size - 1) // local_size) * local_size)

prog = cl.Program(dev.context, src).build()
kernel = prog.tanh_kernel

# Per-launch: bind the input/output buffers and enqueue.
def launch_tanh(x_buf, y_buf, n):
    gsz = ((n + local_size - 1) // local_size) * local_size
    kernel.set_arg(0, x_buf)
    kernel.set_arg(1, y_buf)
    cl.enqueue_nd_range_kernel(dev.queue, kernel, (gsz,), (local_size,))

WorkGroupTuner.suggest_local_size_1d() returns a sensible default for a one-D elementwise kernel (256 is a good number for most GPUs). The rounded-up gsz is the canonical pattern for "exactly N work-items, no remainder".

4. Wrap as a netcl Tensor Function

The wrapper takes a Tensor, allocates an output with the same shape and dtype, enqueues the kernel, and returns the output. It is the "Python-shaped" entry point the autograd wrapper will call.

from netcl.core.tensor import Tensor

def my_tanh(x: Tensor) -> Tensor:
    """Elementwise tanh with a custom OpenCL kernel."""
    y = Tensor.empty_like(x)
    launch_tanh(x.buffer, y.buffer, x.size)
    return y

The function is a plain Python callable — no Tape, no apply_op. It can be used inside or outside a Tape; the autograd wrapper in the next section is what makes it differentiable.

5. Integrate with Autograd & Tape via apply_op

The autograd integration has two pieces: a forward closure that calls my_tanh, and a grad_fn closure that produces the gradient w.r.t. each parent. The chain rule for tanh is d/dx tanh(x) = 1 - tanh(x)^2, so the backward needs the forward output as well as the input.

import netcl.autograd as ag
from netcl.autograd.engine import apply_op
from netcl.ops.elementwise import elementwise_binary

def tanh_node(a: ag.Node, tape=None) -> ag.Node:
    def forward(x_t):
        return my_tanh(x_t)              # uses the kernel from §3

    def grad_fn(grad_out):
        # d/dx tanh(x) = 1 - tanh(x)^2
        y_t = forward(a.value)          # recompute forward (cheap for tanh)
        sq  = elementwise_binary(y_t, y_t, expression="MUL(v0, v1)")
        one_minus_sq = elementwise_binary(sq, sq, expression="SUB(1.0, v0)")
        return [elementwise_binary(grad_out, one_minus_sq,
                                   expression="MUL(v0, v1)")]

    return apply_op(forward, grad_fn, a, tape=tape, op_name="tanh")

What apply_op does. The apply_op entry point is the bridge between a Python forward and a Python backward. Its real signature (per autograd/engine.py) is apply_op(fn, grad_fn, *args, tape=None, op_name=None, attrs=None). When called, it does one of three things in order:

  1. JIT trace bypass. If the JIT Compiler is active (a @jit_compiled function is calling us), it returns a Node whose value is a TraceNode placeholder — no kernel is launched.
  2. Grad off. If is_grad_enabled() is False, it just calls fn(...) and returns the raw Tensor. No Node is built.
  3. Normal path. Runs the forward, builds a Node with the grad_fn, the parents, and the op_name, then calls tape.record(node) if a Tape is in scope.

The grad_fn you pass receives a single argument — the upstream gradient — and must return one Tensor per parent in the order the parents were passed to apply_op.

6. Smoke Test

The minimal sanity check is a forward + backward on a small input, with a finite- difference check to confirm the gradient is correct.

import numpy as np
from netcl.core.device import manager
from netcl.core.tensor import Tensor
import netcl.autograd as ag

dev = manager.default("auto")
q   = dev.queue

# 1) Forward + backward sanity check.
x_t = Tensor.from_host(q, np.linspace(-2, 2, 8, dtype=np.float32))
x_n = ag.tensor(x_t, requires_grad=True)
y_n = tanh_node(x_n)
ag.Tape().backward(ag.sum(y_n))
print("y:", y_n.value.to_host())
print("dy/dx:", x_n.value.grad.to_host())     # analytic, 1 - tanh(x)^2

# 2) Finite-difference numerical check on a single element.
def f(x_arr):
    x_t = Tensor.from_host(q, x_arr.astype(np.float32))
    x_n = ag.tensor(x_t, requires_grad=True)
    y_n = tanh_node(x_n)
    return float(ag.sum(y_n).value.to_host())

eps = 1e-3
x0 = np.zeros(8, dtype=np.float32)
for i in range(8):
    xp = x0.copy(); xp[i] += eps
    xm = x0.copy(); xm[i] -= eps
    num = (f(xp) - f(xm)) / (2 * eps)
    print(f"i={i}: analytic={x_n.value.grad.to_host()[i]:.6f}  numeric={num:.6f}")

The finite-difference check is the same pattern detect_anomaly uses internally; running it by hand on a single element is the fastest way to catch a sign error in the backward.

7. Optional: Register with the JIT Compiler

Registering the op makes the JIT Compiler recognize op_name="tanh" and fuse tanh into chains of other elementwise ops. The cost is four small functions and one register_primitive call.

from netcl.autograd.compiler import register_primitive

def tanh_fwd(args, attrs):
    # args == ["v0"]; return a single C expression for the output.
    return f"tanh({args[0]})"

def tanh_bwd(args, grad_var, attrs, out_var):
    # out_var is the C name of the local copy of the forward output.
    return [f"({grad_var} * (1.0f - {out_var} * {out_var}))"]

register_primitive("tanh", tanh_fwd, tanh_bwd, arity=1, fusible=True)

The two callables are tiny code generators:

  • forward(args, attrs) -> str — given the C names of the input variables and the scalar attributes, return one C expression for the output.
  • backward(args, grad_var, attrs, out_var) -> List[str] — given the inputs, the upstream-gradient C name, and the local C name of the forward output, return one C expression per parent.

After register_primitive("tanh", ...), you can decorate a function that calls tanh_node with @jit_compile and the entire chain fuses into a single forward kernel and a single backward kernel. If you skip this step, @jit_compiled functions that hit tanh silently fall back to the un-fused implementation — safe, but slower.

Performance Checklist

Performance callout. A few rules of thumb that turn a "correct" custom kernel into a "fast" one:

  1. Vectorize loads and stores with float4. For elementwise kernels over fp32, read and write float4 instead of float whenever the buffer is 16-byte aligned. The vectorized variant in the Tensor Backend cuts bandwidth-bound kernels in half on most GPUs.
  2. Tune the workgroup size with WorkGroupTuner. 256 is a good default for one-D elementwise; 64–128 is usually better for two-D reductions. Profile before committing.
  3. Use -cl-fast-relaxed-math where safe. It enables tanh to be replaced with the hardware intrinsic and lets the compiler re-associate. Safe for activations and for normalization denominators; not safe for reductions that need strict associativity (e.g. some distance metrics). The OpenCL build flag is options=["-cl-fast-relaxed-math"] on cl.Program.build().
  4. Mark the kernel __attribute__((vec_type_hint(float4))). The OpenCL compiler uses this hint to pick the right vector width when the body is also explicitly vectorized. Without it, the compiler often falls back to scalar code.
  5. Prefer register_primitive(..., fusible=True) for elementwise chains. The JIT Compiler can fold a chain of elementwise primitives into one kernel; without fusible=True it will not, and you lose the biggest perf win of the Tensor Backend.
  6. Reuse buffers through the BufferPool. Allocate the output with Tensor.empty_like(x) (or a BufferHandle lease) so the device buffer is reused across calls. The Tensor Backend page explains the pool sizing.
  7. Coalesce memory access. Make sure the body indexes linearly with get_global_id(0); strided access patterns (e.g. get_global_id(0) * stride) are one of the leading causes of kernel slowness on AMD/NVIDIA.

Troubleshooting

  • pyopencl LogicError on clBuildProgram with a cryptic error string. The OpenCL build log is in prog.get_build_info(dev.queue.device, cl.program_build_info.LOG). The most common cause is a typo in the body that the C preprocessor does not flag — e.g. using exp instead of expf (the OpenCL single-precision intrinsic).
  • The kernel builds but produces the wrong result. Confirm the body with print(src) and walk through it. The second most common cause is the launch shape: the global size must be a multiple of the local size, and the rounding must happen before the enqueue.
  • detect_anomaly reports a NaN in the gradient. The forward result y is used inside grad_fn (e.g. 1 - tanh(x)^2). If y is 0 or inf anywhere, the derivative blows up. Use a numerically stable form (1 - y*y is safe for tanh; for sigmoid use the standard sigmoid(x) * (1 - sigmoid(x))).
  • The custom op is correct but @jit_compile is silent about it. The decorator falls back to the un-fused implementation when an op is not registered. Call register_primitive for your op_name and confirm the cache hit with a print inside the forward closure.

See also