Skip to content

Commit 9b91c53

Browse files
committed
[SVE] Check for SVE target in func_attr from VectorizeLoop
Check that we are compiling for an SVE enabled target when the extent of a loop marked for vectorizing has a vscale dependent extent. 1. Add call to BindTarget in Graph Executor pipeline (to enable running tvmc flow) 2. Check in LoopVectorize that the extent is either a positive integer or an vscale dependent expression, in which case we'd expect an SVE enabled target to be present in func_attr
1 parent d4056ca commit 9b91c53

File tree

3 files changed

+106
-32
lines changed

3 files changed

+106
-32
lines changed

src/driver/driver_api.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
161161
.value();
162162

163163
bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp", Bool(false)).value();
164+
Target current_target = Target::Current();
164165

165166
Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>();
166167
Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>();
@@ -196,6 +197,9 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
196197
Array<tvm::transform::Pass> pass_list = user_lower_phase0;
197198

198199
// PHASE 1
200+
if (current_target.defined()) {
201+
pass_list.push_back(tir::transform::BindTarget(current_target));
202+
}
199203
pass_list.push_back(tir::transform::InjectPrefetch());
200204
pass_list.push_back(tir::transform::TextureFlatten());
201205
pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));

src/tir/transforms/vectorize_loop.cc

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
#include <unordered_map>
3535
#include <vector>
3636

37+
#include "../../src/arith/scalable_expression.h"
38+
#include "../../tir/analysis/check_contains.h"
39+
3740
namespace tvm {
3841
namespace tir {
3942

@@ -725,17 +728,33 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
725728

726729
class LoopVectorizer : public StmtMutator {
727730
public:
731+
LoopVectorizer(PrimFunc f) {
732+
auto target = f->attrs.GetAttr<tvm::Target>(tvm::attr::kTarget);
733+
if (target.defined()) {
734+
target_ = Downcast<Target>(target);
735+
has_sve_ = target_->GetFeature<Bool>("has_sve").value_or(Bool(false));
736+
}
737+
}
738+
728739
Stmt VisitStmt_(const ForNode* op) final {
729740
if (op->kind == ForKind::kVectorized) {
741+
auto* extent_as_int = op->extent.as<IntImmNode>();
742+
if (!extent_as_int || extent_as_int->value < 1) {
743+
bool is_scalable_expr = CheckContains::ExprContains(op->extent, arith::IsVScaleCall);
744+
ICHECK(is_scalable_expr && has_sve_)
745+
<< "Failed to vectorize loop with extent " << op->extent << " for target " << target_;
746+
}
730747
ICHECK(is_zero(op->min));
731748
return Vectorizer(op->loop_var, op->extent)(op->body);
732749
} else {
733750
return StmtMutator::VisitStmt_(op);
734751
}
735752
}
736-
};
737753

738-
Stmt VectorizeLoop(Stmt stmt) { return LoopVectorizer()(std::move(stmt)); }
754+
private:
755+
bool has_sve_{false};
756+
Target target_{};
757+
};
739758

740759
class VectorizeSkipper : public StmtMutator {
741760
public:
@@ -759,7 +778,7 @@ Pass VectorizeLoop(bool enable_vectorize) {
759778
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
760779
auto* n = f.CopyOnWrite();
761780
if (enable_vectorize) {
762-
n->body = LoopVectorizer()(std::move(n->body));
781+
n->body = LoopVectorizer(f)(std::move(n->body));
763782
} else {
764783
n->body = VectorizeSkipper()(std::move(n->body));
765784
}

0 commit comments

Comments
 (0)