Skip to content

Commit

Permalink
Merge branch 'arm_sve_redux' of https://github.com/halide/Halide into…
Browse files Browse the repository at this point in the history
… arm_sve_redux
  • Loading branch information
Z Stern committed Mar 5, 2024
2 parents 4a269bd + f8952c2 commit eaed2ef
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 8 deletions.
6 changes: 4 additions & 2 deletions src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,10 @@ ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgumen
} // namespace

void Function::deep_copy(const FunctionPtr &copy, DeepCopyMap &copied_map) const {
internal_assert(copy.defined() && contents.defined())
<< "Cannot deep-copy undefined Function\n";
internal_assert(copy.defined())
<< "Cannot deep-copy to undefined Function\n";
internal_assert(contents.defined())
<< "Cannot deep-copy from undefined Function\n";

// Add reference to this Function's deep-copy to the map in case of
// self-reference, e.g. self-reference in an Definition.
Expand Down
54 changes: 48 additions & 6 deletions test/correctness/simd_op_check_sve2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
private:
void check_arm_integer() {
// clang-format off
vector<tuple<int, ImageParam, ImageParam, ImageParam, ImageParam, ImageParam,
vector<tuple<int, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy,
CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy,
CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> test_params{
{8, in_i8, in_u8, in_f16, in_i16, in_u16, i8, i8_sat, i16, i8, i8_sat, u8, u8_sat, u16, u8, u8_sat},
Expand Down Expand Up @@ -542,7 +542,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}

void check_arm_float() {
vector<tuple<int, ImageParam, ImageParam, ImageParam, CastFuncTy>> test_params{
vector<tuple<int, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> test_params{
{16, in_f16, in_u16, in_i16, f16},
{32, in_f32, in_u32, in_i32, f32},
{64, in_f64, in_u64, in_i64, f64},
Expand Down Expand Up @@ -674,7 +674,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
}

void check_arm_load_store() {
vector<tuple<Type, ImageParam>> test_params = {
vector<tuple<Type, CastFuncTy>> test_params = {
{Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}};

for (const auto &[elt, in_im] : test_params) {
Expand Down Expand Up @@ -866,7 +866,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {

// Tests for integer type
{
vector<tuple<int, ImageParam, ImageParam, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> test_params{
vector<tuple<int, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy, CastFuncTy>> test_params{
{8, in_i8, in_u8, i16, i32, u16, u32},
{16, in_i16, in_u16, i32, i64, u32, u64},
{32, in_i32, in_u32, i64, i64, u64, u64},
Expand Down Expand Up @@ -974,7 +974,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
// Tests for Float type
{
// clang-format off
vector<tuple<int, ImageParam>> test_params{
vector<tuple<int, CastFuncTy>> test_params{
{16, in_f16},
{32, in_f32},
{64, in_f64},
Expand Down Expand Up @@ -1230,6 +1230,48 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
// settings.
if (!parent.wildcard_match(parent.filter, decorated_op_name)) return;

// Create a deep copy of the expr and all Funcs referenced by it, so
// that no IR is shared between tests. This is required by the base
// class, and is why we can parallelize.
{
using namespace Halide::Internal;
class FindOutputs : public IRVisitor {
using IRVisitor::visit;
void visit(const Call *op) override {
if (op->func.defined()) {
outputs.insert(op->func);
}
IRVisitor::visit(op);
}

public:
std::set<FunctionPtr> outputs;
} finder;
e.accept(&finder);
std::vector<Function> outputs(finder.outputs.begin(), finder.outputs.end());
auto env = deep_copy(outputs, build_environment(outputs)).second;
class DeepCopy : public IRMutator {
std::map<FunctionPtr, FunctionPtr> copied;
using IRMutator::visit;
Expr visit(const Call *op) override {
if (op->func.defined()) {
auto it = env.find(op->name);
if (it != env.end()) {
return Func(it->second)(mutate(op->args));
}
}
return IRMutator::visit(op);
}
const std::map<std::string, Function> &env;

public:
DeepCopy(const std::map<std::string, Function> &env)
: env(env) {
}
} copier(env);
e = copier.mutate(e);
}

// Create Task and register
parent.tasks.emplace_back(Task{decorated_op_name, unique_name, vec_factor, e});
parent.arm_tasks.emplace(unique_name, ArmTask{std::move(instr_patterns)});
Expand All @@ -1242,7 +1284,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
bool is_enabled;
};

void compile_and_check(Func error, const string &op, const string &name, int vector_width, ostringstream &error_msg) override {
void compile_and_check(Func error, const string &op, const string &name, int vector_width, const std::vector<Argument> &arg_types, ostringstream &error_msg) override {
// This is necessary as LLVM validation errors, crashes, etc. don't tell which op crashed.
cout << "Starting op " << op << "\n";
string fn_name = "test_" + name;
Expand Down

0 comments on commit eaed2ef

Please sign in to comment.