Distributed Compute

Kubetorch provides built-in support for distributed computing, where work is executed in parallel across multiple pods. This guide covers configuring distributed services, the execution model, and handling dynamic cluster changes.

When to Use Distributed

Use .distribute() when you want:

  • Parallel execution: All pods execute the same function simultaneously
  • Coordinated communication: Pods communicate via collective operations (allreduce, broadcast)

This is ideal for distributed training (PyTorch DDP, JAX), data processing with coordination, and any SPMD (Single Program Multiple Data) workload.

Note: Distributed and autoscaling are mutually exclusive. Use .distribute() for parallel execution or .autoscale() for load-balanced round-robin services.

Distribution Types

Kubetorch supports several distribution frameworks:

TypePatternUse Case
spmdSPMD (all workers execute)Generic distributed, custom frameworks
pytorchSPMD with PyTorch envPyTorch DDP training
jaxSPMD with JAX envJAX distributed training
tensorflowSPMD with TF envTensorFlow distributed training
raySingle controllerRay-based workloads
monarchSingle controllerMeta's Monarch framework

SPMD Frameworks (pytorch, jax, tensorflow, spmd)

In SPMD mode, when you call a function:

  1. DNS Discovery: The coordinator queries the headless K8s service to discover all pod IPs
  2. Quorum Wait: Coordinator waits until quorum_workers pods are ready (DNS-based)
  3. Broadcast: Coordinator sends the call to all workers via HTTP (using a RemoteWorkerPool)
  4. Parallel Execution: Each worker executes the function with its assigned rank
  5. Aggregation: Results from all workers are collected and returned as a list
Client Cluster β”‚ β”‚ │──── call(args) ────────▢│ Coordinator (rank 0) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ (DNS lookup for pod IPs) β”‚ β”‚ β”‚ (wait for quorum) β”‚ β”‚ β”‚ β”‚ β”‚ β”œβ”€β”€HTTP──▢ Worker 1 (rank 1) β”‚ β”‚ β”œβ”€β”€HTTP──▢ Worker 2 (rank 2) β”‚ β”‚ └──HTTP──▢ Worker 3 (rank 3) β”‚ β”‚ │◀─── [result0, result1, β”‚ (all workers execute in parallel) β”‚ result2, result3] ─│

Tree Topology for Large Clusters: For clusters with 100+ workers, Kubetorch automatically uses a tree topology instead of flat broadcast. Each node calls a subset of children (default fanout: 50), which then call their children. This prevents the coordinator from being overwhelmed.

Single Controller Frameworks (ray, monarch)

In single controller mode, only the head/controller node receives and executes calls. Workers are available for the controller to dispatch work to via the framework's APIs.

Client Cluster β”‚ β”‚ │──── call(args) ────────▢│ Head Node (controller) β”‚ β”‚ β”‚ β”‚ β”‚ β”‚ (dispatches via Ray/Monarch) β”‚ β”‚ β”œβ”€β”€β–Ά Worker 1 β”‚ β”‚ └──▢ Worker 2 β”‚ β”‚ │◀─── result ─────────────│ (head returns result)

Basic Configuration

PyTorch Distributed

import kubetorch as kt compute = kt.Compute( cpus="4", gpus="1", memory="16Gi", image=kt.Image("pytorch/pytorch:2.0.0"), ).distribute("pytorch", workers=4) def train(epochs): import torch.distributed as dist # Environment is pre-configured by Kubetorch dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() # Your distributed training code model = MyModel().cuda() model = torch.nn.parallel.DistributedDataParallel(model) # ... training loop return {"rank": rank, "loss": final_loss} train_fn = kt.fn(train).to(compute) results = train_fn(epochs=10) # Returns list of results from all 4 workers

JAX Distributed

compute = kt.Compute( cpus="4", gpus="1", image=kt.Image("python:3.10").pip_install(["jax[cuda]"]), ).distribute("jax", workers=4) def train(data): import jax # JAX automatically discovers distributed setup devices = jax.devices() # ... JAX training code return result train_fn = kt.fn(train).to(compute)

Ray Distributed

compute = kt.Compute( cpus="4", memory="8Gi", image=kt.Image("rayproject/ray:2.9.0"), ).distribute("ray", workers=4) def ray_workload(data): import ray @ray.remote def process(item): return item * 2 # Ray cluster is pre-initialized futures = [process.remote(item) for item in data] return ray.get(futures) workload_fn = kt.fn(ray_workload).to(compute)

Environment Variables

Kubetorch sets these environment variables on all workers:

VariableDescriptionExample
WORLD_SIZETotal number of processes across all pods16 (4 pods Γ— 4 procs)
RANKGlobal rank of this process0 to WORLD_SIZE-1
LOCAL_RANKRank within this pod0 to num_proc-1
NODE_RANKRank of this pod0 to workers-1
POD_IPSComma-separated list of all pod IPs10.0.1.1,10.0.1.2,10.0.1.3
MASTER_ADDRIP of the coordinator pod (PyTorch)10.0.1.1
MASTER_PORTPort for coordination (PyTorch)29500

Using Environment Variables

import os def distributed_fn(): world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) pod_ips = os.environ["POD_IPS"].split(",") print(f"I am rank {rank}/{world_size}, local_rank {local_rank}") print(f"Cluster IPs: {pod_ips}")

Multi-Process Per Pod

For multi-GPU pods, run multiple processes per pod with num_proc:

compute = kt.Compute( cpus="32", gpus="8", # 8 GPUs per pod memory="128Gi", ).distribute("pytorch", workers=4, num_proc=8) # 4 pods Γ— 8 procs = 32 total

Each process gets its own LOCAL_RANK (0-7) and RANK (0-31).

Quorum and Timeouts

Control how long to wait for workers to be ready:

compute = kt.Compute(cpus="4", gpus="1").distribute( "pytorch", workers=8, quorum_timeout=600, # Wait up to 10 min for all workers quorum_workers=8, # Require all 8 workers before starting )
ParameterDescriptionDefault
quorum_timeoutSeconds to wait for workerslaunch_timeout
quorum_workersMinimum workers required to startworkers

Tip: Increase quorum_timeout when using node autoscaling, as new nodes take time to provision.

Dynamic World Size

Kubetorch supports resizing distributed clusters on the fly.

Resizing a Cluster

Change workers or num_proc and call .to() again:

# Initial deployment with 4 workers compute = kt.Compute(gpus="1").distribute("pytorch", workers=4) train_fn = kt.fn(train).to(compute) # Training iteration results = train_fn(epochs=10) # Scale up to 8 workers compute = kt.Compute(gpus="1").distribute("pytorch", workers=8) train_fn = kt.fn(train).to(compute) # Cluster resizes # Continue training with more workers results = train_fn(epochs=10)

The new pods join, the coordinator updates POD_IPS, and the next call executes with the new world size.

Handling Membership Changes

When pods are added or removed (manually or due to failures), Kubetorch raises kt.WorkerMembershipChanged. Your code can catch this to handle the change gracefully:

import kubetorch as kt def train_with_elastic(epochs): import torch.distributed as dist for epoch in range(epochs): try: # Training step train_epoch() except kt.WorkerMembershipChanged as e: print(f"Cluster changed: added={e.added_ips}, removed={e.removed_ips}") print(f"New cluster: {e.current_ips}") # Reinitialize process group with new world dist.destroy_process_group() dist.init_process_group(backend="nccl") # Optionally reload checkpoint load_checkpoint() train_fn = kt.fn(train_with_elastic).to(compute)

The exception includes:

  • added_ips: Set of IPs that joined
  • removed_ips: Set of IPs that left
  • previous_ips: Set of IPs before the change
  • current_ips: Set of IPs after the change

Membership Monitoring

By default, SPMD frameworks monitor for membership changes. You can disable this:

compute = kt.Compute(gpus="1").distribute( "pytorch", workers=4, monitor_members=False, # Don't raise exceptions on changes )

Ray and Monarch manage their own membership, so monitor_members is disabled by default for them.

Examples

Dynamic World Size Training

See the Dynamic World Size PyTorch example for a complete walkthrough of elastic training with Kubetorch.

Basic Distributed Training

import kubetorch as kt def distributed_train(config): import torch import torch.distributed as dist dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() local_rank = int(os.environ["LOCAL_RANK"]) # Set device torch.cuda.set_device(local_rank) # Create model and wrap with DDP model = MyModel().cuda() model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank] ) # Create distributed sampler dataset = MyDataset() sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank ) loader = torch.utils.data.DataLoader(dataset, sampler=sampler) # Training loop for epoch in range(config["epochs"]): sampler.set_epoch(epoch) for batch in loader: # ... training step pass # Only rank 0 saves if rank == 0: torch.save(model.state_dict(), "model.pt") return {"rank": rank, "final_loss": loss.item()} compute = kt.Compute( cpus="8", gpus="1", memory="32Gi", image=kt.Image("pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime"), ).distribute("pytorch", workers=4) train_fn = kt.fn(distributed_train).to(compute) results = train_fn({"epochs": 10, "lr": 0.001}) # results is a list of dicts from all 4 workers print(f"Rank 0 loss: {results[0]['final_loss']}")

Distribution Options Reference

compute = kt.Compute(cpus="4", gpus="1").distribute( distribution_type="pytorch", # "spmd", "pytorch", "jax", "tensorflow", "ray", "monarch" workers=4, # Number of pods num_proc=1, # Processes per pod (for multi-GPU) quorum_timeout=300, # Seconds to wait for workers quorum_workers=4, # Minimum workers to start monitor_members=True, # Raise on membership changes (SPMD only) )