Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 170 additions & 1 deletion torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def backward(ctx, grad_output):
class MmPassThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return torch.mm(x, y)
with torch._C._AutoDispatchBelowAutograd():
return torch.mm(x, y)

@staticmethod
def backward(ctx, gO):
Expand All @@ -83,8 +84,176 @@ def split_mm(i, w):
i1 = MmSeparateInputGrad.apply(i, w.detach())
return MmPassThrough.apply(i1, w1)

# addmm operator: out = beta * input + alpha * (mat1 @ mat2)
class AddmmSeparateMat2Grad(torch.autograd.Function):
@staticmethod
def forward(ctx, mat1, mat2, alpha):
ctx.save_for_backward(mat1)
ctx.alpha = alpha
return mat2

@staticmethod
def backward(ctx, grad_output):
(mat1,) = ctx.saved_tensors
# Gradient w.r.t. mat2: alpha * mat1.T @ grad_output
grad_mat2 = mat1.t().mm(grad_output) * ctx.alpha
return None, grad_mat2, None

class AddmmSeparateMat1Grad(torch.autograd.Function):
@staticmethod
def forward(ctx, mat1, mat2, alpha):
ctx.save_for_backward(mat2)
ctx.alpha = alpha
return mat1

@staticmethod
def backward(ctx, grad_output):
(mat2,) = ctx.saved_tensors
# Gradient w.r.t. mat1: alpha * grad_output @ mat2.T
grad_mat1 = grad_output.mm(mat2.t()) * ctx.alpha
return grad_mat1, None, None

class AddmmSeparateBiasGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, bias, beta):
ctx.beta = beta
return bias

@staticmethod
def backward(ctx, grad_output):
# Gradient w.r.t. bias: beta * sum(grad_output, dim=0)
grad_bias = grad_output.sum(dim=0) * ctx.beta
return grad_bias, None

class AddmmPassThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, bias, mat1, mat2, beta, alpha):
with torch._C._AutoDispatchBelowAutograd():
return torch.addmm(bias, mat1, mat2, beta=beta, alpha=alpha)

@staticmethod
def backward(ctx, gO):
return gO, gO, gO, None, None

def split_addmm(bias, mat1, mat2, *, beta=1, alpha=1):
print("split addmm")
mat2_1 = AddmmSeparateMat2Grad.apply(mat1.detach(), mat2, alpha)
mat1_1 = AddmmSeparateMat1Grad.apply(mat1, mat2.detach(), alpha)
bias_1 = AddmmSeparateBiasGrad.apply(bias, beta)
return AddmmPassThrough.apply(bias_1, mat1_1, mat2_1, beta, alpha)

# _fused_rms_norm operator: RMS normalization
class FusedRmsNormSeparateWeightGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
ctx.save_for_backward(input)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
return weight

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
# Compute normalized input for weight gradient
variance = input.pow(2).mean(-1, keepdim=True)
rstd = torch.rsqrt(variance + ctx.eps)
normalized = input * rstd
# Gradient w.r.t. weight: sum over batch dimension
grad_weight = (grad_output * normalized).sum(
dim=tuple(range(grad_output.ndim - 1))
)
return None, grad_weight, None, None

class FusedRmsNormSeparateInputGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
ctx.save_for_backward(weight)
ctx.normalized_shape = normalized_shape
ctx.eps = eps
return input

@staticmethod
def backward(ctx, grad_output):
(weight,) = ctx.saved_tensors
# This is a placeholder - the actual gradient computation happens in PassThrough
# Here we just pass through the grad_output weighted by weight
return grad_output, None, None, None

class FusedRmsNormPassThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape, eps):
with torch._C._AutoDispatchBelowAutograd():
return torch.ops.aten._fused_rms_norm(
input, weight, normalized_shape, eps
)

@staticmethod
def backward(ctx, gO):
return gO, gO, None, None

def split_fused_rms_norm(input, weight, normalized_shape, eps):
print("split fused_rms_norm")
weight_1 = FusedRmsNormSeparateWeightGrad.apply(
input.detach(), weight, normalized_shape, eps
)
input_1 = FusedRmsNormSeparateInputGrad.apply(
input, weight.detach(), normalized_shape, eps
)
return FusedRmsNormPassThrough.apply(input_1, weight_1, normalized_shape, eps)

# _grouped_mm operator: Grouped matrix multiplication for MoE
class GroupedMmSeparateMat2Grad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mat2):
ctx.save_for_backward(input)
return mat2

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
# Gradient w.r.t. mat2 for grouped mm
# This is simplified - actual implementation may need group-wise computation
grad_mat2 = torch.ops.aten._grouped_mm.default(
input.transpose(-1, -2), grad_output, reduce="sum"
)
return None, grad_mat2

class GroupedMmSeparateInputGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mat2):
ctx.save_for_backward(mat2)
return input

@staticmethod
def backward(ctx, grad_output):
(mat2,) = ctx.saved_tensors
# Gradient w.r.t. input for grouped mm
grad_input = torch.ops.aten._grouped_mm.default(
grad_output, mat2.transpose(-1, -2), reduce="sum"
)
return grad_input, None

class GroupedMmPassThrough(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mat2, reduce="sum"):
with torch._C._AutoDispatchBelowAutograd():
return torch.ops.aten._grouped_mm.default(input, mat2, reduce=reduce)

@staticmethod
def backward(ctx, gO):
return gO, gO, None

def split_grouped_mm(input, mat2, reduce="sum"):
print("split grouped_mm")
mat2_1 = GroupedMmSeparateMat2Grad.apply(input.detach(), mat2)
input_1 = GroupedMmSeparateInputGrad.apply(input, mat2.detach())
return GroupedMmPassThrough.apply(input_1, mat2_1, reduce)

lib = torch.library.Library("aten", "IMPL")
lib.impl("mm", split_mm, "Autograd")
lib.impl("addmm", split_addmm, "Autograd")
lib.impl("_fused_rms_norm", split_fused_rms_norm, "Autograd")
lib.impl("_grouped_mm", split_grouped_mm, "Autograd")


def pipeline_llm(
Expand Down