Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

importer: add initial support for loading Float16 tensors #1169

Merged
merged 1 commit into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::Bool;
if (type.isBF16())
return torch_upstream::ScalarType::BFloat16;
if (type.isF16())
return torch_upstream::ScalarType::Half;
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
}

Expand All @@ -74,6 +76,8 @@ Type Torch::getTypeForScalarType(
return IntegerType::get(context, 1);
case torch_upstream::ScalarType::BFloat16:
return mlir::FloatType::getBF16(context);
case torch_upstream::ScalarType::Half:
return mlir::FloatType::getF16(context);
default:
return Type();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ MlirAttribute torch_mlir::convertTensorToMlirElementsAttr(at::Tensor tensor,
case ScalarType::BFloat16:
return mlirDenseElementsAttrBFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData));
case ScalarType::Half:
return mlirDenseElementsAttrFloat16Get(
shapedType, numElements, static_cast<const uint16_t *>(tensorData));

default:
throwUnsupportedTensorError();
}
Expand Down
13 changes: 13 additions & 0 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,16 @@ func.func @torch.aten.neg.bf16(%arg0: !torch.vtensor<[?,?],bf16>) -> !torch.vten
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],bf16> -> !torch.vtensor<[?,?],bf16>
return %0 : !torch.vtensor<[?,?],bf16>
}

// -----

// CHECK-LABEL: func.func @torch.aten.neg.f16
// CHECK: linalg.generic {{.*}} {
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f16, %{{.*}}: f16):
// CHECK-NEXT: %[[NEG:.*]] = arith.negf %[[LHS]] : f16
// CHECK-NEXT: linalg.yield %[[NEG]] : f16
// CHECK-NEXT: } -> tensor<?x?xf16>
func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtensor<[?,?],f16> {
%0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16>
return %0 : !torch.vtensor<[?,?],f16>
}
3 changes: 3 additions & 0 deletions test/python/importer/jit_ir/ivalue_import/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
self.ones_f64 = torch.ones(1, dtype=torch.float64)
self.ones_bool = torch.ones(1, dtype=torch.bool)
self.ones_bf16 = torch.ones(1, dtype=torch.bfloat16)
self.ones_f16 = torch.ones(1, dtype=torch.half)
self.ones_qint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.qint8)
self.ones_quint8 = torch.quantize_per_tensor(torch.ones(1), 1.0, 0, torch.quint8)
self.arange = torch.nn.Parameter(torch.arange(3.0))
Expand All @@ -34,6 +35,7 @@ def __init__(self):
# CHECK: %[[ONES_F64:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf64>) : !torch.tensor<[1],f64>
# CHECK: %[[ONES_BOOL:.*]] = torch.tensor.literal(dense<true> : tensor<1xi1>) : !torch.tensor<[1],i1>
# CHECK: %[[ONES_BF16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xbf16>) : !torch.tensor<[1],bf16>
# CHECK: %[[ONES_F16:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf16>) : !torch.tensor<[1],f16>
# CHECK: %[[ONES_QINT8_DATA:.*]] = torch.tensor.literal(dense<1> : tensor<1xsi8>) : !torch.tensor<[1],si8>
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
Expand All @@ -49,6 +51,7 @@ def __init__(self):
# CHECK: torch.slot "ones_f64", %[[ONES_F64]] : !torch.tensor<[1],f64>
# CHECK: torch.slot "ones_bool", %[[ONES_BOOL]] : !torch.tensor<[1],i1>
# CHECK: torch.slot "ones_bf16", %[[ONES_BF16]] : !torch.tensor<[1],bf16>
# CHECK: torch.slot "ones_f16", %[[ONES_F16]] : !torch.tensor<[1],f16>
# CHECK: torch.slot "ones_qint8", %[[ONES_QINT8]] : !torch.tensor<[1],!torch.qint8>
# CHECK: torch.slot "ones_quint8", %[[ONES_QUINT8]] : !torch.tensor<[1],!torch.quint8>
# CHECK: }
Expand Down