Skip to content

Commit acb0e74

Browse files
committed
Respond to review
Change-Id: I0569534397a2d0db9587db6424b1674846a76079
1 parent b1ab879 commit acb0e74

File tree

5 files changed

+21
-13
lines changed

5 files changed

+21
-13
lines changed

src/arith/analyzer.cc

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,17 +235,14 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) {
235235
// SVE, we can make some assumptions about the value of vscale and iterate over a
236236
// space of pre-defined values to attempt to prove the expression.
237237
if (tir::CheckContains::ExprContains(expr, IsVScaleCall)) {
238-
Target curr_target = tvm::Target::Current();
239-
if (curr_target.defined() && curr_target->features.defined() &&
240-
(curr_target->features.find("has_sve") != curr_target->features.end()) &&
241-
curr_target->GetFeature<Bool>("has_sve").value_or(Bool(false)).operator bool()) {
238+
if (TargetHasSVE()) {
242239
return CanProveVscaleExpressionFromKnownValues(this, simplified, kAArch64VScaleValues);
243240
}
244241
LOG(WARNING)
245242
<< "The expression contains scalable values. An attempt to prove by substituting "
246243
"with known values of vscale was not performed. This proof currently only supports "
247244
"AArch64 SVE targets, but the target was "
248-
<< curr_target;
245+
<< Target::Current();
249246
}
250247
return false;
251248
}

src/arith/scalable_expression.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,14 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr
8888
return can_prove_expr;
8989
}
9090

91+
bool TargetHasSVE() {
92+
Target current_target = Target::Current();
93+
bool has_sve{false};
94+
if (current_target.defined()) {
95+
has_sve = current_target->GetFeature<Bool>("has_sve").value_or(Bool(false));
96+
}
97+
return has_sve;
98+
}
99+
91100
} // namespace arith
92101
} // namespace tvm

src/arith/scalable_expression.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
7171
bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const PrimExpr& expr,
7272
const std::vector<unsigned int>& vscale_values);
7373

74+
/*!
75+
* \brief Check whether the compilation target supports SVE
76+
* \return Whether SVE is supported
77+
*/
78+
bool TargetHasSVE();
79+
7480
} // namespace arith
7581
} // namespace tvm
7682

src/tir/transforms/vectorize_loop.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -731,15 +731,12 @@ class LoopVectorizer : public StmtMutator {
731731
Stmt VisitStmt_(const ForNode* op) final {
732732
if (op->kind == ForKind::kVectorized) {
733733
auto* extent_as_int = op->extent.as<IntImmNode>();
734+
734735
if (!extent_as_int || extent_as_int->value < 1) {
735-
Target current_target = Target::Current();
736-
bool has_sve{false};
737-
if (current_target.defined()) {
738-
has_sve = current_target->GetFeature<Bool>("has_sve").value_or(Bool(false));
739-
}
740736
bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
741-
ICHECK(is_scalable_expr && has_sve) << "Failed to vectorize loop with extent " << op->extent
742-
<< " for target " << current_target;
737+
ICHECK(is_scalable_expr && arith::TargetHasSVE())
738+
<< "Failed to vectorize loop with extent " << op->extent << " for target "
739+
<< Target::Current();
743740
}
744741
ICHECK(is_zero(op->min));
745742
return Vectorizer(op->loop_var, op->extent)(op->body);

tests/python/tir-transform/test_tir_transform_vectorize.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def test_vectorize_vector_scalable_error4():
114114
class Module:
115115
@T.prim_func(private=True)
116116
def main(A: T.Buffer((25,), "float32")):
117-
T.func_attr({"target": sve_target})
118117
for j in T.vectorized(T.vscale() * 4):
119118
A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = T.Broadcast(
120119
T.float32(1), T.vscale() * 4

0 commit comments

Comments
 (0)