< Annie Hu

Distributed Communication Patterns

Interactive guide to PyTorch distributed operations with animated visualizations

Overview

Distributed computing patterns enable multiple processes or devices to communicate and coordinate work. In PyTorch distributed training, these patterns are essential for scaling neural networks across multiple GPUs or nodes.

Point-to-Point

Send / Recv

Direct communication between two specific processes. One sends data, another receives it.

Collective

Broadcast

One process sends the same data to all other processes in the group.

Collective

Reduce

All processes contribute data, one process receives the combined result.

Collective

All-Reduce

All processes contribute and all receive the combined result. Most common in training.

Collective

Gather

One process collects data from all processes in the group.

Collective

All-Gather

Every process collects data from all other processes.

Collective

Scatter

One process distributes different pieces of data to different processes.

Collective

Reduce-Scatter

Combines reduce and scatter - reduces data then distributes different parts to processes.

Send / Recv

Direct communication between two specific processes. The kernel code running depends on the node rank.

Rank 0
42
Rank 1
?
Rank 0 has data (42), Rank 1 is waiting to receive
# Rank 0 (sender)
data = torch.tensor([42])
dist.send(data, dst=1)

# Rank 1 (receiver)  
data = torch.empty(1)
dist.recv(data, src=0)

Broadcast

One sender distributes the same data to everyone in the process group. Think of it like 8 GPUs in a node.

Rank 0
42
Rank 1
?
Rank 2
?
Rank 3
?
Rank 0 will broadcast value 42 to all other processes
# All processes run this
data = torch.tensor([42]) if rank == 0 else torch.empty(1)
dist.broadcast(data, src=0)

All-Reduce

Everyone contributes data, everyone gets the combined result. Example: [1,2,3] -> sum = 6, all processes see 6.

BEFORE

Rank 0: [1]
Rank 1: [2]
Rank 2: [3]
->

AFTER

Rank 0: [6]
Rank 1: [6]
Rank 2: [6]
Rank 0
1
Rank 1
2
Rank 2
3
Each process has different data: [1, 2, 3]
# All processes run this
data = torch.tensor([local_value])  # 1, 2, or 3
dist.all_reduce(data, op=dist.ReduceOp.SUM)
# Now all processes have data = 6

Reduce

Everyone contributes, only one process (usually rank 0) gets the result.

Rank 0
1
Rank 1
2
Rank 2
3
Rank 0 will receive the sum of all values

Gather

One process collects data from all processes. Destination sees [1, 2, 3].

BEFORE

Rank 0: [1]
Rank 1: [2]
Rank 2: [3]
->

AFTER

Rank 0: [1,2,3]
Rank 1: [2]
Rank 2: [3]
Rank 0
[1]
Rank 1
2
Rank 2
3
Rank 0 will gather all values into a list

All-Gather

Everyone collects data from everyone. All processes see [1, 2, 3].

BEFORE

Rank 0: [1]
Rank 1: [2]
Rank 2: [3]
->

AFTER

Rank 0: [1,2,3]
Rank 1: [1,2,3]
Rank 2: [1,2,3]
Rank 0
1
Rank 1
2
Rank 2
3
All processes will gather all values

Scatter

One process distributes different pieces to different processes. Each device gets its own index value back.

BEFORE

Rank 0: [1,2,3]
Rank 1: []
Rank 2: []
->

AFTER

Rank 0: [1]
Rank 1: [2]
Rank 2: [3]
Rank 0
[1,2,3]
Rank 1
?
Rank 2
?
Rank 0 will distribute different values to each process

Reduce-Scatter

Combines reduce and scatter. We get one tensor, then each device sees one part of that tensor.

Rank 0
[1,2]
Rank 1
[3,4]
Each process contributes a tensor, gets back one element

Communication Topologies

Understanding the data flow patterns helps choose the right operation for your use case.

1-to-1 (Point-to-Point)

A
->
B

Operations: send, recv

Direct communication between two specific processes.

1-to-Many (Broadcast)

A
-> B
-> C
-> D

Operations: broadcast, scatter

One process distributes data to all others.

Many-to-1 (Reduction)

A ->
B ->
C ->
D

Operations: reduce, gather

Multiple processes send data to one collector.

All-to-All (Full Exchange)

A
B
C
<->

Operations: all_reduce, all_gather, all_to_all

Every process communicates with every other process.

Note on Multi-dimensional Tensors: In practice, tensors are multi-dimensional (e.g., gradients with shape [1024, 512]). We use simple scalar values in these examples for clarity, but the same patterns apply to tensors of any shape. PyTorch handles the complexity of transferring multi-dimensional data automatically.

Practical Use Cases in Distributed Training

Data Parallel Training

# Each GPU processes a different batch
# After backward pass, average gradients across all GPUs
for param in model.parameters():
    dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
    param.grad.data /= world_size

Parameter Initialization

# Ensure all processes start with same model weights
for param in model.parameters():
    dist.broadcast(param.data, src=0)

Distributed Validation

# Collect validation results from all processes
local_loss = torch.tensor([validation_loss])
dist.all_reduce(local_loss, op=dist.ReduceOp.SUM)
avg_loss = local_loss.item() / world_size

Model Parallel Training

# Send activations between model parallel ranks
if rank == 0:
    # First part of model
    hidden = model_part1(input)
    dist.send(hidden, dst=1)
elif rank == 1:
    # Second part of model  
    hidden = torch.empty(hidden_size)
    dist.recv(hidden, src=0)
    output = model_part2(hidden)