66//
77// ===----------------------------------------------------------------------===//
88
9- #include < iostream>
10-
119#include < optional>
1210#include < string_view>
1311#include < utility>
14- #include < memory>
1512
1613#include " IRModule.h"
1714
@@ -995,7 +992,7 @@ class PyDenseElementsAttribute
995992 } else if (format == " ?" ) {
996993 // i1
997994 // The i1 type needs to be bit-packed, so we will handle it seperately
998- return getAttributeFromBufferBoolean (view, explicitShape, context);
995+ return getAttributeFromBufferBoolBitpack (view, explicitShape, context);
999996 } else if (isSignedIntegerFormat (format)) {
1000997 if (view.itemsize == 4 ) {
1001998 // i32
@@ -1047,17 +1044,21 @@ class PyDenseElementsAttribute
10471044 }
10481045 }
10491046
1050- return mlirDenseElementsAttrRawBufferGet (getShapedType (bulkLoadElementType, explicitShape, view), view.len , view.buf );
1047+ MlirType type = getShapedType (bulkLoadElementType, explicitShape, view);
1048+ return mlirDenseElementsAttrRawBufferGet (type, view.len , view.buf );
10511049 }
10521050
1053- static MlirAttribute getAttributeFromBufferBoolean (Py_buffer& view,
1054- std::optional<std::vector<int64_t >> explicitShape,
1055- MlirContext& context) {
1051+ // There is a complication for boolean numpy arrays, as numpy represent them as
1052+ // 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
1053+ // This function does the bit-packing respecting endianess.
1054+ static MlirAttribute getAttributeFromBufferBoolBitpack (Py_buffer& view,
1055+ std::optional<std::vector<int64_t >> explicitShape,
1056+ MlirContext& context) {
10561057 // First read the content of the python buffer as u8's, to correct for endianess
1057- MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet (
1058- getShapedType ( mlirIntegerTypeUnsignedGet (context, 8 ), explicitShape, view) , view.len , view.buf );
1058+ MlirType byteType = getShapedType ( mlirIntegerTypeUnsignedGet (context, 8 ), explicitShape, view);
1059+ MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet (byteType , view.len , view.buf );
10591060
1060- // Pack the boolean array according to the i8 bitpacking layout
1061+ // Pack the boolean array according to the i1 bitpacking layout
10611062 const int numPackedBytes = (view.len + 7 ) / 8 ;
10621063 SmallVector<uint8_t , 8 > bitpacked (numPackedBytes);
10631064 for (int byteNum = 0 ; byteNum < numPackedBytes; byteNum++) {
@@ -1070,8 +1071,8 @@ class PyDenseElementsAttribute
10701071 bitpacked[byteNum] = byte;
10711072 }
10721073
1073- return mlirDenseElementsAttrRawBufferGet ( getShapedType (
1074- mlirIntegerTypeGet (context, 1 ), explicitShape, view) , numPackedBytes, bitpacked.data ());
1074+ MlirType bitpackedType = getShapedType (mlirIntegerTypeGet (context, 1 ), explicitShape, view);
1075+ return mlirDenseElementsAttrRawBufferGet (bitpackedType , numPackedBytes, bitpacked.data ());
10751076 }
10761077
10771078 template <typename Type>
@@ -1145,7 +1146,6 @@ class PyDenseIntElementsAttribute
11451146 bool isUnsigned = mlirIntegerTypeIsUnsigned (type);
11461147 if (isUnsigned) {
11471148 if (width == 1 ) {
1148- std::cerr << " Loading unsigned i1 values at position: " << pos << std::endl;
11491149 return mlirDenseElementsAttrGetBoolValue (*this , pos);
11501150 }
11511151 if (width == 8 ) {
@@ -1162,7 +1162,6 @@ class PyDenseIntElementsAttribute
11621162 }
11631163 } else {
11641164 if (width == 1 ) {
1165- std::cerr << " Loading signed i1 values at position: " << pos << std::endl;
11661165 return mlirDenseElementsAttrGetBoolValue (*this , pos);
11671166 }
11681167 if (width == 8 ) {
0 commit comments