Skip to content

Commit

Permalink
Cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
kasper0406 committed Oct 19, 2024
1 parent d5da538 commit 52b49ac
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 26 deletions.
29 changes: 14 additions & 15 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
//
//===----------------------------------------------------------------------===//

#include <iostream>

#include <optional>
#include <string_view>
#include <utility>
#include <memory>

#include "IRModule.h"

Expand Down Expand Up @@ -995,7 +992,7 @@ class PyDenseElementsAttribute
} else if (format == "?") {
// i1
// The i1 type needs to be bit-packed, so we will handle it seperately
return getAttributeFromBufferBoolean(view, explicitShape, context);
return getAttributeFromBufferBoolBitpack(view, explicitShape, context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
Expand Down Expand Up @@ -1047,17 +1044,21 @@ class PyDenseElementsAttribute
}
}

return mlirDenseElementsAttrRawBufferGet(getShapedType(bulkLoadElementType, explicitShape, view), view.len, view.buf);
MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
}

static MlirAttribute getAttributeFromBufferBoolean(Py_buffer& view,
std::optional<std::vector<int64_t>> explicitShape,
MlirContext& context) {
// There is a complication for boolean numpy arrays, as numpy represent them as
// 8 bits per boolean, whereas MLIR bitpacks them into 8 booleans per byte.
// This function does the bit-packing respecting endianess.
static MlirAttribute getAttributeFromBufferBoolBitpack(Py_buffer& view,
std::optional<std::vector<int64_t>> explicitShape,
MlirContext& context) {
// First read the content of the python buffer as u8's, to correct for endianess
MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(
getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view), view.len, view.buf);
MlirType byteType = getShapedType(mlirIntegerTypeUnsignedGet(context, 8), explicitShape, view);
MlirAttribute intermediateAttr = mlirDenseElementsAttrRawBufferGet(byteType, view.len, view.buf);

// Pack the boolean array according to the i8 bitpacking layout
// Pack the boolean array according to the i1 bitpacking layout
const int numPackedBytes = (view.len + 7) / 8;
SmallVector<uint8_t, 8> bitpacked(numPackedBytes);
for (int byteNum = 0; byteNum < numPackedBytes; byteNum++) {
Expand All @@ -1070,8 +1071,8 @@ class PyDenseElementsAttribute
bitpacked[byteNum] = byte;
}

return mlirDenseElementsAttrRawBufferGet(getShapedType(
mlirIntegerTypeGet(context, 1), explicitShape, view), numPackedBytes, bitpacked.data());
MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(bitpackedType, numPackedBytes, bitpacked.data());
}

template <typename Type>
Expand Down Expand Up @@ -1145,7 +1146,6 @@ class PyDenseIntElementsAttribute
bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
if (isUnsigned) {
if (width == 1) {
std::cerr << "Loading unsigned i1 values at position: " << pos << std::endl;
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 8) {
Expand All @@ -1162,7 +1162,6 @@ class PyDenseIntElementsAttribute
}
} else {
if (width == 1) {
std::cerr << "Loading signed i1 values at position: " << pos << std::endl;
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
if (width == 8) {
Expand Down
9 changes: 2 additions & 7 deletions mlir/lib/CAPI/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//===----------------------------------------------------------------------===//

#include <iostream>

#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/AffineMap.h"
Expand Down Expand Up @@ -529,11 +527,8 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType,
rawBufferSize);
bool isSplat = false;
if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp,
isSplat)) {
std::cerr << "NULL POINTER!!!" << std::endl;
isSplat))
return mlirAttributeGetNull();
}
std::cerr << "Pointer looks ok..." << std::endl;
return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp));
}

Expand Down Expand Up @@ -593,7 +588,7 @@ MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
const int *elements) {
SmallVector<bool, 8> values(elements, elements + numElements);
return wrap(DenseElementsAttr::get(llvm::cast<ShapedType>(unwrap(shapedType)),
values));
values));
}

/// Creates a dense attribute with elements of the type deduced by templates.
Expand Down
4 changes: 0 additions & 4 deletions mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
//
//===----------------------------------------------------------------------===//

#include <iostream>

#include "mlir/IR/BuiltinAttributes.h"
#include "AttributeDetail.h"
#include "mlir/IR/AffineMap.h"
Expand Down Expand Up @@ -1090,8 +1088,6 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
}

// This is a valid non-splat buffer if it has the right size.
std::cerr << "Raw buffer width: " << rawBufferWidth << std::endl;
std::cerr << "Aligned to width: " << llvm::alignTo<8>(numElements) << std::endl;
return rawBufferWidth == llvm::alignTo<8>(numElements);
}

Expand Down

0 comments on commit 52b49ac

Please sign in to comment.