Distributed Compute
Kubetorch provides built-in support for distributed computing, where work is executed in parallel across multiple pods. This guide covers configuring distributed services, the execution model, and handling dynamic cluster changes.
When to Use Distributed
Use .distribute() when you want:
- Parallel execution: All pods execute the same function simultaneously
- Coordinated communication: Pods communicate via collective operations (allreduce, broadcast)
This is ideal for distributed training (PyTorch DDP, JAX), data processing with coordination, and any SPMD (Single Program Multiple Data) workload.
Note: Distributed and autoscaling are mutually exclusive. Use .distribute() for parallel
execution or .autoscale() for load-balanced round-robin services.
Distribution Types
Kubetorch supports several distribution frameworks:
| Type | Pattern | Use Case |
|---|---|---|
spmd | SPMD (all workers execute) | Generic distributed, custom frameworks |
pytorch | SPMD with PyTorch env | PyTorch DDP training |
jax | SPMD with JAX env | JAX distributed training |
tensorflow | SPMD with TF env | TensorFlow distributed training |
ray | Single controller | Ray-based workloads |
monarch | Single controller | Meta's Monarch framework |
SPMD Frameworks (pytorch, jax, tensorflow, spmd)
In SPMD mode, when you call a function:
- DNS Discovery: The coordinator queries the headless K8s service to discover all pod IPs
- Quorum Wait: Coordinator waits until
quorum_workerspods are ready (DNS-based) - Broadcast: Coordinator sends the call to all workers via HTTP (using a RemoteWorkerPool)
- Parallel Execution: Each worker executes the function with its assigned rank
- Aggregation: Results from all workers are collected and returned as a list
Client Cluster β β βββββ call(args) βββββββββΆβ Coordinator (rank 0) β β β β β β (DNS lookup for pod IPs) β β β (wait for quorum) β β β β β βββHTTPβββΆ Worker 1 (rank 1) β β βββHTTPβββΆ Worker 2 (rank 2) β β βββHTTPβββΆ Worker 3 (rank 3) β β βββββ [result0, result1, β (all workers execute in parallel) β result2, result3] ββ
Tree Topology for Large Clusters: For clusters with 100+ workers, Kubetorch automatically uses a tree topology instead of flat broadcast. Each node calls a subset of children (default fanout: 50), which then call their children. This prevents the coordinator from being overwhelmed.
Single Controller Frameworks (ray, monarch)
In single controller mode, only the head/controller node receives and executes calls. Workers are available for the controller to dispatch work to via the framework's APIs.
Client Cluster β β βββββ call(args) βββββββββΆβ Head Node (controller) β β β β β β (dispatches via Ray/Monarch) β β ββββΆ Worker 1 β β ββββΆ Worker 2 β β βββββ result ββββββββββββββ (head returns result)
Basic Configuration
PyTorch Distributed
import kubetorch as kt compute = kt.Compute( cpus="4", gpus="1", memory="16Gi", image=kt.Image("pytorch/pytorch:2.0.0"), ).distribute("pytorch", workers=4) def train(epochs): import torch.distributed as dist # Environment is pre-configured by Kubetorch dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() # Your distributed training code model = MyModel().cuda() model = torch.nn.parallel.DistributedDataParallel(model) # ... training loop return {"rank": rank, "loss": final_loss} train_fn = kt.fn(train).to(compute) results = train_fn(epochs=10) # Returns list of results from all 4 workers
JAX Distributed
compute = kt.Compute( cpus="4", gpus="1", image=kt.Image("python:3.10").pip_install(["jax[cuda]"]), ).distribute("jax", workers=4) def train(data): import jax # JAX automatically discovers distributed setup devices = jax.devices() # ... JAX training code return result train_fn = kt.fn(train).to(compute)
Ray Distributed
compute = kt.Compute( cpus="4", memory="8Gi", image=kt.Image("rayproject/ray:2.9.0"), ).distribute("ray", workers=4) def ray_workload(data): import ray @ray.remote def process(item): return item * 2 # Ray cluster is pre-initialized futures = [process.remote(item) for item in data] return ray.get(futures) workload_fn = kt.fn(ray_workload).to(compute)
Environment Variables
Kubetorch sets these environment variables on all workers:
| Variable | Description | Example |
|---|---|---|
WORLD_SIZE | Total number of processes across all pods | 16 (4 pods Γ 4 procs) |
RANK | Global rank of this process | 0 to WORLD_SIZE-1 |
LOCAL_RANK | Rank within this pod | 0 to num_proc-1 |
NODE_RANK | Rank of this pod | 0 to workers-1 |
POD_IPS | Comma-separated list of all pod IPs | 10.0.1.1,10.0.1.2,10.0.1.3 |
MASTER_ADDR | IP of the coordinator pod (PyTorch) | 10.0.1.1 |
MASTER_PORT | Port for coordination (PyTorch) | 29500 |
Using Environment Variables
import os def distributed_fn(): world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) pod_ips = os.environ["POD_IPS"].split(",") print(f"I am rank {rank}/{world_size}, local_rank {local_rank}") print(f"Cluster IPs: {pod_ips}")
Multi-Process Per Pod
For multi-GPU pods, run multiple processes per pod with num_proc:
compute = kt.Compute( cpus="32", gpus="8", # 8 GPUs per pod memory="128Gi", ).distribute("pytorch", workers=4, num_proc=8) # 4 pods Γ 8 procs = 32 total
Each process gets its own LOCAL_RANK (0-7) and RANK (0-31).
Quorum and Timeouts
Control how long to wait for workers to be ready:
compute = kt.Compute(cpus="4", gpus="1").distribute( "pytorch", workers=8, quorum_timeout=600, # Wait up to 10 min for all workers quorum_workers=8, # Require all 8 workers before starting )
| Parameter | Description | Default |
|---|---|---|
quorum_timeout | Seconds to wait for workers | launch_timeout |
quorum_workers | Minimum workers required to start | workers |
Tip: Increase quorum_timeout when using node autoscaling, as new nodes take time to provision.
Dynamic World Size
Kubetorch supports resizing distributed clusters on the fly.
Resizing a Cluster
Change workers or num_proc and call .to() again:
# Initial deployment with 4 workers compute = kt.Compute(gpus="1").distribute("pytorch", workers=4) train_fn = kt.fn(train).to(compute) # Training iteration results = train_fn(epochs=10) # Scale up to 8 workers compute = kt.Compute(gpus="1").distribute("pytorch", workers=8) train_fn = kt.fn(train).to(compute) # Cluster resizes # Continue training with more workers results = train_fn(epochs=10)
The new pods join, the coordinator updates POD_IPS, and the next call executes with the new world size.
Handling Membership Changes
When pods are added or removed (manually or due to failures), Kubetorch raises
kt.WorkerMembershipChanged. Your code can catch this to handle the change gracefully:
import kubetorch as kt def train_with_elastic(epochs): import torch.distributed as dist for epoch in range(epochs): try: # Training step train_epoch() except kt.WorkerMembershipChanged as e: print(f"Cluster changed: added={e.added_ips}, removed={e.removed_ips}") print(f"New cluster: {e.current_ips}") # Reinitialize process group with new world dist.destroy_process_group() dist.init_process_group(backend="nccl") # Optionally reload checkpoint load_checkpoint() train_fn = kt.fn(train_with_elastic).to(compute)
The exception includes:
added_ips: Set of IPs that joinedremoved_ips: Set of IPs that leftprevious_ips: Set of IPs before the changecurrent_ips: Set of IPs after the change
Membership Monitoring
By default, SPMD frameworks monitor for membership changes. You can disable this:
compute = kt.Compute(gpus="1").distribute( "pytorch", workers=4, monitor_members=False, # Don't raise exceptions on changes )
Ray and Monarch manage their own membership, so monitor_members is disabled by default for them.
Examples
Dynamic World Size Training
See the Dynamic World Size PyTorch example for a complete walkthrough of elastic training with Kubetorch.
Basic Distributed Training
import kubetorch as kt def distributed_train(config): import torch import torch.distributed as dist dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() local_rank = int(os.environ["LOCAL_RANK"]) # Set device torch.cuda.set_device(local_rank) # Create model and wrap with DDP model = MyModel().cuda() model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank] ) # Create distributed sampler dataset = MyDataset() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank ) loader = torch.utils.data.DataLoader(dataset, sampler=sampler) # Training loop for epoch in range(config["epochs"]): sampler.set_epoch(epoch) for batch in loader: # ... training step pass # Only rank 0 saves if rank == 0: torch.save(model.state_dict(), "model.pt") return {"rank": rank, "final_loss": loss.item()} compute = kt.Compute( cpus="8", gpus="1", memory="32Gi", image=kt.Image("pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime"), ).distribute("pytorch", workers=4) train_fn = kt.fn(distributed_train).to(compute) results = train_fn({"epochs": 10, "lr": 0.001}) # results is a list of dicts from all 4 workers print(f"Rank 0 loss: {results[0]['final_loss']}")
Distribution Options Reference
compute = kt.Compute(cpus="4", gpus="1").distribute( distribution_type="pytorch", # "spmd", "pytorch", "jax", "tensorflow", "ray", "monarch" workers=4, # Number of pods num_proc=1, # Processes per pod (for multi-GPU) quorum_timeout=300, # Seconds to wait for workers quorum_workers=4, # Minimum workers to start monitor_members=True, # Raise on membership changes (SPMD only) )