diff --git a/docs/spec.md b/docs/spec.md index 53d7dc3e8e0..87853d0bd93 100644 --- a/docs/spec.md +++ b/docs/spec.md @@ -1284,12 +1284,12 @@ type of the `result` tensor. More formally, given `E = element_type(operand)`, `E' = element_type(result)`, and `R = rank(operand)`: -* If `num_bits(E') = num_bits(E)`, - `bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])`. * If `num_bits(E') < num_bits(E)`, `bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1])`. * If `num_bits(E') > num_bits(E)`, `bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :])`. +* If `num_bits(E') = num_bits(E)`, + `bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1])`. `bits` returns in-memory representation of a given value, and its behavior is implementation-defined because the exact representation of tensors is @@ -1328,14 +1328,13 @@ implementation-defined as well. #### Examples ```mlir -// %operand: [0.0, 1.0] -%result = "stablehlo.bitcast_convert"(%operand) : (tensor<2xf32>) -> tensor<2x4xi8> -// %result: [ -// [0, 0, 0, 0], -// [0, 0, -128, 63] // little-endian representation of 1.0 -// ] +// %operand: 0x0123456789ABCDEF +%result = "stablehlo.bitcast_convert"(%operand) : (tensor) -> tensor<4xf16> +// %result: [0xCDEF, 0x89AB, 0x4567, 0x0123] // little-endian representation ``` + [More Examples](../stablehlo/tests/interpret_bitcast_convert.mlir) + ### broadcast_in_dim #### Semantics diff --git a/docs/status.md b/docs/status.md index 7d019ef70fd..95cb5d33096 100644 --- a/docs/status.md +++ b/docs/status.md @@ -53,7 +53,7 @@ one of the following tracking labels. | batch_norm_grad | yes | revisit | yes | no | revisit | | batch_norm_inference | yes | revisit | yes | no | revisit | | batch_norm_training | yes | revisit | yes | no | revisit | -| bitcast_convert | yes | yes | infeasible | yes | no | +| bitcast_convert | yes | yes | infeasible | yes | yes | | broadcast | no | yes\* | yes\* | yes | revisit | | broadcast_in_dim | yes | yes | infeasible | yes | yes | | case | yes | revisit | yes | no | yes | diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 4ce6eb00573..e263a6bddd7 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -1823,11 +1823,11 @@ def StableHLO_BitcastConvertOp : StableHLO_ShapedInterfaceOp<"bitcast_convert", Example: ```mlir - %result = stablehlo.bitcast_convert %operand : (tensor<2xf32>) -> tensor<2x4xi8> + %result = stablehlo.bitcast_convert %operand : (tensor) -> tensor<4xf16> ``` }]; - let arguments = (ins HLO_Tensor:$operand); + let arguments = (ins HLO_Tensor:$operand /*bitcast_convert_i1*/); let results = (outs HLO_Tensor); let hasVerifier = 1; diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 7f4bf66b6d3..fa6228bef44 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -3114,21 +3114,12 @@ LogicalResult verifyAllReduceOp(std::optional location, Value operand, return success(); } -/* - * We intend to verify the following properties - * P1. We cannot convert between complex and real types (cf xla) - * P3. The dimensions of the operand and the target - * shape must match, except that the shape with the smaller element bitwidth has - * an appropriately-sized additional innermost dimension, e.g. - * ... x f32 => [bitcast_convert] => ... x 4 x i8 - * ... x 4 x i8 => [bitcast_convert] => ... x f32 - */ LogicalResult verifyBitcastConvertOp(std::optional location, Value operand, Value result) { auto operandShapedType = operand.getType().cast(); auto targetShapedType = result.getType().cast(); - // P1. + // bitcast_convert_c2 auto targetElt = targetShapedType.getElementType(); auto operandElt = operandShapedType.getElementType(); if (targetElt.isa() != operandElt.isa()) @@ -3139,7 +3130,6 @@ LogicalResult verifyBitcastConvertOp(std::optional location, auto targetEltBitwidth = potentiallyComplexBitwidth(targetElt); auto operandEltBitwidth = potentiallyComplexBitwidth(operandElt); - // P2. auto operandType = operandShapedType.dyn_cast(); auto targetType = targetShapedType.dyn_cast(); if (!operandType || !targetType) return success(); @@ -3147,28 +3137,25 @@ LogicalResult verifyBitcastConvertOp(std::optional location, auto targetShape = targetType.getShape(); auto operandShape = operandType.getShape(); ArrayRef smallerEltShape, biggerEltShape; - Type smallerElt, biggerElt; if (operandEltBitwidth < targetEltBitwidth) { smallerEltShape = operandShape; - smallerElt = operandElt; biggerEltShape = targetShape; - biggerElt = targetElt; } else { smallerEltShape = targetShape; - smallerElt = targetElt; biggerEltShape = operandShape; - biggerElt = operandElt; } ArrayRef smallerEltPrefix; auto smallerEltBitwidth = std::min(targetEltBitwidth, operandEltBitwidth); auto biggerEltBitwidth = std::max(targetEltBitwidth, operandEltBitwidth); + // bitcast_convert_c1 if (operandEltBitwidth != targetEltBitwidth) { - if (smallerEltShape.empty()) { - return emitOptionalError(location, - "does not allow the smaller element type to be " - "part of a 0d tensor, but got: ", - operandType, " and ", targetType, "."); + if (smallerEltShape.size() != biggerEltShape.size() + 1) { + return emitOptionalError( + location, "rank of smaller element type (", smallerEltShape.size(), + ") should be 1 more than rank of larger element type (", + biggerEltShape.size(), "), but ", smallerEltShape.size(), + " != ", biggerEltShape.size(), " + 1."); } smallerEltPrefix = smallerEltShape.drop_back(); if (!isDynamicDimSize(smallerEltShape.back()) && @@ -3185,6 +3172,7 @@ LogicalResult verifyBitcastConvertOp(std::optional location, for (auto it : llvm::zip(smallerEltPrefix, biggerEltShape)) { auto targetDim = std::get<0>(it); auto operandDim = std::get<1>(it); + // bitcast_convert_c1 if (!verifyCompatibleDims(targetDim, operandDim)) return emitOptionalError(location, "operand and result shapes must match except " diff --git a/stablehlo/reference/Element.cpp b/stablehlo/reference/Element.cpp index 0635e8d3379..e29c657b7a1 100644 --- a/stablehlo/reference/Element.cpp +++ b/stablehlo/reference/Element.cpp @@ -260,6 +260,46 @@ std::complex Element::getComplexValue() const { return std::complex(floatPair.first, floatPair.second); } +APInt Element::toBits() const { + if (isSupportedBooleanType(type_)) + return APInt(/*numBits=*/1, getBooleanValue() ? 1 : 0); + if (isSupportedIntegerType(type_)) return getIntegerValue(); + if (isSupportedFloatType(type_)) return getFloatValue().bitcastToAPInt(); + if (isSupportedComplexType(type_)) { + // Package the real part into the low half of the result bits, + // and the imaginary part into the high half of the result bits. + auto realBits = getComplexValue().real().bitcastToAPInt(); + auto imagBits = getComplexValue().imag().bitcastToAPInt(); + return imagBits.zext(numBits(type_)).shl(numBits(type_) / 2) + + realBits.zext(numBits(type_)); + } + report_fatal_error(invalidArgument("Unsupported element type: %s", + debugString(type_).c_str())); +} + +Element Element::fromBits(Type type, APInt bits) { + if (numBits(type) != bits.getBitWidth()) + llvm::report_fatal_error("numBits(type) != bits.getBitWidth()"); + if (isSupportedBooleanType(type)) return Element(type, !bits.isZero()); + if (isSupportedIntegerType(type)) return Element(type, bits); + if (isSupportedFloatType(type)) + return Element(type, + APFloat(type.cast().getFloatSemantics(), bits)); + if (isSupportedComplexType(type)) { + // Interpret the low half of the bits as the real part, and + // the high half of the bits as the imaginary part. + auto elementType = type.cast().getElementType(); + auto realBits = bits.extractBits(numBits(type) / 2, 0); + auto realElement = fromBits(elementType, realBits); + auto imagBits = bits.extractBits(numBits(type) / 2, numBits(type) / 2); + auto imagElement = fromBits(elementType, imagBits); + return Element(type, + {realElement.getFloatValue(), imagElement.getFloatValue()}); + } + report_fatal_error(invalidArgument("Unsupported element type: %s", + debugString(type).c_str())); +} + Element Element::operator!() const { return Element(IntegerType::get(getType().getContext(), 1), !getBooleanValue()); @@ -595,6 +635,52 @@ Element atan2(const Element &e1, const Element &e2) { debugString(type).c_str())); } +Element bitcastConvertManyToOne(Type type, ArrayRef elements) { + SmallVector results; + + auto resultNumBits = numBits(type); + auto operandNumBits = numBits(elements[0].getType()); + if (resultNumBits % operandNumBits != 0) + report_fatal_error(invalidArgument( + "Unsupported bitcast conversion from %s to %s", + debugString(elements[0].getType()).c_str(), debugString(type).c_str())); + + APInt resultBits(resultNumBits, 0); + for (auto element : llvm::reverse(elements)) { + if (operandNumBits != numBits(element.getType())) + llvm::report_fatal_error("All elements must have the same numBits"); + auto operandBits = element.toBits(); + resultBits = + resultBits.shl(operandNumBits) + operandBits.zext(resultNumBits); + } + return Element::fromBits(type, resultBits); +} + +SmallVector bitcastConvertOneToMany(Type type, const Element &el) { + SmallVector results; + + auto resultNumBits = numBits(type); + auto operandNumBits = numBits(el.getType()); + if (operandNumBits % resultNumBits != 0) + report_fatal_error(invalidArgument( + "Unsupported bitcast conversion from %s to %s", + debugString(el.getType()).c_str(), debugString(type).c_str())); + + for (auto i = 0; i < operandNumBits; i += resultNumBits) { + auto resultBits = el.toBits().extractBits(resultNumBits, i); + results.push_back(Element::fromBits(type, resultBits)); + } + return results; +} + +Element bitcastConvertOneToOne(Type type, const Element &el) { + if (numBits(type) != numBits(el.getType())) + report_fatal_error(invalidArgument( + "Unsupported bitcast conversion from %s to %s", + debugString(el.getType()).c_str(), debugString(type).c_str())); + return Element::fromBits(type, el.toBits()); +} + Element cbrt(const Element &el) { return mapWithUpcastToDouble( el, [](double e) { return std::cbrt(e); }, diff --git a/stablehlo/reference/Element.h b/stablehlo/reference/Element.h index d43eabdbe84..da98940aedf 100644 --- a/stablehlo/reference/Element.h +++ b/stablehlo/reference/Element.h @@ -75,6 +75,12 @@ class Element { /// complex type. std::complex getComplexValue() const; + /// Returns the implementation-defined bits of the underlying value. + APInt toBits() const; + + /// Creates an Element from implementation-defined bits. + static Element fromBits(Type type, APInt bits); + /// Overloaded not (logical) operator. Element operator!() const; @@ -149,6 +155,11 @@ Element atan2(const Element &e1, const Element &e2); /// individually equal modulo the tolerance. Element areApproximatelyEqual(const Element &e1, const Element &e2); +/// Various flavors of bitcast conversion as defined in the specification. +Element bitcastConvertOneToOne(Type type, const Element &e); +SmallVector bitcastConvertOneToMany(Type type, const Element &e); +Element bitcastConvertManyToOne(Type type, ArrayRef es); + /// Returns cube root of Element object. Element cbrt(const Element &e); diff --git a/stablehlo/reference/Ops.cpp b/stablehlo/reference/Ops.cpp index 2202f1e9f28..cd0e8d638f6 100644 --- a/stablehlo/reference/Ops.cpp +++ b/stablehlo/reference/Ops.cpp @@ -184,6 +184,10 @@ SmallVector eval( failOnDecomposableOp(op); } else if (auto batchNormTrainingOp = dyn_cast(op)) { failOnDecomposableOp(op); + } else if (auto bitcastConvertOp = dyn_cast(op)) { + auto operand = scope.findTensor(bitcastConvertOp.getOperand()); + auto result = evalBitcastConvertOp(operand, bitcastConvertOp.getType()); + scope.add(bitcastConvertOp.getResult(), result); } else if (auto broadcastInDimOp = dyn_cast(op)) { auto operand = scope.findTensor(broadcastInDimOp.getOperand()); auto broadcastDimensions = @@ -682,6 +686,44 @@ Tensor evalAtan2Op(const Tensor &lhs, const Tensor &rhs, return result; } +Tensor evalBitcastConvertOp(const Tensor &operand, ShapedType resultType) { + Tensor result(resultType); + + auto resultElementType = result.getElementType(); + auto resultNumBits = numBits(result.getElementType()); + auto operandNumBits = numBits(operand.getElementType()); + + if (resultNumBits < operandNumBits) { + auto resultIt = result.index_begin(); + for (auto operandIt = operand.index_begin(); + operandIt != operand.index_end(); ++operandIt) { + auto resultElements = + bitcastConvertOneToMany(resultElementType, operand.get(*operandIt)); + for (auto resultElement : resultElements) + result.set(*resultIt++, resultElement); + } + return result; + } + + if (resultNumBits > operandNumBits) { + auto operandIt = operand.index_begin(); + for (auto resultIt = result.index_begin(); resultIt != result.index_end(); + ++resultIt) { + SmallVector operandElements; + for (auto i = 0; i < resultNumBits / operandNumBits; ++i) + operandElements.push_back(operand.get(*operandIt++)); + result.set(*resultIt, + bitcastConvertManyToOne(resultElementType, operandElements)); + } + return result; + } + + for (auto it = result.index_begin(); it != result.index_end(); ++it) + result.set(*it, + bitcastConvertOneToOne(resultElementType, operand.get(*it))); + return result; +} + Tensor evalBroadcastInDimOp(const Tensor &operand, const Axes &broadcastDimensions, ShapedType resultType) { diff --git a/stablehlo/reference/Ops.h b/stablehlo/reference/Ops.h index cfbdeebda0f..d081b66ac7d 100644 --- a/stablehlo/reference/Ops.h +++ b/stablehlo/reference/Ops.h @@ -33,6 +33,7 @@ Tensor evalAddOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); Token evalAfterAllOp(ArrayRef inputs, MLIRContext *context); Tensor evalAndOp(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); Tensor evalAtan2Op(const Tensor &lhs, const Tensor &rhs, ShapedType resultType); +Tensor evalBitcastConvertOp(const Tensor &operand, ShapedType resultType); Tensor evalBroadcastInDimOp(const Tensor &operand, const Axes &broadcastDimensions, ShapedType resultType); diff --git a/stablehlo/reference/Types.cpp b/stablehlo/reference/Types.cpp index 80c7db78434..0f2ca1a99e8 100644 --- a/stablehlo/reference/Types.cpp +++ b/stablehlo/reference/Types.cpp @@ -58,5 +58,11 @@ bool isSupportedComplexType(Type type) { return complexElemTy.isF32() || complexElemTy.isF64(); } +int64_t numBits(Type type) { + if (isSupportedComplexType(type)) + return numBits(type.cast().getElementType()) * 2; + return type.getIntOrFloatBitWidth(); +} + } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/reference/Types.h b/stablehlo/reference/Types.h index 27a5bc1c987..e8b1467cd0c 100644 --- a/stablehlo/reference/Types.h +++ b/stablehlo/reference/Types.h @@ -46,6 +46,14 @@ bool isSupportedFloatType(Type type); /// StableHLO specification. Such types are: complex and complex. bool isSupportedComplexType(Type type); +/// Returns the number of bits in the representation of an element type. +/// * For boolean type: 1. +/// * For integer types: bit width (e.g. 32 for si32). +/// * For floating-point types: bit width (e.g. 32 for f32). +/// * For complex types: 2x the bit width of element type (e.g. 64 for +/// complex). +int64_t numBits(Type type); + } // namespace stablehlo } // namespace mlir diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_bfloat16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_bfloat16.mlir index 1611b7f853e..d7eb0d0f7d5 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_bfloat16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_bfloat16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_float16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_float16.mlir index 3b63998983d..97bef492bfb 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_float16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_float16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_int16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_int16.mlir index 140bab87a91..74cd1df16ed 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_int16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_int16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_uint16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_uint16.mlir index 74bf3445185..596805d2467 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_uint16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bfloat16_2_3__newdtype_uint16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bool_2_3__newdtype_bool.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bool_2_3__newdtype_bool.mlir index e58958519d2..40aa760bd77 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bool_2_3__newdtype_bool.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_bool_2_3__newdtype_bool.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_complex64_2_3__newdtype_complex64.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_complex64_2_3__newdtype_complex64.mlir index f508298a25f..000d300ec7b 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_complex64_2_3__newdtype_complex64.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_complex64_2_3__newdtype_complex64.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_bfloat16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_bfloat16.mlir index 21cd9e87830..60e5b0f3c92 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_bfloat16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_bfloat16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_float16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_float16.mlir index f22f7f37860..2b7876aec19 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_float16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_float16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_int16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_int16.mlir index 1f15a0686f7..808c5063866 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_int16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_int16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_uint16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_uint16.mlir index 6c52dd2f112..91e55e24e6b 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_uint16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float16_2_3__newdtype_uint16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_float32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_float32.mlir index e0ca6ea0f39..846d5d4a786 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_float32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_float32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_int32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_int32.mlir index 5005b873421..26d9e7a52eb 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_int32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_int32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_uint32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_uint32.mlir index ebf37119543..48b6b78e5d3 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_uint32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_float32_2_3__newdtype_uint32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_bfloat16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_bfloat16.mlir index 2efd8b67ffd..5251fde2da3 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_bfloat16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_bfloat16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_float16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_float16.mlir index 28f0ffbbb2d..41decaab22d 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_float16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_float16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_int16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_int16.mlir index 0107f7c2afe..53e3869fa21 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_int16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_int16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_uint16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_uint16.mlir index bf05b206681..febf36280cd 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_uint16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int16_2_3__newdtype_uint16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_float32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_float32.mlir index 49a3cbf7d64..2103c08a561 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_float32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_float32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_int32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_int32.mlir index 847e218c3b5..eb2a4a9e0d8 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_int32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_int32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_uint32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_uint32.mlir index 81b4313c56b..98197c72e04 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_uint32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int32_2_3__newdtype_uint32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_int8.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_int8.mlir index 517628c8481..acb71079116 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_int8.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_int8.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_uint8.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_uint8.mlir index 8831c97bf77..066c075dce4 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_uint8.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_int8_2_3__newdtype_uint8.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_bfloat16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_bfloat16.mlir index bdbd0edd1ba..572d6c07c3a 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_bfloat16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_bfloat16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_float16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_float16.mlir index 89d0afdf929..3774bf2a8f3 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_float16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_float16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_int16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_int16.mlir index f22b7c62dcc..580f17c9a5d 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_int16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_int16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_uint16.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_uint16.mlir index 3026aa8ebb1..e641a75c64d 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_uint16.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint16_2_3__newdtype_uint16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_float32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_float32.mlir index cee159f91d0..abb7b6f5804 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_float32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_float32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_int32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_int32.mlir index be234c5a6d7..4034b29ee46 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_int32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_int32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_uint32.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_uint32.mlir index 9b38dd9e931..f5c91777293 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_uint32.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint32_2_3__newdtype_uint32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_int8.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_int8.mlir index 864fdfeee5d..a40183d3756 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_int8.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_int8.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_uint8.mlir b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_uint8.mlir index 1313455f06d..a52289b6bbd 100644 --- a/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_uint8.mlir +++ b/stablehlo/testdata/bitcast_convert_type_dtypes_to_new_dtypes_shape_uint8_2_3__newdtype_uint8.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/nextafter_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir b/stablehlo/testdata/nextafter_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir index 5b2adec0202..4442ab03126 100644 --- a/stablehlo/testdata/nextafter_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir +++ b/stablehlo/testdata/nextafter_broadcasting_lhs_float32_1_20__rhs_float32_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/nextafter_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir b/stablehlo/testdata/nextafter_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir index 89e803d71ca..29c2fe49d8e 100644 --- a/stablehlo/testdata/nextafter_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir +++ b/stablehlo/testdata/nextafter_broadcasting_lhs_float32_20_20__rhs_float32_1_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/nextafter_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir b/stablehlo/testdata/nextafter_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir index 5229f933343..dc88f076c0e 100644 --- a/stablehlo/testdata/nextafter_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir +++ b/stablehlo/testdata/nextafter_dtypes_lhs_bfloat16_20_20__rhs_bfloat16_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/nextafter_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir b/stablehlo/testdata/nextafter_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir index b318e2de2e2..ee45a30299a 100644 --- a/stablehlo/testdata/nextafter_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir +++ b/stablehlo/testdata/nextafter_dtypes_lhs_float16_20_20__rhs_float16_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/nextafter_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir b/stablehlo/testdata/nextafter_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir index 55076906b68..b8d03debe98 100644 --- a/stablehlo/testdata/nextafter_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir +++ b/stablehlo/testdata/nextafter_dtypes_lhs_float32_20_20__rhs_float32_20_20.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_0.mlir b/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_0.mlir index cad493a7d20..653217b9a8e 100644 --- a/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_0.mlir +++ b/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_0.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_1.mlir b/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_1.mlir index bbae25f724e..dca0a92efe9 100644 --- a/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_1.mlir +++ b/stablehlo/testdata/random_categorical_shape_bfloat16_5_4__axis_1.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_bfloat16_8__axis_0.mlir b/stablehlo/testdata/random_categorical_shape_bfloat16_8__axis_0.mlir index c74b4b61c48..b36befe6b75 100644 --- a/stablehlo/testdata/random_categorical_shape_bfloat16_8__axis_0.mlir +++ b/stablehlo/testdata/random_categorical_shape_bfloat16_8__axis_0.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_0.mlir b/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_0.mlir index f8625444aaa..476570cfcf7 100644 --- a/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_0.mlir +++ b/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_0.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_1.mlir b/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_1.mlir index 42ebc2712d0..e814fdf3ed0 100644 --- a/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_1.mlir +++ b/stablehlo/testdata/random_categorical_shape_float16_5_4__axis_1.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_float16_8__axis_0.mlir b/stablehlo/testdata/random_categorical_shape_float16_8__axis_0.mlir index 603e0d6c448..eaf258231ad 100644 --- a/stablehlo/testdata/random_categorical_shape_float16_8__axis_0.mlir +++ b/stablehlo/testdata/random_categorical_shape_float16_8__axis_0.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_0.mlir b/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_0.mlir index c3c47542819..71acfc0b4f4 100644 --- a/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_0.mlir +++ b/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_0.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_1.mlir b/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_1.mlir index 6ca0182a3fb..d7b1b7ec79f 100644 --- a/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_1.mlir +++ b/stablehlo/testdata/random_categorical_shape_float32_5_4__axis_1.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_categorical_shape_float32_8__axis_0.mlir b/stablehlo/testdata/random_categorical_shape_float32_8__axis_0.mlir index 11fdbe5c324..2f63a93cb52 100644 --- a/stablehlo/testdata/random_categorical_shape_float32_8__axis_0.mlir +++ b/stablehlo/testdata/random_categorical_shape_float32_8__axis_0.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_gamma_shape_float32.mlir b/stablehlo/testdata/random_gamma_shape_float32.mlir index 21ba2a4bba9..0ae9e00bfed 100644 --- a/stablehlo/testdata/random_gamma_shape_float32.mlir +++ b/stablehlo/testdata/random_gamma_shape_float32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_gamma_shape_float32_3.mlir b/stablehlo/testdata/random_gamma_shape_float32_3.mlir index 5d07675c881..512c4d10bd6 100644 --- a/stablehlo/testdata/random_gamma_shape_float32_3.mlir +++ b/stablehlo/testdata/random_gamma_shape_float32_3.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_gamma_shape_float64.mlir b/stablehlo/testdata/random_gamma_shape_float64.mlir index 15e5a29f32b..334fbd4f753 100644 --- a/stablehlo/testdata/random_gamma_shape_float64.mlir +++ b/stablehlo/testdata/random_gamma_shape_float64.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_gamma_shape_float64_3.mlir b/stablehlo/testdata/random_gamma_shape_float64_3.mlir index 937b66dc90c..cda447aa42f 100644 --- a/stablehlo/testdata/random_gamma_shape_float64_3.mlir +++ b/stablehlo/testdata/random_gamma_shape_float64_3.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_bfloat16.mlir b/stablehlo/testdata/random_uniform_shape_bfloat16.mlir index 8fb769833f9..5248638fb22 100644 --- a/stablehlo/testdata/random_uniform_shape_bfloat16.mlir +++ b/stablehlo/testdata/random_uniform_shape_bfloat16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_bfloat16_32.mlir b/stablehlo/testdata/random_uniform_shape_bfloat16_32.mlir index 4e7e291b535..b6b3c70f935 100644 --- a/stablehlo/testdata/random_uniform_shape_bfloat16_32.mlir +++ b/stablehlo/testdata/random_uniform_shape_bfloat16_32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_bfloat16_5_4.mlir b/stablehlo/testdata/random_uniform_shape_bfloat16_5_4.mlir index 79b0883691a..d339a70cf95 100644 --- a/stablehlo/testdata/random_uniform_shape_bfloat16_5_4.mlir +++ b/stablehlo/testdata/random_uniform_shape_bfloat16_5_4.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_float16.mlir b/stablehlo/testdata/random_uniform_shape_float16.mlir index d94f5ea4d6f..08e4e3bfd41 100644 --- a/stablehlo/testdata/random_uniform_shape_float16.mlir +++ b/stablehlo/testdata/random_uniform_shape_float16.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_float16_32.mlir b/stablehlo/testdata/random_uniform_shape_float16_32.mlir index 8aa0a53b195..e000dc66679 100644 --- a/stablehlo/testdata/random_uniform_shape_float16_32.mlir +++ b/stablehlo/testdata/random_uniform_shape_float16_32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_float16_5_4.mlir b/stablehlo/testdata/random_uniform_shape_float16_5_4.mlir index b7a3881aacc..2865d0d1110 100644 --- a/stablehlo/testdata/random_uniform_shape_float16_5_4.mlir +++ b/stablehlo/testdata/random_uniform_shape_float16_5_4.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_float32.mlir b/stablehlo/testdata/random_uniform_shape_float32.mlir index 243771c5427..f5a2012f455 100644 --- a/stablehlo/testdata/random_uniform_shape_float32.mlir +++ b/stablehlo/testdata/random_uniform_shape_float32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_float32_32.mlir b/stablehlo/testdata/random_uniform_shape_float32_32.mlir index b7679bc29fa..713aba18998 100644 --- a/stablehlo/testdata/random_uniform_shape_float32_32.mlir +++ b/stablehlo/testdata/random_uniform_shape_float32_32.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/testdata/random_uniform_shape_float32_5_4.mlir b/stablehlo/testdata/random_uniform_shape_float32_5_4.mlir index 43efb8b47b6..752454a2e4d 100644 --- a/stablehlo/testdata/random_uniform_shape_float32_5_4.mlir +++ b/stablehlo/testdata/random_uniform_shape_float32_5_4.mlir @@ -1,4 +1,4 @@ -// RUN-DISABLED: stablehlo-opt -inline %s | stablehlo-translate --interpret +// RUN: stablehlo-opt -inline %s | stablehlo-translate --interpret // RUN: diff <(stablehlo-translate --serialize --target=current %s | stablehlo-translate --deserialize | stablehlo-opt) <(stablehlo-opt %s) module @jit_testcase { diff --git a/stablehlo/tests/interpret_bitcast_convert.mlir b/stablehlo/tests/interpret_bitcast_convert.mlir new file mode 100644 index 00000000000..686ac08a61d --- /dev/null +++ b/stablehlo/tests/interpret_bitcast_convert.mlir @@ -0,0 +1,65 @@ +// RUN: stablehlo-translate --interpret -split-input-file %s + +func.func @bitcast_convert_op_test_i1_to_i64() { + %operand = stablehlo.constant dense<[true, true, true, true, + false, true, true, true, + true, false, true, true, + false, false, true, true, + true, true, false, true, + false, true, false, true, + true, false, false, true, + false, false, false, true, + true, true, true, false, + false, true, true, false, + true, false, true, false, + false, false, true, false, + true, true, false, false, + false, true, false, false, + true, false, false, false, + false, false, false, false]> : tensor<64xi1> + %result = stablehlo.bitcast_convert %operand : (tensor<64xi1>) -> tensor + check.expect_eq_const %result, dense<0x0123456789ABCDEF> : tensor + func.return +} + +// ----- + +func.func @bitcast_convert_op_test_i64_to_f64() { + %operand = stablehlo.constant dense<0x0123456789ABCDEF> : tensor + %result = stablehlo.bitcast_convert %operand : (tensor) -> tensor + check.expect_almost_eq_const %result, dense<0x0123456789ABCDEF> : tensor + func.return +} + +// ----- + +func.func @bitcast_convert_op_test_f64_to_i1() { + %operand = stablehlo.constant dense<0x0123456789ABCDEF> : tensor + %result = stablehlo.bitcast_convert %operand : (tensor) -> tensor<64xi1> + check.expect_eq_const %result, dense<[true, true, true, true, + false, true, true, true, + true, false, true, true, + false, false, true, true, + true, true, false, true, + false, true, false, true, + true, false, false, true, + false, false, false, true, + true, true, true, false, + false, true, true, false, + true, false, true, false, + false, false, true, false, + true, true, false, false, + false, true, false, false, + true, false, false, false, + false, false, false, false]> : tensor<64xi1> + func.return +} + +// ----- + +func.func @bitcast_convert_op_test_c128_to_c64() { + %operand = stablehlo.constant dense<(0x0123456789ABCDEF, 0x0000000011111111)> : tensor> + %result = stablehlo.bitcast_convert %operand : (tensor>) -> tensor<2xcomplex> + check.expect_eq_const %result, dense<[(0x89ABCDEF, 0x01234567), (0x11111111, 0x00000000)]> : tensor<2xcomplex> + func.return +} diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index 7306835dfd1..463945fd521 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -3273,51 +3273,39 @@ func.func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> { // ----- -func.func @bitcast_convert_int(%arg: tensor<2xf32>) -> tensor<2x4xi8> { +// CHECK-LABEL: func @bitcast_convert +func.func @bitcast_convert(%arg: tensor<2xf32>) -> tensor<2x4xi8> { %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2xf32>) -> tensor<2x4xi8> return %0 : tensor<2x4xi8> } // ----- -func.func @bitcast_convert_from_int(%arg: tensor<2x4xi8>) -> tensor<2xf32> { +// CHECK-LABEL: func @bitcast_convert +func.func @bitcast_convert(%arg: tensor<2x4xi8>) -> tensor<2xf32> { %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2x4xi8>) -> tensor<2xf32> return %0 : tensor<2xf32> } // ----- - -func.func @bitcast_convert_complex(%arg: tensor>) -> tensor<2xcomplex> { +// CHECK-LABEL: func @bitcast_convert +func.func @bitcast_convert(%arg: tensor>) -> tensor<2xcomplex> { %0 = "stablehlo.bitcast_convert"(%arg) : (tensor>) -> tensor<2xcomplex> return %0 : tensor<2xcomplex> } // ----- -func.func @invalid_bitcast_convert_decomplex(%arg: tensor<2x4xcomplex>) -> tensor<2x2xf64> { - // expected-error@+1 {{cannot convert between real and complex types}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2x4xcomplex>) -> tensor<2x2xf64> - return %0 : tensor<2x2xf64> -} - -// ----- - -func.func @bitcast_convert_scalar(%arg: tensor) -> tensor { +// CHECK-LABEL: func @bitcast_convert +func.func @bitcast_convert(%arg: tensor) -> tensor { %0 = "stablehlo.bitcast_convert"(%arg) : (tensor) -> tensor return %0 : tensor } // ----- -func.func @bitcast_convert_invalid_scalar(%arg: tensor) -> tensor { - // expected-error@+1 {{does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor' and 'tensor'.}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor) -> tensor - return %0 : tensor -} - -// ----- - +// CHECK-LABEL: func @bitcast_convert func.func @bitcast_convert(%arg: tensor<*xf32>) -> tensor<*xf32> { %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -3325,50 +3313,42 @@ func.func @bitcast_convert(%arg: tensor<*xf32>) -> tensor<*xf32> { // ----- -func.func @invalid_bitcast_convert_width_mismatch(%arg: tensor<2x4xf64>) -> tensor<2x4xf32> { - // expected-error@+1 {{requires compatible bitwidths. Got: 'tensor<2x4xf64>' and 'tensor<2x4xf32>', but 32 * 4 != 64.}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2x4xf64>) -> tensor<2x4xf32> - return %0 : tensor<2x4xf32> -} - -// ----- - -func.func @bitcast_convert_width_mismatch(%arg: tensor) -> tensor { - // expected-error@+1 {{does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor' and 'tensor'.}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor) -> tensor - return %0 : tensor +func.func @bitcast_convert_c1(%arg: tensor<2xf64>) -> tensor<3xi64> { + // expected-error@+1 {{operand and result shapes must match except for the innermost dimension of the shape with the smaller element type. Got: 'tensor<2xf64>' and 'tensor<3xi64>'.}} + %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2xf64>) -> tensor<3xi64> + return %0 : tensor<3xi64> } // ----- -func.func @bitcast_convert_empty_target(%arg: tensor<1xf64>) -> tensor { - // expected-error@+1 {{does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor<1xf64>' and 'tensor'.}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<1xf64>) -> tensor +func.func @bitcast_convert_c1(%arg: tensor) -> tensor { + // expected-error@+1 {{rank of smaller element type (0) should be 1 more than rank of larger element type (0), but 0 != 0 + 1.}} + %0 = "stablehlo.bitcast_convert"(%arg) : (tensor) -> tensor return %0 : tensor } // ----- -func.func @bitcast_convert_empty_operand(%arg: tensor) -> tensor<1xf64> { - // expected-error@+1 {{does not allow the smaller element type to be part of a 0d tensor, but got: 'tensor' and 'tensor<1xf64>'.}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor) -> tensor<1xf64> - return %0 : tensor<1xf64> +func.func @bitcast_convert_c1(%arg: tensor<2xf64>) -> tensor<4x2xf32> { + // expected-error@+1 {{operand and result shapes must match except for the innermost dimension of the shape with the smaller element type. Got: 'tensor<2xf64>' and 'tensor<4x2xf32>'.}} + %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2xf64>) -> tensor<4x2xf32> + return %0 : tensor<4x2xf32> } // ----- -func.func @invalid_bitcast_convert_width_mismatch(%arg: tensor<2x4xf32>) -> tensor<2x4xf64> { - // expected-error@+1 {{requires compatible bitwidths. Got: 'tensor<2x4xf32>' and 'tensor<2x4xf64>', but 32 * 4 != 64.}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf64> - return %0 : tensor<2x4xf64> +func.func @bitcast_convert_c1(%arg: tensor<2xf64>) -> tensor<2x4xf32> { + // expected-error@+1 {{requires compatible bitwidths. Got: 'tensor<2xf64>' and 'tensor<2x4xf32>', but 32 * 4 != 64.}} + %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2xf64>) -> tensor<2x4xf32> + return %0 : tensor<2x4xf32> } // ----- -func.func @invalid_bitcast_convert_shape_mismatch(%arg: tensor<2x4xf32>) -> tensor<4x4xf32> { - // expected-error@+1 {{operand and result shapes must match except for the innermost dimension of the shape with the smaller element type. Got: 'tensor<2x4xf32>' and 'tensor<4x4xf32>'.}} - %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2x4xf32>) -> tensor<4x4xf32> - return %0 : tensor<4x4xf32> +func.func @bitcast_convert_c2(%arg: tensor<2x4xcomplex>) -> tensor<2x2xf64> { + // expected-error@+1 {{cannot convert between real and complex types}} + %0 = "stablehlo.bitcast_convert"(%arg) : (tensor<2x4xcomplex>) -> tensor<2x2xf64> + return %0 : tensor<2x2xf64> } // -----