Skip to content

Commit

Permalink
Add interpreter for BitcastConvertOp (#1463)
Browse files Browse the repository at this point in the history
Here are the following constraints:
```
(I1) operand is a tensor.
(C1) Let `E` and `E'` be the `operand` and `result` element type,
respectively and `R = rank(operand)`:
* If `num_bits(E')` = `num_bits(E)`, shape(`result`) = shape(`operand`).
* If `num_bits(E')` < `num_bits(E)`:
  * `rank(result) = R+1`.
  * dim(`result`, `i`) = dim(`operand`, `i`) for all `i` in [0, `R`-1].
  * `dim(result, R) = num_bits(E)/num_bits(E')`.
* If `num_bits(E')` > `num_bits(E)`:
  * `rank(result) = R-1`.
  * dim(`result`, `i`) = dim(`operand`, `i`) for all `i` in [0, `R`-1).
  * `dim(operand, R-1) = num_bits(E')/num_bits(E)`.
(C2) Conversion between complex and non-complex types is not permitted.
```

These constraints will be comprehensively covered by the following
tests:

```
I1: a) operand is not a tensor. (Covered by ODS).
C1: a) If `num_bits(E')` = `num_bits(E)`, shape(`result`) != shape(`operand`).
C1: b) If `num_bits(E')` < `num_bits(E)`: `rank(result) != R+1`.
C1: c) If `num_bits(E')` < `num_bits(E)`: dim(`result`, `i`) != dim(`operand`, `i`) for any `i` in [0, `R`-1].
C1: d) If `num_bits(E')` < `num_bits(E)`: `dim(result, R) != num_bits(E)/num_bits(E')`.
C1: e) If `num_bits(E')` > `num_bits(E)`: `rank(result) != R-1`.
C1: f) If `num_bits(E')` > `num_bits(E)`: dim(`result`, `i`) != dim(`operand`, `i`) for all `i` in [0, `R`-1).
C1: g) If `num_bits(E')` > `num_bits(E)`: `dim(operand, R-1) != num_bits(E')/num_bits(E)`.
(C2) Conversion between complex and non-complex types is permitted.
```

If we drop the "Covered by ODS" pieces, this will leave us with the
following test cases:

```
C1a: If `num_bits(E')` = `num_bits(E)`:
     shape(`result`) != shape(`operand`).
C1b: If `num_bits(E')` < `num_bits(E)`:
     `rank(result) != R+1`.
C1c: If `num_bits(E')` < `num_bits(E)`:
     dim(`result`, `i`) != dim(`operand`, `i`) for any `i` in [0, `R`-1].
C1d: If `num_bits(E')` < `num_bits(E)`:
     `dim(result, R) != num_bits(E)/num_bits(E')`.
C1e: If `num_bits(E')` > `num_bits(E)`:
     `rank(result) != R-1`.
C1f: If `num_bits(E')` > `num_bits(E)`:
     dim(`result`, `i`) != dim(`operand`, `i`) for all `i` in [0, `R`-1).
C1g: If `num_bits(E')` > `num_bits(E)`:
     `dim(operand, R-1) != num_bits(E')/num_bits(E)`.
C2: Conversion between complex and non-complex types is permitted.
```

Notes:
* The interpreter assumes little endian representation, and we have
#1460 to address adding support for big endian architectures.

closes #1098
  • Loading branch information
ghpvnist authored Jul 6, 2023
1 parent ac47650 commit 3220197
Show file tree
Hide file tree
Showing 70 changed files with 324 additions and 138 deletions.
15 changes: 7 additions & 8 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -1284,12 +1284,12 @@ type of the `result` tensor.
More formally, given `E = element_type(operand)`, `E' = element_type(result)`,
and `R = rank(operand)`:

* If `num_bits(E') = num_bits(E)`,
`bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])`.
* If `num_bits(E') < num_bits(E)`,
`bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])`.
* If `num_bits(E') > num_bits(E)`,
`bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])`.
* If `num_bits(E') = num_bits(E)`,
`bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])`.

`bits` returns in-memory representation of a given value, and its behavior
is implementation-defined because the exact representation of tensors is
Expand Down Expand Up @@ -1328,14 +1328,13 @@ implementation-defined as well.
#### Examples

```mlir
// %operand: [0.0, 1.0]
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<2xf32>) -> tensor<2x4xi8>
// %result: [
// [0, 0, 0, 0],
// [0, 0, -128, 63] // little-endian representation of 1.0
// ]
// %operand: 0x0123456789ABCDEF
%result = "stablehlo.bitcast_convert"(%operand) : (tensor<f64>) -> tensor<4xf16>
// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation
```

&nbsp;[More Examples](../stablehlo/tests/interpret_bitcast_convert.mlir)

### broadcast_in_dim

#### Semantics
Expand Down
2 changes: 1 addition & 1 deletion docs/status.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ one of the following tracking labels.
| batch_norm_grad | yes | revisit | yes | no | revisit |
| batch_norm_inference | yes | revisit | yes | no | revisit |
| batch_norm_training | yes | revisit | yes | no | revisit |
| bitcast_convert | yes | yes | infeasible | yes | no |
| bitcast_convert | yes | yes | infeasible | yes | yes |
| broadcast | no | yes\* | yes\* | yes | revisit |
| broadcast_in_dim | yes | yes | infeasible | yes | yes |
| case | yes | revisit | yes | no | yes |
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1823,11 +1823,11 @@ def StableHLO_BitcastConvertOp : StableHLO_ShapedInterfaceOp<"bitcast_convert",

Example:
```mlir
%result = stablehlo.bitcast_convert %operand : (tensor<2xf32>) -> tensor<2x4xi8>
%result = stablehlo.bitcast_convert %operand : (tensor<f64>) -> tensor<4xf16>
```
}];

let arguments = (ins HLO_Tensor:$operand);
let arguments = (ins HLO_Tensor:$operand /*bitcast_convert_i1*/);
let results = (outs HLO_Tensor);
let hasVerifier = 1;

Expand Down
30 changes: 9 additions & 21 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3114,21 +3114,12 @@ LogicalResult verifyAllReduceOp(std::optional<Location> location, Value operand,
return success();
}

/*
* We intend to verify the following properties
* P1. We cannot convert between complex and real types (cf xla)
* P3. The dimensions of the operand and the target
* shape must match, except that the shape with the smaller element bitwidth has
* an appropriately-sized additional innermost dimension, e.g.
* ... x f32 => [bitcast_convert] => ... x 4 x i8
* ... x 4 x i8 => [bitcast_convert] => ... x f32
*/
LogicalResult verifyBitcastConvertOp(std::optional<Location> location,
Value operand, Value result) {
auto operandShapedType = operand.getType().cast<ShapedType>();
auto targetShapedType = result.getType().cast<ShapedType>();

// P1.
// bitcast_convert_c2
auto targetElt = targetShapedType.getElementType();
auto operandElt = operandShapedType.getElementType();
if (targetElt.isa<ComplexType>() != operandElt.isa<ComplexType>())
Expand All @@ -3139,36 +3130,32 @@ LogicalResult verifyBitcastConvertOp(std::optional<Location> location,
auto targetEltBitwidth = potentiallyComplexBitwidth(targetElt);
auto operandEltBitwidth = potentiallyComplexBitwidth(operandElt);

// P2.
auto operandType = operandShapedType.dyn_cast<RankedTensorType>();
auto targetType = targetShapedType.dyn_cast<RankedTensorType>();
if (!operandType || !targetType) return success();

auto targetShape = targetType.getShape();
auto operandShape = operandType.getShape();
ArrayRef<int64_t> smallerEltShape, biggerEltShape;
Type smallerElt, biggerElt;
if (operandEltBitwidth < targetEltBitwidth) {
smallerEltShape = operandShape;
smallerElt = operandElt;
biggerEltShape = targetShape;
biggerElt = targetElt;
} else {
smallerEltShape = targetShape;
smallerElt = targetElt;
biggerEltShape = operandShape;
biggerElt = operandElt;
}

ArrayRef<int64_t> smallerEltPrefix;
auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth);
auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth);
// bitcast_convert_c1
if (operandEltBitwidth != targetEltBitwidth) {
if (smallerEltShape.empty()) {
return emitOptionalError(location,
"does not allow the smaller element type to be "
"part of a 0d tensor, but got: ",
operandType, " and ", targetType, ".");
if (smallerEltShape.size() != biggerEltShape.size() + 1) {
return emitOptionalError(
location, "rank of smaller element type (", smallerEltShape.size(),
") should be 1 more than rank of larger element type (",
biggerEltShape.size(), "), but ", smallerEltShape.size(),
" != ", biggerEltShape.size(), " + 1.");
}
smallerEltPrefix = smallerEltShape.drop_back();
if (!isDynamicDimSize(smallerEltShape.back()) &&
Expand All @@ -3185,6 +3172,7 @@ LogicalResult verifyBitcastConvertOp(std::optional<Location> location,
for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) {
auto targetDim = std::get<0>(it);
auto operandDim = std::get<1>(it);
// bitcast_convert_c1
if (!verifyCompatibleDims(targetDim, operandDim))
return emitOptionalError(location,
"operand and result shapes must match except "
Expand Down
86 changes: 86 additions & 0 deletions stablehlo/reference/Element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,46 @@ std::complex<APFloat> Element::getComplexValue() const {
return std::complex<APFloat>(floatPair.first, floatPair.second);
}

APInt Element::toBits() const {
if (isSupportedBooleanType(type_))
return APInt(/*numBits=*/1, getBooleanValue() ? 1 : 0);
if (isSupportedIntegerType(type_)) return getIntegerValue();
if (isSupportedFloatType(type_)) return getFloatValue().bitcastToAPInt();
if (isSupportedComplexType(type_)) {
// Package the real part into the low half of the result bits,
// and the imaginary part into the high half of the result bits.
auto realBits = getComplexValue().real().bitcastToAPInt();
auto imagBits = getComplexValue().imag().bitcastToAPInt();
return imagBits.zext(numBits(type_)).shl(numBits(type_) / 2) +
realBits.zext(numBits(type_));
}
report_fatal_error(invalidArgument("Unsupported element type: %s",
debugString(type_).c_str()));
}

Element Element::fromBits(Type type, APInt bits) {
if (numBits(type) != bits.getBitWidth())
llvm::report_fatal_error("numBits(type) != bits.getBitWidth()");
if (isSupportedBooleanType(type)) return Element(type, !bits.isZero());
if (isSupportedIntegerType(type)) return Element(type, bits);
if (isSupportedFloatType(type))
return Element(type,
APFloat(type.cast<FloatType>().getFloatSemantics(), bits));
if (isSupportedComplexType(type)) {
// Interpret the low half of the bits as the real part, and
// the high half of the bits as the imaginary part.
auto elementType = type.cast<ComplexType>().getElementType();
auto realBits = bits.extractBits(numBits(type) / 2, 0);
auto realElement = fromBits(elementType, realBits);
auto imagBits = bits.extractBits(numBits(type) / 2, numBits(type) / 2);
auto imagElement = fromBits(elementType, imagBits);
return Element(type,
{realElement.getFloatValue(), imagElement.getFloatValue()});
}
report_fatal_error(invalidArgument("Unsupported element type: %s",
debugString(type).c_str()));
}

Element Element::operator!() const {
return Element(IntegerType::get(getType().getContext(), 1),
!getBooleanValue());
Expand Down Expand Up @@ -595,6 +635,52 @@ Element atan2(const Element &e1, const Element &e2) {
debugString(type).c_str()));
}

Element bitcastConvertManyToOne(Type type, ArrayRef<Element> elements) {
SmallVector<Element> results;

auto resultNumBits = numBits(type);
auto operandNumBits = numBits(elements[0].getType());
if (resultNumBits % operandNumBits != 0)
report_fatal_error(invalidArgument(
"Unsupported bitcast conversion from %s to %s",
debugString(elements[0].getType()).c_str(), debugString(type).c_str()));

APInt resultBits(resultNumBits, 0);
for (auto element : llvm::reverse(elements)) {
if (operandNumBits != numBits(element.getType()))
llvm::report_fatal_error("All elements must have the same numBits");
auto operandBits = element.toBits();
resultBits =
resultBits.shl(operandNumBits) + operandBits.zext(resultNumBits);
}
return Element::fromBits(type, resultBits);
}

SmallVector<Element> bitcastConvertOneToMany(Type type, const Element &el) {
SmallVector<Element> results;

auto resultNumBits = numBits(type);
auto operandNumBits = numBits(el.getType());
if (operandNumBits % resultNumBits != 0)
report_fatal_error(invalidArgument(
"Unsupported bitcast conversion from %s to %s",
debugString(el.getType()).c_str(), debugString(type).c_str()));

for (auto i = 0; i < operandNumBits; i += resultNumBits) {
auto resultBits = el.toBits().extractBits(resultNumBits, i);
results.push_back(Element::fromBits(type, resultBits));
}
return results;
}

Element bitcastConvertOneToOne(Type type, const Element &el) {
if (numBits(type) != numBits(el.getType()))
report_fatal_error(invalidArgument(
"Unsupported bitcast conversion from %s to %s",
debugString(el.getType()).c_str(), debugString(type).c_str()));
return Element::fromBits(type, el.toBits());
}

Element cbrt(const Element &el) {
return mapWithUpcastToDouble(
el, [](double e) { return std::cbrt(e); },
Expand Down
11 changes: 11 additions & 0 deletions stablehlo/reference/Element.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class Element {
/// complex type.
std::complex<APFloat> getComplexValue() const;

/// Returns the implementation-defined bits of the underlying value.
APInt toBits() const;

/// Creates an Element from implementation-defined bits.
static Element fromBits(Type type, APInt bits);

/// Overloaded not (logical) operator.
Element operator!() const;

Expand Down Expand Up @@ -149,6 +155,11 @@ Element atan2(const Element &e1, const Element &e2);
/// individually equal modulo the tolerance.
Element areApproximatelyEqual(const Element &e1, const Element &e2);

/// Various flavors of bitcast conversion as defined in the specification.
Element bitcastConvertOneToOne(Type type, const Element &e);
SmallVector<Element> bitcastConvertOneToMany(Type type, const Element &e);
Element bitcastConvertManyToOne(Type type, ArrayRef<Element> es);

/// Returns cube root of Element object.
Element cbrt(const Element &e);

Expand Down
42 changes: 42 additions & 0 deletions stablehlo/reference/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ SmallVector<InterpreterValue> eval(
failOnDecomposableOp(op);
} else if (auto batchNormTrainingOp = dyn_cast<BatchNormTrainingOp>(op)) {
failOnDecomposableOp(op);
} else if (auto bitcastConvertOp = dyn_cast<BitcastConvertOp>(op)) {
auto operand = scope.findTensor(bitcastConvertOp.getOperand());
auto result = evalBitcastConvertOp(operand, bitcastConvertOp.getType());
scope.add(bitcastConvertOp.getResult(), result);
} else if (auto broadcastInDimOp = dyn_cast<BroadcastInDimOp>(op)) {
auto operand = scope.findTensor(broadcastInDimOp.getOperand());
auto broadcastDimensions =
Expand Down Expand Up @@ -682,6 +686,44 @@ Tensor evalAtan2Op(const Tensor &lhs, const Tensor &rhs,
return result;
}

Tensor evalBitcastConvertOp(const Tensor &operand, ShapedType resultType) {
Tensor result(resultType);

auto resultElementType = result.getElementType();
auto resultNumBits = numBits(result.getElementType());
auto operandNumBits = numBits(operand.getElementType());

if (resultNumBits < operandNumBits) {
auto resultIt = result.index_begin();
for (auto operandIt = operand.index_begin();
operandIt != operand.index_end(); ++operandIt) {
auto resultElements =
bitcastConvertOneToMany(resultElementType, operand.get(*operandIt));
for (auto resultElement : resultElements)
result.set(*resultIt++, resultElement);
}
return result;
}

if (resultNumBits > operandNumBits) {
auto operandIt = operand.index_begin();
for (auto resultIt = result.index_begin(); resultIt != result.index_end();
++resultIt) {
SmallVector<Element> operandElements;
for (auto i = 0; i < resultNumBits / operandNumBits; ++i)
operandElements.push_back(operand.get(*operandIt++));
result.set(*resultIt,
bitcastConvertManyToOne(resultElementType, operandElements));
}
return result;
}

for (auto it = result.index_begin(); it != result.index_end(); ++it)
result.set(*it,
bitcastConvertOneToOne(resultElementType, operand.get(*it)));
return result;
}

Tensor evalBroadcastInDimOp(const Tensor &operand,
const Axes &broadcastDimensions,
ShapedType resultType) {
Expand Down
1 change: 1 addition & 0 deletions stablehlo/reference/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Tensor evalAddOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType);
Token evalAfterAllOp(ArrayRef<Token> inputs, MLIRContext *context);
Tensor evalAndOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType);
Tensor evalAtan2Op(const Tensor &lhs, const Tensor &rhs, ShapedType resultType);
Tensor evalBitcastConvertOp(const Tensor &operand, ShapedType resultType);
Tensor evalBroadcastInDimOp(const Tensor &operand,
const Axes &broadcastDimensions,
ShapedType resultType);
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,11 @@ bool isSupportedComplexType(Type type) {
return complexElemTy.isF32() || complexElemTy.isF64();
}

int64_t numBits(Type type) {
if (isSupportedComplexType(type))
return numBits(type.cast<ComplexType>().getElementType()) * 2;
return type.getIntOrFloatBitWidth();
}

} // namespace stablehlo
} // namespace mlir
8 changes: 8 additions & 0 deletions stablehlo/reference/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ bool isSupportedFloatType(Type type);
/// StableHLO specification. Such types are: complex<f32> and complex<f64>.
bool isSupportedComplexType(Type type);

/// Returns the number of bits in the representation of an element type.
/// * For boolean type: 1.
/// * For integer types: bit width (e.g. 32 for si32).
/// * For floating-point types: bit width (e.g. 32 for f32).
/// * For complex types: 2x the bit width of element type (e.g. 64 for
/// complex<f32>).
int64_t numBits(Type type);

} // namespace stablehlo
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret
// RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s)

module @jit_testcase {
Expand Down
Loading

0 comments on commit 3220197

Please sign in to comment.