-
Defines the distributed primitives used in Tensor Parallelism and Sequence Parallelism such that the forward pass and backward pass are appropriately defined
-
reduce_scatterandall_reduceare equivalent in terms of communcation becauseall_reduceis implemented using a ring-all-reduce = reduce + scatter- we just need to bucketize by sequence dimension for things to go correctly.
PyTorch specifics
torch.autograd.Function
- To define differentiable operator, you will need to define classes inheriting
torch.autograd.Function. How it works is explained in torch.autograd
@assert_cuda_max_connections_set_to_1
- functiond decorator that sets
CUDA_DEVICE_MAX_CONNECTIONS=1ensures that only one connection can be made to a single GPU from any host thread.- In
ColumnLinearAsync, in the backward pass withtp_mode=REDUCE_SCATTER, they asynchronously fetch
- In
Tensor Parallelism only
Differentiable Identity
- Before a
ColumnLinearoperation (beginning of MLP) - Forward:
f(x) = xi.e. allColumnLinearreceive the same input - Backward:
b(x) = all_reduce_sum(x), because the same input was replicated over all processes
Differentiable_All_Reduce_Sum
- After a
RowLinearoperation (end of MLP) - Forward:
f(x) = all_reduce_sum(x)i.e. we need to accumulate the matrix to get the correct results - Backward:
b(x) = xi.e. the gradient must be replicated over all processes, as they all participated
Tensor + Sequence Parallelism
Differentiable_All_Gather
-
- Before a
ColumnLinearoperation (beginning of MLP)
- Before a
- Forward:
f(x) = all_gather(x)i.e. was split by sequence previously, so we need to gather it back before feeding toColumnLinear - Backward:
b(x) = reduce_scatter_sum(x), because the same input was replicated over all processes, we need to reduce, and then scatter by splitting by sequence dimension.
How to implement it
- nanotron code assumes it’s sharded by the first dimension
- gets the current tensor and
sharded_batch_size, *rest_size = tensor.shape - creates an empty tensor of
unsharded_sizei.e.unsharded_tensor= torch.empty((sharded_batch_size*group.size(), *rest_size)) - call
dist.all_gather_into_tensor(unsharded_tensor, tensor, group)
Differentiable_Reduce_Scatter
- After a
RowLinearoperation (end of MLP) - Forward:
f(x) = reduce_scatter_sum(x)i.e. we need to accumulate the matrix to get the correct results, and then scatter by splitting by sequence dimension - Backward:
b(x) = all_gather(x)i.e. the gradient must be gathered along the sequence dimension and then replicated over all processes, as they all participated
Async Tensor + Sequence Parallelism
Motivation
- If we’re smart, we can overlap communication and computation
- e.g. In backward
- e.g. for
ColumnLinear, start the reduce_scatter ofgrad_tensor, while computing the gradient of the weight and bias i.e.grad_weightandgrad_bias.
- e.g. for
- e.g. In backward
- We rely on
CUDA_DEVICE_MAX_CONNECTIONS=1to ensure that the gather/reduce_scatter is scheduled before the tensor gradient computation in the code.
Details
RowLineardoesn’t support async if the tp_mode isall_reduce(tp only) instead ofreduce_scatter(tp+sequence)- Must Define
_{Row,Column}LinearAsyncCommunication(torch.autograd.Function)classes
Code
ColumnLinear
Forward
- In
def forward(ctx, tensor, weight, bias, group, tp_mode)tensoris sharded/split by sequence dimension- We define the full input to be
gathered_tensor
- What you can do is
ctx.save_for_backward(tensor, weight)(only save the sharded tensor to reduce activation memory, we will gather it back at backward time)- start async gather
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group, async_op=True) - meanwhile, compute the result matmul only with the sharded
tensortorch.mm(tensor, weight, out=same_device_shard)
handle.wait()- compute the rest of the matmul with the rest of the gathered tensor
Backward
def backward(ctx, grad_output)- What you can do is
tensor, weight = ctx.saved_tensors- async gather
gathered_tensorhandle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
- Compute
grad_tensor = grad_output.matmul(weight) handle.wait()- async
reduce_scatter_sumthe grad_tensor- Â
handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
- Â
- meanwhile, compute
grad_weightandgrad_biasgrad_weight = grad_output.t().matmul(gathered_tensor)grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()return sub_grad_tensor, grad_weight, grad_bias, None, None
Row Linear
Forward
def forward(ctx, tensor, weight, bias, group, tp_mode):- nothing much tricky going on
- What you do is
ctx.save_for_backward(tensor,weight)out = F.linear(tensor,weight,bias)return differentiable_reduce_scatter_sum(out, group)
Backward
def backward(ctx, grad_output)- What you do is (similar to
ColumnLinearforward)- start async gather
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - meanwhile, compute the local
grad_outputwith the current shardtorch.mm(grad_output,weight,out=same_device_shard_grad_tensor)
handle.wait()- Compute the rest of the
grad_tensorusingtotal_grad_output - compute weight and bias grad
grad_weight = total_grad_output.t().matmul(tensor)grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return total_grad_tensor, grad_weight, grad_bias, None, None
- start async gather