Skip to content

Commit

Permalink
torch-mlir change for dense resource implementation (llvm#2513)
Browse files Browse the repository at this point in the history
Co-authored-by: Avinash Sharma <avinash@nod-labs.com>
  • Loading branch information
saienduri and aviator19941 authored Nov 3, 2023
1 parent 1b9fb1b commit 88adf38
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,18 @@ class ConvertTorchTensorLiteralOp
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
Type elemTy = op.getValueAttr().getElementType();
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
return APInt(bitWidth, v.getSExtValue());
}));
return success();
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
Type elemTy = op.getValueAttr().getElementType();
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
auto shapedType =
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
auto rawData = elements.getRawData();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
shapedType, rawData);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
return success();
}
}
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
Expand All @@ -190,7 +194,8 @@ class ConvertTorchTensorLiteralOp
AsmResourceBlob *blob = elements.getRawHandle().getBlob();
assert(blob && "Expecting dense resource with a valid blob");
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, DenseElementsAttr::get(shapedType, blob->getData()));
op, DenseResourceElementsAttr::get(shapedType,
elements.getRawHandle()));
return success();
}
}
Expand Down

0 comments on commit 88adf38

Please sign in to comment.