From f8952c2cf124d6ad9b6f02153cd0ace7c78df75f Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 23 Feb 2024 14:51:59 -0800 Subject: [PATCH] Make a deep copy of each piece of test IR so that we can parallelize --- src/Function.cpp | 6 ++-- test/correctness/simd_op_check_sve2.cpp | 46 ++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/Function.cpp b/src/Function.cpp index 795d18136843..88f5b851e986 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -487,8 +487,10 @@ ExternFuncArgument deep_copy_extern_func_argument_helper(const ExternFuncArgumen } // namespace void Function::deep_copy(const FunctionPtr ©, DeepCopyMap &copied_map) const { - internal_assert(copy.defined() && contents.defined()) - << "Cannot deep-copy undefined Function\n"; + internal_assert(copy.defined()) + << "Cannot deep-copy to undefined Function\n"; + internal_assert(contents.defined()) + << "Cannot deep-copy from undefined Function\n"; // Add reference to this Function's deep-copy to the map in case of // self-reference, e.g. self-reference in an Definition. diff --git a/test/correctness/simd_op_check_sve2.cpp b/test/correctness/simd_op_check_sve2.cpp index 072663a52588..9186d0e959f6 100644 --- a/test/correctness/simd_op_check_sve2.cpp +++ b/test/correctness/simd_op_check_sve2.cpp @@ -60,10 +60,6 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { return can_run_the_code; } - bool use_multiple_threads() const override { - return false; - } - void add_tests() override { check_arm_integer(); check_arm_float(); @@ -1234,6 +1230,48 @@ class SimdOpCheckArmSve : public SimdOpCheckTest { // settings. if (!parent.wildcard_match(parent.filter, decorated_op_name)) return; + // Create a deep copy of the expr and all Funcs referenced by it, so + // that no IR is shared between tests. This is required by the base + // class, and is why we can parallelize. + { + using namespace Halide::Internal; + class FindOutputs : public IRVisitor { + using IRVisitor::visit; + void visit(const Call *op) override { + if (op->func.defined()) { + outputs.insert(op->func); + } + IRVisitor::visit(op); + } + + public: + std::set outputs; + } finder; + e.accept(&finder); + std::vector outputs(finder.outputs.begin(), finder.outputs.end()); + auto env = deep_copy(outputs, build_environment(outputs)).second; + class DeepCopy : public IRMutator { + std::map copied; + using IRMutator::visit; + Expr visit(const Call *op) override { + if (op->func.defined()) { + auto it = env.find(op->name); + if (it != env.end()) { + return Func(it->second)(mutate(op->args)); + } + } + return IRMutator::visit(op); + } + const std::map &env; + + public: + DeepCopy(const std::map &env) + : env(env) { + } + } copier(env); + e = copier.mutate(e); + } + // Create Task and register parent.tasks.emplace_back(Task{decorated_op_name, unique_name, vec_factor, e}); parent.arm_tasks.emplace(unique_name, ArmTask{std::move(instr_patterns)});