Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] simplify outer backward #201

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,18 @@ Examples:
- support reduction operators: cumsum, layernorm, mean, softmax

### v2.0
- support BLAS operator: mv, outer
- support BLAS operators: mv, outer
- support pointwise operators: bitwise_and, bitwise_not, bitwise_or, cos, clamp, eq, ge, gt, isinf, isnan, le, lt, ne, neg, or, sin, tanh, sigmoid
- support reduction operators: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm
- support fused operators: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding

### v2.1
- support Tensor operators: where, arange, repeat, masked_fill, tile, unique, index_select, masked_select, ones, ones_like, zeros, zeros_like, full, full_like, flip, pad
- support neural network operator: embedding
- support basic math operators: allclose, isclose, isfinite, floor_divide, trunc_divide, maximum, minimum
- support distribution operators: normal, uniform_, exponential_, multinomial, nonzero, topk, rand, randn, rand_like, randn_like
- support science operators: erf, resolve_conj, resolve_neg

## Quick Start

### Requirements
Expand Down
7 changes: 7 additions & 0 deletions README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
- 支持reduction类算子: all, any, amax, argmax, max, min, prod, sum, var_mean, vector_norm, cross_entropy_loss, group_norm, log_softmax, rms_norm
- 支持融合算子: skip_rms_norm, skip_layer_norm, gelu_and_mul, silu_and_mul, apply_rotary_position_embedding

### v2.1
- 支持Tensor类算子:where, arange, repeat, masked_fill, tile, unique, index_select, masked_select, ones, ones_like, zeros, zeros_like, full, full_like, flip, pad
- 支持神经网络类算子:embedding
- 支持基础数学算子:allclose, isclose, isfinite, floor_divide, trunc_divide, maximum, minimum
- 支持分布类算子:normal, uniform_, exponential_, multinomial, nonzero, topk, rand, randn, rand_like, randn_like
- 支持科学计算算子:erf, resolve_conj, resolve_neg

## 快速入门

### 依赖
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_perf_gelu():

def test_perf_gelu_backward():
bench = Benchmark(
op_name="gelu backward",
op_name="gelu",
torch_op=torch.nn.functional.gelu,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
Expand Down
2 changes: 1 addition & 1 deletion benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def test_perf_softmax():

def test_perf_softmax_backward():
bench = Benchmark(
op_name="softmax backward",
op_name="softmax",
torch_op=torch.nn.functional.softmax,
arg_func=unary_arg,
dtypes=FLOAT_DTYPES,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .fused import * # noqa: F403
from .ops import * # noqa: F403

__version__ = "2.0"
__version__ = "2.1"

aten_lib = torch.library.Library("aten", "IMPL")

Expand Down
13 changes: 3 additions & 10 deletions src/flag_gems/ops/outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch

from .mm import mm
from .mul import mul
from .mv import mv


class Outer(torch.autograd.Function):
Expand All @@ -26,15 +26,8 @@ def backward(ctx, out_grad):

inp, weight = ctx.saved_tensors

inp_shape = inp.shape
inp_grad_mid = mm(out_grad, weight[:, None])
inp_grad = inp_grad_mid.reshape(inp_shape)

weight_shape = weight.shape
inp = inp[None, :]
inp = inp.contiguous()
weight_grad_mid = mm(inp, out_grad)
weight_grad = weight_grad_mid.reshape(weight_shape)
inp_grad = mv(out_grad, weight)
weight_grad = mv(out_grad.t(), inp)

return inp_grad, weight_grad

Expand Down
7 changes: 4 additions & 3 deletions tests/test_blas_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
@pytest.mark.parametrize("M", MNK_SHAPES)
@pytest.mark.parametrize("N", MNK_SHAPES)
@pytest.mark.parametrize("K", MNK_SHAPES)
@pytest.mark.parametrize("alpha", SCALARS)
@pytest.mark.parametrize("beta", SCALARS)
@pytest.mark.parametrize("scalar", SCALARS)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_addmm(M, N, K, alpha, beta, dtype):
def test_accuracy_addmm(M, N, K, scalar, dtype):
mat1 = torch.randn((M, K), dtype=dtype, device="cuda")
mat2 = torch.randn((K, N), dtype=dtype, device="cuda")
bias = torch.randn((N,), dtype=dtype, device="cuda")
ref_mat1 = to_reference(mat1, True)
ref_mat2 = to_reference(mat2, True)
ref_bias = to_reference(bias, True)

alpha = beta = scalar

ref_out = torch.addmm(ref_bias, ref_mat1, ref_mat2, alpha=alpha, beta=beta)
with flag_gems.use_gems():
res_out = torch.addmm(bias, mat1, mat2, alpha=alpha, beta=beta)
Expand Down