Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fused_scale_tril / hot fix matmul / softmax broadcast_sub broadcast_div #3980

Merged
merged 14 commits into from
Dec 6, 2020
8 changes: 4 additions & 4 deletions oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {

const AMPList& AutoMixedPrecisionLists::ClearList() {
// TODO(niuchong): identity, tuple_identity, keep_header_only?
static AMPList clear_list = {"gather", "max_pool_1d", "max_pool_2d", "max_pool_3d",
"reshape", "relu", "transpose", "random_mask_like",
"concat", "pad", "same_padding", "tril",
"slice"};
static AMPList clear_list = {
"gather", "max_pool_1d", "max_pool_2d", "max_pool_3d", "reshape", "relu",
"transpose", "random_mask_like", "concat", "pad", "same_padding", "tril",
"slice", "fused_scale_tril"};

return clear_list;
}
Expand Down
43 changes: 43 additions & 0 deletions oneflow/core/kernel/util/cuda_blas_interface.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE tra
ldc);
}

template<>
void Gemm(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const half* alpha,
const half* a, const half* b, const half* beta, half* c) {
const float alpha_f = __half2float(*alpha);
const float beta_f = __half2float(*beta);
int lda, ldb, ldc;
cublasOperation_t cublas_trans_a, cublas_trans_b;
std::tie(lda, ldb, ldc, cublas_trans_a, cublas_trans_b) =
PrepareToCallCublasGemm(trans_a, trans_b, m, n, k);
OF_CUBLAS_CHECK(cublasGemmEx(ctx->cublas_tensor_op_math_handle(), cublas_trans_b, cublas_trans_a,
n, m, k, &alpha_f, b, CUDA_R_16F, ldb, a, CUDA_R_16F, lda, &beta_f,
c, CUDA_R_16F, ldc, CUDA_R_32F, CUBLAS_GEMM_DFALT_TENSOR_OP));
}

void HGemmWithFloat(DeviceCtx* ctx, const enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans_a,
enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
const float* alpha, const half* a, const half* b, const float* beta, half* c) {
Expand Down Expand Up @@ -176,6 +191,34 @@ void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
#endif
}

#if CUDA_VERSION >= 9010
template<>
void BatchedGemmImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b,
int batch_size, int m, int n, int k, const half* alpha, const half* a,
const half* b, const half* beta, half* c, half** buf) {
float alpha_f = __half2float(*alpha);
float beta_f = __half2float(*beta);

int a_stride, b_stride, c_stride;
int lda, ldb, ldc;
cublasOperation_t cublas_trans_a, cublas_trans_b;
half** dev_a_ptrs;
half** dev_b_ptrs;
half** dev_c_ptrs;
std::tie(a_stride, b_stride, c_stride, lda, ldb, ldc, cublas_trans_a, cublas_trans_b, dev_a_ptrs,
dev_b_ptrs, dev_c_ptrs) =
PrepareToCallBatchedGemm<half>(ctx, trans_a, trans_b, batch_size, m, n, k, a, b, c, buf);
OF_CUBLAS_CHECK(cublasGemmBatchedEx(
ctx->cublas_tensor_op_math_handle(), CblasTrans2CublasTrans(trans_b),
CblasTrans2CublasTrans(trans_a), n, m, k, &alpha_f,
reinterpret_cast<const void**>(const_cast<const half**>(dev_b_ptrs)), CUDA_R_16F, ldb,
reinterpret_cast<const void**>(const_cast<const half**>(dev_a_ptrs)), CUDA_R_16F, lda,
&beta_f, reinterpret_cast<void**>(dev_c_ptrs), CUDA_R_16F, ldc, batch_size, CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP));
}
#endif

void BatchedHGemmWithFloatImpl(DeviceCtx* ctx, const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, int batch_size, int m, int n,
Expand Down
46 changes: 46 additions & 0 deletions oneflow/python/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,52 @@ def tril_Job(x: tp.Numpy.Placeholder((4, 4))
)


@oneflow_export("math.fused_scale_tril", "nn.fused_scale_tril")
def fused_scale_tril(
x: remote_blob_util.BlobDef,
diagonal: int = 0,
fill_value: Union[int, float] = 0,
scale: Union[int, float] = 1,
name: Optional[str] = None,
) -> remote_blob_util.BlobDef:

if isinstance(fill_value, float):
is_floating_fill_value = True
floating_fill_value = float(fill_value)
integer_fill_value = int(0)
else:
is_floating_fill_value = False
floating_fill_value = float(0)
integer_fill_value = int(fill_value)

if isinstance(scale, float):
is_floating_scale_value = True
floating_scale_value = float(scale)
integer_scale_value = int(1)
else:
is_floating_scale_value = False
floating_scale_value = float(1)
integer_scale_value = int(scale)
return (
flow.user_op_builder(
name if name is not None else id_util.UniqueStr("FusedScaleTril_")
)
.Op("fused_scale_tril")
.Input("in", [x])
.Attr("diagonal", diagonal)
.Attr("is_floating_fill_value", is_floating_fill_value)
.Attr("floating_fill_value", floating_fill_value)
.Attr("integer_fill_value", integer_fill_value)
.Attr("is_floating_scale_value", is_floating_scale_value)
.Attr("floating_scale_value", floating_scale_value)
.Attr("integer_scale_value", integer_scale_value)
.Output("out")
.Build()
.InferAndTryRun()
.RemoteBlobList()[0]
)


@oneflow_export("math.polyval")
def polyval(
coeffs: Union[List, Tuple], x: remote_blob_util.BlobDef, name: Optional[str] = None
Expand Down
131 changes: 131 additions & 0 deletions oneflow/python/test/ops/test_fused_scale_tril.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow as flow
from test_util import (
GenArgDict,
test_global_storage,
type_name_to_flow_type,
type_name_to_np_type,
)
import oneflow.typing as oft


def _test_fused_scale_tril_fw_bw(
test_case, device, shape, type_name, diagonal, fill_value, scale
):
flow.clear_default_session()
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)

if type_name == "float16":
flow_type = flow.float
np_type = np.float32
else:
flow_type = type_name_to_flow_type[type_name]
np_type = type_name_to_np_type[type_name]

@flow.global_function(type="train", function_config=func_config)
def test_fused_scale_tril_fw_bw_job(
x: oft.Numpy.Placeholder(shape, dtype=flow_type),
):
with flow.scope.placement(device, "0:0"):
x_var = flow.get_variable(
name="xv",
shape=(1,),
dtype=flow.float,
initializer=flow.zeros_initializer(),
)
x += flow.cast(x_var, dtype=flow_type)
if type_name == "float16":
out = flow.cast(
flow.math.fused_scale_tril(
flow.cast(x, flow.float16), diagonal, scale=scale
),
flow.float,
)
else:
out = flow.math.fused_scale_tril(x, diagonal, scale=scale)
flow.optimizer.SGD(
flow.optimizer.PiecewiseConstantScheduler([], [1e-4]), momentum=0
).minimize(out)

flow.watch(x, test_global_storage.Setter("x"))
flow.watch_diff(x, test_global_storage.Setter("x_diff"))
flow.watch(out, test_global_storage.Setter("out"))
flow.watch_diff(out, test_global_storage.Setter("out_diff"))
return out

check_point = flow.train.CheckPoint()
check_point.init()
x = np.random.randint(low=0, high=100, size=shape)
test_fused_scale_tril_fw_bw_job(x.astype(np_type)).get()

np_out = np.where(
np.tril(np.ones(shape), diagonal),
test_global_storage.Get("x") * scale,
np.full(shape, fill_value).astype(np_type),
)
np_x_diff = np.tril(test_global_storage.Get("out_diff"), diagonal) * scale

if type_name == "float16":
tolerance = 1e-3
else:
tolerance = 1e-5
test_case.assertTrue(
np.allclose(
np_out, test_global_storage.Get("out"), rtol=tolerance, atol=tolerance
)
)
test_case.assertTrue(
np.allclose(
np_x_diff, test_global_storage.Get("x_diff"), rtol=tolerance, atol=tolerance
)
)


@flow.unittest.skip_unless_1n1d()
class TestFusedScaleTril(flow.unittest.TestCase):
def test_fused_scale_tril_fw_bw(test_case):
arg_dict = OrderedDict()
arg_dict["device"] = ["gpu"]
arg_dict["type_name"] = [
"float32",
"float16",
"double",
"int32",
"int64",
]
arg_dict["shape"] = [(6, 6), (3, 6, 8)]
arg_dict["diagonal"] = [-8, -1, 0, 1, 8]
arg_dict["fill_value"] = [1.0, 0]
arg_dict["scale"] = [5.0, 3]
for arg in GenArgDict(arg_dict):
if arg["device"] == "cpu" and arg["type_name"] == "float16":
continue
if isinstance(arg["fill_value"], float) and arg_dict["type_name"] not in [
"float32",
"float16",
"double",
]:
continue
_test_fused_scale_tril_fw_bw(test_case, **arg)


if __name__ == "__main__":
unittest.main()
Loading