Distributed Communication Patterns
Interactive guide to PyTorch distributed operations with animated visualizations
Overview
Distributed computing patterns enable multiple processes or devices to communicate and coordinate work. In PyTorch distributed training, these patterns are essential for scaling neural networks across multiple GPUs or nodes.
Send / Recv
Direct communication between two specific processes. One sends data, another receives it.
Broadcast
One process sends the same data to all other processes in the group.
Reduce
All processes contribute data, one process receives the combined result.
All-Reduce
All processes contribute and all receive the combined result. Most common in training.
Gather
One process collects data from all processes in the group.
All-Gather
Every process collects data from all other processes.
Scatter
One process distributes different pieces of data to different processes.
Reduce-Scatter
Combines reduce and scatter - reduces data then distributes different parts to processes.
Send / Recv
Direct communication between two specific processes. The kernel code running depends on the node rank.
# Rank 0 (sender)
data = torch.tensor([42])
dist.send(data, dst=1)
# Rank 1 (receiver)
data = torch.empty(1)
dist.recv(data, src=0)
Broadcast
One sender distributes the same data to everyone in the process group. Think of it like 8 GPUs in a node.
# All processes run this
data = torch.tensor([42]) if rank == 0 else torch.empty(1)
dist.broadcast(data, src=0)
All-Reduce
Everyone contributes data, everyone gets the combined result. Example: [1,2,3] -> sum = 6, all processes see 6.
BEFORE
AFTER
# All processes run this
data = torch.tensor([local_value]) # 1, 2, or 3
dist.all_reduce(data, op=dist.ReduceOp.SUM)
# Now all processes have data = 6
Reduce
Everyone contributes, only one process (usually rank 0) gets the result.
Gather
One process collects data from all processes. Destination sees [1, 2, 3].
BEFORE
AFTER
All-Gather
Everyone collects data from everyone. All processes see [1, 2, 3].
BEFORE
AFTER
Scatter
One process distributes different pieces to different processes. Each device gets its own index value back.
BEFORE
AFTER
Reduce-Scatter
Combines reduce and scatter. We get one tensor, then each device sees one part of that tensor.
Communication Topologies
Understanding the data flow patterns helps choose the right operation for your use case.
1-to-1 (Point-to-Point)
Operations: send, recv
Direct communication between two specific processes.
1-to-Many (Broadcast)
Operations: broadcast, scatter
One process distributes data to all others.
Many-to-1 (Reduction)
Operations: reduce, gather
Multiple processes send data to one collector.
All-to-All (Full Exchange)
Operations: all_reduce, all_gather, all_to_all
Every process communicates with every other process.
Note on Multi-dimensional Tensors: In practice, tensors are multi-dimensional (e.g., gradients with shape [1024, 512]). We use simple scalar values in these examples for clarity, but the same patterns apply to tensors of any shape. PyTorch handles the complexity of transferring multi-dimensional data automatically.
Practical Use Cases in Distributed Training
Data Parallel Training
# Each GPU processes a different batch
# After backward pass, average gradients across all GPUs
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= world_size
Parameter Initialization
# Ensure all processes start with same model weights
for param in model.parameters():
dist.broadcast(param.data, src=0)
Distributed Validation
# Collect validation results from all processes
local_loss = torch.tensor([validation_loss])
dist.all_reduce(local_loss, op=dist.ReduceOp.SUM)
avg_loss = local_loss.item() / world_size
Model Parallel Training
# Send activations between model parallel ranks
if rank == 0:
# First part of model
hidden = model_part1(input)
dist.send(hidden, dst=1)
elif rank == 1:
# Second part of model
hidden = torch.empty(hidden_size)
dist.recv(hidden, src=0)
output = model_part2(hidden)