Skip to content

Commit

Permalink
[onnx] Add IDF and TFIDF modes to TFIDF Vectorizer (#3726)
Browse files Browse the repository at this point in the history
  • Loading branch information
samutamm authored Oct 2, 2024
1 parent 617c1c7 commit a2bfe47
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,31 @@ struct OpBinder {
return failure();
}

ParseResult f32FloatArrayAttr(llvm::SmallVector<float> &values,
StringRef nameSuffix,
ArrayRef<float> defaults) {
SmallString<64> name("torch.onnx.");
name.append(nameSuffix);
auto attr = op->getAttr(name);
if (!attr) {
values.append(defaults.begin(), defaults.end());
return success();
}
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto element : arrayAttr) {
auto floatAttr = dyn_cast<FloatAttr>(element);
if (!floatAttr)
return failure();
FloatType t = cast<FloatType>(floatAttr.getType());
if (t.getWidth() != 32)
return failure();
values.push_back(floatAttr.getValue().convertToFloat());
}
return success();
}
return failure();
}

ParseResult stringArrayAttr(llvm::SmallVector<std::string> &values,
StringRef nameSuffix) {
SmallString<64> name("torch.onnx.");
Expand Down
37 changes: 33 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4339,6 +4339,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<int64_t> ngram_counts;
llvm::SmallVector<int64_t> ngram_indexes;
llvm::SmallVector<int64_t> pool_int64s;
llvm::SmallVector<float> weights;
std::string mode;
int64_t min_gram_length;
int64_t max_gram_length;
Expand All @@ -4356,9 +4357,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.tensorOperand(input) || binder.tensorResultType(resultType))
return failure();

if (mode != "TF")
return rewriter.notifyMatchFailure(binder.op,
"TF mode supported only");
llvm::SmallVector<float> defaultWeights(ngram_indexes.size(), 1.0f);
if (binder.f32FloatArrayAttr(weights, "weights", defaultWeights))
return failure();

if (pool_int64s.size() == 0)
return rewriter.notifyMatchFailure(
binder.op, "pool_int64s empty, only integers supported");
Expand Down Expand Up @@ -4584,9 +4586,36 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.getLoc(), loopConditionTrue, ValueRange({count}));
}
count = skipLoop.getResult(0);
// insert count "tf" into output
Value countFloat = rewriter.create<Torch::AtenFloatScalarOp>(
binder.getLoc(), count);
if (mode == "IDF" || mode == "TFIDF") {
// both IDF and TFIDF modes use weights
float weight = weights[ngram_i];
Value constWeight = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(weight));

// TFIDF
Value multiplier = countFloat;
if (mode == "IDF") {
// All the counts larger than 1 would be truncated to 1
// and the i-th element in weights would be used to scale
// (by multiplication) the count of the i-th n-gram in pool.

Value intCount = rewriter.create<Torch::AtenIntScalarOp>(
binder.getLoc(), count);
// compare intCount > 0
Value gtZeroCount = rewriter.create<Torch::AtenGtIntOp>(
binder.getLoc(), intCount, zero);
gtZeroCount = rewriter.create<Torch::AtenIntBoolOp>(
binder.getLoc(), gtZeroCount);
Value gtZeroCountFloat =
rewriter.create<Torch::AtenFloatScalarOp>(binder.getLoc(),
gtZeroCount);
multiplier = gtZeroCountFloat;
}
countFloat = rewriter.create<Torch::AtenMulFloatOp>(
binder.getLoc(), multiplier, constWeight);
}
Value dataList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
Expand Down

0 comments on commit a2bfe47

Please sign in to comment.