netcl.distributed — Collectives & Data-Parallel
netcl.distributed — Collectives & Data-Parallel
The distributed API is a host-based minimum-viable-product for
multi-device training on a single node. It targets a single workstation with N
OpenCL devices (typically 1-4 GPUs on a developer box),
and is not a cluster runtime — there is no RDMA, no MPI, no NCCL. The transport is
numpy round-trips through host memory plus, when all participants share a
cl.Context, an optional device-side pairwise sum.
Note — Top-level re-exports.
netcl/distributed/__init__.pyre-exportsall_reduce, broadcast, scatter, gather, DeviceManager, shard_batch, replicate_params, sync_grads, broadcast_params, prepare_replicas, data_parallel_step. Everything in this page is therefore reachable viafrom netcl import distributed as distanddist.<name>(...).
Public API
| Symbol | Path | Purpose |
|---|---|---|
all_reduce(tensors, op="sum"|"mean", overlap=False) |
distributed/collectives.py |
In-place mean / sum across a list of Tensor replicas |
broadcast(tensor_list, root=0) |
distributed/collectives.py |
Copy the root-th Tensor onto every other one |
scatter(tensor, chunks) |
distributed/collectives.py |
Split a Tensor along axis 0 into chunks pieces |
gather(tensors) |
distributed/collectives.py |
Concatenate a list of Tensors along axis 0 |
all_reduce_p2p(tensors, op="sum") |
distributed/collectives.py |
Device-side pairwise sum; falls back to host all_reduce |
shard_batch(x, num_shards) |
distributed/data_parallel.py |
np.array_split along axis 0 |
replicate_params(params, queues) |
distributed/data_parallel.py |
Copy parameter Tensors onto each queue |
sync_grads(param_replicas) |
distributed/data_parallel.py |
Mean-reduce every .grad across replicas |
broadcast_params(src_params, dst_param_groups, root=0) |
distributed/data_parallel.py |
Refresh every replica's parameters from the source group |
prepare_replicas(params, queues) |
distributed/trainer.py |
Convenience alias for replicate_params |
data_parallel_step(forward_fn, param_replicas, optimizers, batch, *, pre_shard_hook=None, post_sync_hook=None, post_step_hook=None) |
distributed/trainer.py |
One full forward / backward / sync / step iteration per device |
DeviceManager |
distributed/device_manager.py |
Multi-queue manager; re-imports the core.DeviceManager |
Collectives
The four primitives live in distributed/collectives.py and are
implemented entirely in host memory unless all participants share a
cl.Context, in which case all_reduce_p2p can be used for a faster device-side sum.
from netcl.distributed import all_reduce, broadcast
t = ... # Tensor
all_reduce([t], op="mean") # in-place mean across the list
broadcast(t, src=0) # in-place broadcast (note: takes a list)
all_reduce
all_reduce(tensors, op="sum", overlap=False)
op is one of "sum" or "mean". The implementation pulls every
Tensor to host (t.to_host()), reduces them with np.sum / np.mean, and
copies the result back. If overlap=True the copies are issued on parallel
threading.Threads, which is a measurable win on multi-GPU boxes where the H2D copy is
the bottleneck. When participants live in different cl.Contexts (e.g. an NVIDIA card
and an Intel iGPU), a warning is printed and the host fallback is used.
all_reduce_p2p
all_reduce_p2p(tensors, op="sum")
Falls back to all_reduce if pyopencl is missing or if the participants span more than
one cl.Context. Otherwise it JIT-compiles a tiny ar_sum kernel and does the reduction
in place on the device, pairwise. Useful when the workload is bandwidth-bound by host
round-trips.
broadcast
broadcast(tensor_list, root=0)
Note the calling convention: a list of Tensors, not a single
Tensor. The function returns the same list with every entry (except
root) replaced by the contents of tensor_list[root].
scatter
scatter(tensor, chunks)
Returns a list of chunks NumPy arrays (not Tensors!) obtained from
np.array_split(tensor.to_host(), chunks, axis=0). Intended for the data-parallel
shard path; pair it with shard_batch on the loader side.
gather
gather(tensors)
The inverse of scatter. Concatenates a list of
Tensors along axis 0 and returns a NumPy array.
Data-Parallel Setup
The end-to-end recipe is:
from netcl.distributed import (
DeviceManager, prepare_replicas, data_parallel_step,
)
from netcl.optim import Adam
dm = DeviceManager() # one queue per device
queues = dm.get_queues()
replicas = prepare_replicas(model.parameters(), queues) # copy params onto each queue
opt_replicas = [Adam(r, lr=1e-3) for r in replicas] # one optimizer per replica
for x, y in loader:
loss = data_parallel_step(
forward_fn=forward_pass, # (queue, xb, yb, params) -> (loss_node, tape)
param_replicas=replicas,
optimizers=opt_replicas,
batch=(x, y),
)
data_parallel_step
data_parallel_step(
forward_fn, param_replicas, optimizers, batch, *,
pre_shard_hook=None, post_sync_hook=None, post_step_hook=None,
)
This is the single function that drives one full data-parallel iteration. Internally it runs the following five steps:
shard_batch(batch, n_devices)— split(x, y)along axis 0 into N sub-batches.- For each device
i: callforward_fn(queue_i, shards_x[i], shards_y[i], param_replicas[i])to get a(loss_node, tape), thentape.backward(loss_node). sync_grads(param_replicas)— in-place mean of everyparam.gradacross replicas.- For each device:
optimizer.step()andoptimizer.zero_grad(). Anypost_step_hookis called per-device here. - (No
broadcast_paramsis called fromdata_parallel_step— see the manual pattern below if you need it.)
The pre_shard_hook, post_sync_hook, and post_step_hook keywords let you plug in
extra logic (logging, NaN guards, learning-rate schedule ticks, etc.) at well-defined
points in the loop.
Manual broadcast_params Pattern
If your model uses running statistics (e.g. BatchNorm) or any buffer that is
not a parameter, you typically want the head replica's view of the world to win after
each step. Call broadcast_params explicitly:
for x, y in loader:
loss = data_parallel_step(...)
broadcast_params(src_params=None, dst_param_groups=replicas, root=0)
src_params is an optional override; when None, broadcast_params
uses dst_param_groups[root] as the source. The function applies the broadcast
parameter-by-parameter, so it works with mismatched group shapes (for example, when
device 0 has more parameters than the others because of a conditional head).
DeviceManager
from netcl.distributed import DeviceManager
dm = DeviceManager() # one context+queue per OpenCL device
dm.num_devices() # -> 4
dm.get_queues() # -> [cl.CommandQueue, …]
The class is a thin convenience over
netcl.core.device.manager.default. It defaults to the discovered default
context's devices, but you can pass an explicit list to pin it to a subset (for
example, "the two discrete GPUs" on a box that also has an integrated GPU).
Note — Same class as
core.DeviceManager.distributed.DeviceManageris the same class ascore.DeviceManager, re-imported for ergonomics. Use whichever import path feels more natural in the surrounding code.
Lifecycle: Replicas, Subprocesses, Shutdown
The data-parallel helpers above assume in-process replicas — multiple OpenCL contexts living in the same Python interpreter. For very large models or for processes that need a hard memory isolation barrier, the recommended pattern is subprocess-based replicas with IPC for gradient exchange. The pattern looks like this:
- Spawn. Use
multiprocessing(or any other subprocess primitive) to launch N child processes, each owning onecl.Contextand one parameter replica. - Forward / backward. Each child pulls its shard of the batch, runs forward +
backward, and writes the local gradients to a
multiprocessing.shared_memorysegment. - Sync. The driver process (or one of the children acting as the "head") reads the
per-replica gradients, mean-reduces them with
all_reduceover a list of buffer views into the SHM segments, and writes the result back. - Step. Each child reads the reduced gradients, calls its
Optimizer's
step(), and writes the updated parameters back into SHM. - Graceful shutdown. A
SIGINThandler in each child callsqueue.finish()to drain pending kernels before exiting, and the SHM segments are unlinked in the driver'satexithook.
The runtime/scheduler.Stream abstraction is the right tool for
overlapping H2D copies with compute inside a single process; the subprocess-based
pattern is what you reach for when you need to escape the per-process OpenCL device
quota of a vendor driver.
See also
- Data-Parallel Training — the worked tutorial that uses
data_parallel_stepend-to-end. - Distributed Architecture — the dataflow diagram, the IPC contract, and the failure modes (stragglers, crashed replicas).
- DeviceManager — the lower-level device manager that
distributed.DeviceManagerre-exports. - Optimizer — the Adam / SGD / AdamW
used inside
data_parallel_step. - DataLoader — the loader that feeds the data-parallel loop.
- AMP — wraps the forward pass in
autocastfor half precision. - JIT Compiler — fuses the forward path inside each replica.
- Tensor — the value type that
sync_gradsandbroadcast_paramsoperate on.