netcl wiki
api

netcl.distributed — Collectives & Data-Parallel

netcl.distributed — Collectives & Data-Parallel

The distributed API is a host-based minimum-viable-product for multi-device training on a single node. It targets a single workstation with N OpenCL devices (typically 1-4 GPUs on a developer box), and is not a cluster runtime — there is no RDMA, no MPI, no NCCL. The transport is numpy round-trips through host memory plus, when all participants share a cl.Context, an optional device-side pairwise sum.

Note — Top-level re-exports. netcl/distributed/__init__.py re-exports all_reduce, broadcast, scatter, gather, DeviceManager, shard_batch, replicate_params, sync_grads, broadcast_params, prepare_replicas, data_parallel_step. Everything in this page is therefore reachable via from netcl import distributed as dist and dist.<name>(...).

Public API

Symbol Path Purpose
all_reduce(tensors, op="sum"|"mean", overlap=False) distributed/collectives.py In-place mean / sum across a list of Tensor replicas
broadcast(tensor_list, root=0) distributed/collectives.py Copy the root-th Tensor onto every other one
scatter(tensor, chunks) distributed/collectives.py Split a Tensor along axis 0 into chunks pieces
gather(tensors) distributed/collectives.py Concatenate a list of Tensors along axis 0
all_reduce_p2p(tensors, op="sum") distributed/collectives.py Device-side pairwise sum; falls back to host all_reduce
shard_batch(x, num_shards) distributed/data_parallel.py np.array_split along axis 0
replicate_params(params, queues) distributed/data_parallel.py Copy parameter Tensors onto each queue
sync_grads(param_replicas) distributed/data_parallel.py Mean-reduce every .grad across replicas
broadcast_params(src_params, dst_param_groups, root=0) distributed/data_parallel.py Refresh every replica's parameters from the source group
prepare_replicas(params, queues) distributed/trainer.py Convenience alias for replicate_params
data_parallel_step(forward_fn, param_replicas, optimizers, batch, *, pre_shard_hook=None, post_sync_hook=None, post_step_hook=None) distributed/trainer.py One full forward / backward / sync / step iteration per device
DeviceManager distributed/device_manager.py Multi-queue manager; re-imports the core.DeviceManager

Collectives

The four primitives live in distributed/collectives.py and are implemented entirely in host memory unless all participants share a cl.Context, in which case all_reduce_p2p can be used for a faster device-side sum.

from netcl.distributed import all_reduce, broadcast

t = ...   # Tensor
all_reduce([t], op="mean")               # in-place mean across the list
broadcast(t, src=0)                      # in-place broadcast (note: takes a list)

all_reduce

all_reduce(tensors, op="sum", overlap=False)

op is one of "sum" or "mean". The implementation pulls every Tensor to host (t.to_host()), reduces them with np.sum / np.mean, and copies the result back. If overlap=True the copies are issued on parallel threading.Threads, which is a measurable win on multi-GPU boxes where the H2D copy is the bottleneck. When participants live in different cl.Contexts (e.g. an NVIDIA card and an Intel iGPU), a warning is printed and the host fallback is used.

all_reduce_p2p

all_reduce_p2p(tensors, op="sum")

Falls back to all_reduce if pyopencl is missing or if the participants span more than one cl.Context. Otherwise it JIT-compiles a tiny ar_sum kernel and does the reduction in place on the device, pairwise. Useful when the workload is bandwidth-bound by host round-trips.

broadcast

broadcast(tensor_list, root=0)

Note the calling convention: a list of Tensors, not a single Tensor. The function returns the same list with every entry (except root) replaced by the contents of tensor_list[root].

scatter

scatter(tensor, chunks)

Returns a list of chunks NumPy arrays (not Tensors!) obtained from np.array_split(tensor.to_host(), chunks, axis=0). Intended for the data-parallel shard path; pair it with shard_batch on the loader side.

gather

gather(tensors)

The inverse of scatter. Concatenates a list of Tensors along axis 0 and returns a NumPy array.

Data-Parallel Setup

The end-to-end recipe is:

from netcl.distributed import (
    DeviceManager, prepare_replicas, data_parallel_step,
)
from netcl.optim import Adam

dm = DeviceManager()                                    # one queue per device
queues = dm.get_queues()
replicas = prepare_replicas(model.parameters(), queues)  # copy params onto each queue
opt_replicas = [Adam(r, lr=1e-3) for r in replicas]     # one optimizer per replica

for x, y in loader:
    loss = data_parallel_step(
        forward_fn=forward_pass,    # (queue, xb, yb, params) -> (loss_node, tape)
        param_replicas=replicas,
        optimizers=opt_replicas,
        batch=(x, y),
    )

data_parallel_step

data_parallel_step(
    forward_fn, param_replicas, optimizers, batch, *,
    pre_shard_hook=None, post_sync_hook=None, post_step_hook=None,
)

This is the single function that drives one full data-parallel iteration. Internally it runs the following five steps:

  1. shard_batch(batch, n_devices) — split (x, y) along axis 0 into N sub-batches.
  2. For each device i: call forward_fn(queue_i, shards_x[i], shards_y[i], param_replicas[i]) to get a (loss_node, tape), then tape.backward(loss_node).
  3. sync_grads(param_replicas) — in-place mean of every param.grad across replicas.
  4. For each device: optimizer.step() and optimizer.zero_grad(). Any post_step_hook is called per-device here.
  5. (No broadcast_params is called from data_parallel_step — see the manual pattern below if you need it.)

The pre_shard_hook, post_sync_hook, and post_step_hook keywords let you plug in extra logic (logging, NaN guards, learning-rate schedule ticks, etc.) at well-defined points in the loop.

Manual broadcast_params Pattern

If your model uses running statistics (e.g. BatchNorm) or any buffer that is not a parameter, you typically want the head replica's view of the world to win after each step. Call broadcast_params explicitly:

for x, y in loader:
    loss = data_parallel_step(...)
    broadcast_params(src_params=None, dst_param_groups=replicas, root=0)

src_params is an optional override; when None, broadcast_params uses dst_param_groups[root] as the source. The function applies the broadcast parameter-by-parameter, so it works with mismatched group shapes (for example, when device 0 has more parameters than the others because of a conditional head).

DeviceManager

from netcl.distributed import DeviceManager

dm = DeviceManager()              # one context+queue per OpenCL device
dm.num_devices()                  # -> 4
dm.get_queues()                   # -> [cl.CommandQueue, …]

The class is a thin convenience over netcl.core.device.manager.default. It defaults to the discovered default context's devices, but you can pass an explicit list to pin it to a subset (for example, "the two discrete GPUs" on a box that also has an integrated GPU).

Note — Same class as core.DeviceManager. distributed.DeviceManager is the same class as core.DeviceManager, re-imported for ergonomics. Use whichever import path feels more natural in the surrounding code.

Lifecycle: Replicas, Subprocesses, Shutdown

The data-parallel helpers above assume in-process replicas — multiple OpenCL contexts living in the same Python interpreter. For very large models or for processes that need a hard memory isolation barrier, the recommended pattern is subprocess-based replicas with IPC for gradient exchange. The pattern looks like this:

  1. Spawn. Use multiprocessing (or any other subprocess primitive) to launch N child processes, each owning one cl.Context and one parameter replica.
  2. Forward / backward. Each child pulls its shard of the batch, runs forward + backward, and writes the local gradients to a multiprocessing.shared_memory segment.
  3. Sync. The driver process (or one of the children acting as the "head") reads the per-replica gradients, mean-reduces them with all_reduce over a list of buffer views into the SHM segments, and writes the result back.
  4. Step. Each child reads the reduced gradients, calls its Optimizer's step(), and writes the updated parameters back into SHM.
  5. Graceful shutdown. A SIGINT handler in each child calls queue.finish() to drain pending kernels before exiting, and the SHM segments are unlinked in the driver's atexit hook.

The runtime/scheduler.Stream abstraction is the right tool for overlapping H2D copies with compute inside a single process; the subprocess-based pattern is what you reach for when you need to escape the per-process OpenCL device quota of a vendor driver.

See also