diff --git a/src/Function.cpp b/src/Function.cpp index 795d18136843..88f5b851e986 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -487,8 +487,10 @@ ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgumen } // namespace void Function::deep_copy(const FunctionPtr ©, DeepCopyMap &copied_map) const { - internal_assert(copy.defined() && contents.defined()) - << "Cannot deep-copy undefined Function\n"; + internal_assert(copy.defined()) + << "Cannot deep-copy to undefined Function\n"; + internal_assert(contents.defined()) + << "Cannot deep-copy from undefined Function\n"; // Add reference to this Function's deep-copy to the map in case of // self-reference, e.g. self-reference in an Definition. diff --git a/test/correctness/simd_op_check_sve2.cpp b/test/correctness/simd_op_check_sve2.cpp index 00ed36a35a9a..9186d0e959f6 100644 --- a/test/correctness/simd_op_check_sve2.cpp +++ b/test/correctness/simd_op_check_sve2.cpp @@ -70,7 +70,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { private: void check_arm_integer() { // clang-format off - vector> test_params{ {8, in_i8, in_u8, in_f16, in_i16, in_u16, i8, i8_sat, i16, i8, i8_sat, u8, u8_sat, u16, u8, u8_sat}, @@ -542,7 +542,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { } void check_arm_float() { - vector> test_params{ + vector> test_params{ {16, in_f16, in_u16, in_i16, f16}, {32, in_f32, in_u32, in_i32, f32}, {64, in_f64, in_u64, in_i64, f64}, @@ -674,7 +674,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { } void check_arm_load_store() { - vector> test_params = { + vector> test_params = { {Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}}; for (const auto &[elt, in_im] : test_params) { @@ -866,7 +866,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // Tests for integer type { - vector> test_params{ + vector> test_params{ {8, in_i8, in_u8, i16, i32, u16, u32}, {16, in_i16, in_u16, i32, i64, u32, u64}, {32, in_i32, in_u32, i64, i64, u64, u64}, @@ -974,7 +974,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // Tests for Float type { // clang-format off - vector> test_params{ + vector> test_params{ {16, in_f16}, {32, in_f32}, {64, in_f64}, @@ -1230,6 +1230,48 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // settings. if (!parent.wildcard_match(parent.filter, decorated_op_name)) return; + // Create a deep copy of the expr and all Funcs referenced by it, so + // that no IR is shared between tests. This is required by the base + // class, and is why we can parallelize. + { + using namespace Halide::Internal; + class FindOutputs : public IRVisitor { + using IRVisitor::visit; + void visit(const Call *op) override { + if (op->func.defined()) { + outputs.insert(op->func); + } + IRVisitor::visit(op); + } + + public: + std::set outputs; + } finder; + e.accept(&finder); + std::vector outputs(finder.outputs.begin(), finder.outputs.end()); + auto env = deep_copy(outputs, build_environment(outputs)).second; + class DeepCopy : public IRMutator { + std::map copied; + using IRMutator::visit; + Expr visit(const Call *op) override { + if (op->func.defined()) { + auto it = env.find(op->name); + if (it != env.end()) { + return Func(it->second)(mutate(op->args)); + } + } + return IRMutator::visit(op); + } + const std::map &env; + + public: + DeepCopy(const std::map &env) + : env(env) { + } + } copier(env); + e = copier.mutate(e); + } + // Create Task and register parent.tasks.emplace_back(Task{decorated_op_name, unique_name, vec_factor, e}); parent.arm_tasks.emplace(unique_name, ArmTask{std::move(instr_patterns)}); @@ -1242,7 +1284,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { bool is_enabled; }; - void compile_and_check(Func error, const string &op, const string &name, int vector_width, ostringstream &error_msg) override { + void compile_and_check(Func error, const string &op, const string &name, int vector_width, const std::vector &arg_types, ostringstream &error_msg) override { // This is necessary as LLVM validation errors, crashes, etc. don't tell which op crashed. cout << "Starting op " << op << "\n"; string fn_name = "test_" + name;