Skip to content

Commit f4ee125

Browse files
rongzha1chunyuan-wsanchitintelWeizhuoZhang-intel
authored
llga: no data type promotion needed for binary op of `Double' (#2622)
* llga: no data type promotion needed for binary op when one data type is 'Double' * llga: no promotion for data types not supported by oneDNN Graph * fix clang-format error --------- Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com> Co-authored-by: sanchitintel <sanchit.jain@intel.com> Co-authored-by: WeizhuoZhang-intel <weizhuo.zhang@intel.com>
1 parent 55efdde commit f4ee125

File tree

4 files changed

+57
-24
lines changed

4 files changed

+57
-24
lines changed

csrc/cpu/jit/codegen/LlgaTensorImpl.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,26 @@ dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor) {
5656
tensor.data_ptr()};
5757
}
5858

59+
using data_type = dnnl::graph::logical_tensor::data_type;
60+
data_type getLlgaDataType(at::ScalarType dt) {
61+
switch (dt) {
62+
case at::ScalarType::Float:
63+
return data_type::f32;
64+
case at::ScalarType::BFloat16:
65+
return data_type::bf16;
66+
case at::ScalarType::Bool:
67+
return data_type::boolean;
68+
case at::kInt:
69+
return data_type::s32;
70+
case at::ScalarType::QInt8:
71+
return data_type::s8;
72+
case at::ScalarType::QUInt8:
73+
return data_type::u8;
74+
default:
75+
return data_type::undef;
76+
}
77+
}
78+
5979
at::Tensor LlgaTensorImpl::llga_to_aten_tensor(LlgaTensorImpl* llgaImpl) {
6080
auto aten_tensor = at::detail::make_tensor<TensorImpl>(
6181
std::move(llgaImpl->storage_),
@@ -81,27 +101,6 @@ at::Tensor LlgaTensorImpl::llga_to_aten_tensor(
81101
return aten_tensor;
82102
}
83103

84-
using data_type = dnnl::graph::logical_tensor::data_type;
85-
86-
data_type LlgaTensorDesc::getLlgaDataType(at::ScalarType dt) const {
87-
switch (dt) {
88-
case at::ScalarType::Float:
89-
return data_type::f32;
90-
case at::ScalarType::BFloat16:
91-
return data_type::bf16;
92-
case at::ScalarType::Bool:
93-
return data_type::boolean;
94-
case at::kInt:
95-
return data_type::s32;
96-
case at::ScalarType::QInt8:
97-
return data_type::s8;
98-
case at::ScalarType::QUInt8:
99-
return data_type::u8;
100-
default:
101-
return data_type::undef;
102-
}
103-
}
104-
105104
LlgaTensorDesc LlgaTensorDesc::supplementTensorInfo(const at::Tensor& t) const {
106105
if (t.is_mkldnn()) {
107106
// if input tensor is of mkldnn, it's originated from an upstream

csrc/cpu/jit/codegen/LlgaTensorImpl.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace jit {
1313
namespace fuser {
1414
namespace onednn {
1515

16+
dnnl::graph::logical_tensor::data_type getLlgaDataType(at::ScalarType dt);
17+
1618
struct LlgaTensorDesc {
1719
using desc = dnnl::graph::logical_tensor;
1820

@@ -106,9 +108,6 @@ struct LlgaTensorDesc {
106108

107109
at::ScalarType aten_scalar_type() const;
108110

109-
dnnl::graph::logical_tensor::data_type getLlgaDataType(
110-
at::ScalarType dt) const;
111-
112111
const std::vector<int64_t>& sizes() const {
113112
return sizes_;
114113
}

csrc/cpu/jit/codegen/onednn/prepare_binary.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "prepare_binary.h"
22
#include <torch/csrc/jit/passes/dead_code_elimination.h>
33
#include <torch/csrc/jit/passes/shape_analysis.h>
4+
#include "../LlgaTensorImpl.h"
45
#include "utils.h"
56

67
namespace torch_ipex {
@@ -9,6 +10,7 @@ namespace fuser {
910
namespace onednn {
1011

1112
using namespace torch::jit;
13+
using data_type = dnnl::graph::logical_tensor::data_type;
1214

1315
void handleBinaryOpInputs(Node* node, int first_input, int second_input) {
1416
if (node->input(first_input)->type()->isSubtypeOf(TensorType::get()) &&
@@ -33,6 +35,10 @@ void handleBinaryOpInputs(Node* node, int first_input, int second_input) {
3335
// https://pytorch.org/docs/stable/tensor_attributes.html#type-promotion-doc
3436
// clang-format on
3537
auto promotedDtype = dtypeOfFirstInput;
38+
// This tensor won't be added to oneDNN graph due to unsupported data
39+
// type, so no need to do promotion for it.
40+
if (getLlgaDataType(promotedDtype) == data_type::undef)
41+
return;
3642
utils::convertInputTo0DTensor(node, second_input, promotedDtype);
3743
// dtype might have changed, so needs to be updated in IR as well
3844
utils::modifyDtypeOfNode(node, promotedDtype);
@@ -53,6 +59,10 @@ void handleBinaryOpInputs(Node* node, int first_input, int second_input) {
5359
// Type promotion is required
5460
auto promotedDtype =
5561
c10::promoteTypes(dtypeOfFirstInput, dtypeOfSecondInput);
62+
// This tensor won't be added to oneDNN graph due to unsupported data
63+
// type, so no need to do promotion for it.
64+
if (getLlgaDataType(promotedDtype) == data_type::undef)
65+
return;
5666
int input_to_replace;
5767
if (promotedDtype == dtypeOfFirstInput) {
5868
input_to_replace = second_input;
@@ -65,6 +75,10 @@ void handleBinaryOpInputs(Node* node, int first_input, int second_input) {
6575
utils::mark_original_output_dtype(node);
6676
utils::modifyDtypeOfNode(node, promotedDtype);
6777
} else {
78+
// This tensor won't be added to oneDNN graph due to unsupported data
79+
// type, so no need to do promotion for it.
80+
if (getLlgaDataType(dtypeOfFirstInput) == data_type::undef)
81+
return;
6882
// both dtypes are same
6983
// IR info of dtypes is missing sometimes in JIT IR,
7084
// and we shouldn't treat those tensors as FP32 tensors by default.

tests/cpu/test_jit_llga_fuser.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,27 @@ def forward(self, x, y):
450450
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
451451
self.assertFused(graph, ["aten::matmul", "aten::div"])
452452

453+
@llga_fp32_bf16_test_env
454+
def test_bmm_div_with_double_dt(self):
455+
class M(nn.Module):
456+
def __init__(self):
457+
super(M, self).__init__()
458+
self.divisor = torch.randn(1, dtype=torch.float64)
459+
460+
def forward(self, x, y):
461+
return x.matmul(y) / self.divisor
462+
463+
x = torch.randn(128, 16, 384, 64)
464+
y = torch.randn(128, 16, 64, 384)
465+
m = M()
466+
467+
graph, _ = self.checkTrace(m, [x, y])
468+
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)
469+
# no need to do data type promotion for `Double` which llga doesn't
470+
# support
471+
self.assertGraphContainsExactly(graph, "aten::to", 0)
472+
self.assertFused(graph, ["aten::matmul"])
473+
453474
@llga_fp32_bf16_test_env
454475
def test_bmm_div_add(self):
455476
class M(nn.Module):

0 commit comments

Comments
 (0)