Skip to content

Commit

Permalink
Merge branch 'arm_sve_redux' of https://github.com/halide/Halide into…
Browse files Browse the repository at this point in the history
… arm_sve_redux
  • Loading branch information
Z Stern committed Mar 6, 2024
2 parents eaed2ef + a63439b commit f84c764
Show file tree
Hide file tree
Showing 40 changed files with 1,265 additions and 739 deletions.
3 changes: 0 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,6 @@ CXX_FLAGS += $(WEBASSEMBLY_CXX_FLAGS)
# On ubuntu, this requires packages flatbuffers-compiler and libflatbuffers-dev
ifneq (,$(shell which flatc))
CXX_FLAGS += -DWITH_SERIALIZATION -I $(BUILD_DIR) -I $(shell which flatc | sed 's/bin.flatc/include/')
# Note: if updating here, be sure to update in CMakeLists.txt as well
HALIDE_SERIALIZATION_VERSION_MINOR ?= 1
HALIDE_SERIALIZATION_VERSION_PATCH ?= 0
endif

# This is required on some hosts like powerpc64le-linux-gnu because we may build
Expand Down
80 changes: 68 additions & 12 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,67 @@ class InitializeSemaphores : public IRMutator {
}
};

// A class to support stmt_uses_vars queries that repeatedly hit the same
// sub-stmts. Used to support TightenProducerConsumerNodes below.
class CachingStmtUsesVars : public IRMutator {
const Scope<> &query;
bool found_use = false;
std::map<Stmt, bool> cache;

using IRMutator::visit;
Expr visit(const Variable *op) override {
found_use |= query.contains(op->name);
return op;
}

Expr visit(const Call *op) override {
found_use |= query.contains(op->name);
IRMutator::visit(op);
return op;
}

Stmt visit(const Provide *op) override {
found_use |= query.contains(op->name);
IRMutator::visit(op);
return op;
}

public:
CachingStmtUsesVars(const Scope<> &q)
: query(q) {
}

using IRMutator::mutate;
Stmt mutate(const Stmt &s) override {
auto it = cache.find(s);
if (it != cache.end()) {
found_use |= it->second;
} else {
bool old = found_use;
found_use = false;
Stmt stmt = IRMutator::mutate(s);
if (found_use) {
cache.emplace(s, true);
} else {
cache.emplace(s, false);
}
found_use |= old;
}
return s;
}

bool check_stmt(const Stmt &s) {
found_use = false;
mutate(s);
return found_use;
}
};

// Tighten the scope of consume nodes as much as possible to avoid needless synchronization.
class TightenProducerConsumerNodes : public IRMutator {
using IRMutator::visit;

Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<int> &scope) {
Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<> &scope, CachingStmtUsesVars &uses_vars) {
if (const LetStmt *let = body.as<LetStmt>()) {
Stmt orig = body;
// 'orig' is only used to keep a reference to the let
Expand All @@ -595,7 +651,7 @@ class TightenProducerConsumerNodes : public IRMutator {
body = ProducerConsumer::make(name, is_producer, body);
} else {
// Recurse onto a non-let-node
body = make_producer_consumer(name, is_producer, body, scope);
body = make_producer_consumer(name, is_producer, body, scope, uses_vars);
}

for (auto it = containing_lets.rbegin(); it != containing_lets.rend(); it++) {
Expand All @@ -611,44 +667,44 @@ class TightenProducerConsumerNodes : public IRMutator {
vector<Stmt> sub_stmts;
Stmt rest;
do {
Stmt first = block->first;
sub_stmts.push_back(block->first);
rest = block->rest;
block = rest.as<Block>();
} while (block);
sub_stmts.push_back(rest);

for (Stmt &s : sub_stmts) {
if (stmt_uses_vars(s, scope)) {
s = make_producer_consumer(name, is_producer, s, scope);
if (uses_vars.check_stmt(s)) {
s = make_producer_consumer(name, is_producer, s, scope, uses_vars);
}
}

return Block::make(sub_stmts);
} else if (const ProducerConsumer *pc = body.as<ProducerConsumer>()) {
return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope));
return ProducerConsumer::make(pc->name, pc->is_producer, make_producer_consumer(name, is_producer, pc->body, scope, uses_vars));
} else if (const Realize *r = body.as<Realize>()) {
return Realize::make(r->name, r->types, r->memory_type,
r->bounds, r->condition,
make_producer_consumer(name, is_producer, r->body, scope));
make_producer_consumer(name, is_producer, r->body, scope, uses_vars));
} else {
return ProducerConsumer::make(name, is_producer, body);
}
}

Stmt visit(const ProducerConsumer *op) override {
Stmt body = mutate(op->body);
Scope<int> scope;
scope.push(op->name, 0);
Scope<> scope;
scope.push(op->name);
Function f = env.find(op->name)->second;
if (f.outputs() == 1) {
scope.push(op->name + ".buffer", 0);
scope.push(op->name + ".buffer");
} else {
for (int i = 0; i < f.outputs(); i++) {
scope.push(op->name + "." + std::to_string(i) + ".buffer", 0);
scope.push(op->name + "." + std::to_string(i) + ".buffer");
}
}
return make_producer_consumer(op->name, op->is_producer, body, scope);
CachingStmtUsesVars uses_vars{scope};
return make_producer_consumer(op->name, op->is_producer, body, scope, uses_vars);
}

const map<string, Function> &env;
Expand Down
7 changes: 6 additions & 1 deletion src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,9 +1383,14 @@ Stmt bounds_inference(Stmt s,
fused_pairs_in_groups.push_back(pairs);
}

// Add a note in the IR for where the outermost dynamic-stage skipping
// checks should go. These are injected in a later pass.
Expr marker = Call::make(Int(32), Call::skip_stages_marker, {}, Call::Intrinsic);
s = Block::make(Evaluate::make(marker), s);

// Add a note in the IR for where assertions on input images
// should go. Those are handled by a later lowering pass.
Expr marker = Call::make(Int(32), Call::add_image_checks_marker, {}, Call::Intrinsic);
marker = Call::make(Int(32), Call::add_image_checks_marker, {}, Call::Intrinsic);
s = Block::make(Evaluate::make(marker), s);

// Add a synthetic outermost loop to act as 'root'.
Expand Down
43 changes: 23 additions & 20 deletions src/CanonicalizeGPUVars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@ namespace Halide {
namespace Internal {

using std::map;
using std::string;
using std::vector;

namespace {
string thread_names[] = {"__thread_id_x", "__thread_id_y", "__thread_id_z"};
string block_names[] = {"__block_id_x", "__block_id_y", "__block_id_z"};

string get_thread_name(int index) {
const std::string &gpu_thread_name(int index) {
static std::string gpu_thread_names[3] = {"." + unique_name("thread_id_x"),
"." + unique_name("thread_id_y"),
"." + unique_name("thread_id_z")};
internal_assert(index >= 0 && index < 3);
return thread_names[index];
return gpu_thread_names[index];
}

string get_block_name(int index) {
const std::string &gpu_block_name(int index) {
static std::string gpu_block_names[3] = {"." + unique_name("block_id_x"),
"." + unique_name("block_id_y"),
"." + unique_name("block_id_z")};
internal_assert(index >= 0 && index < 3);
return block_names[index];
return gpu_block_names[index];
}

namespace {

class CountGPUBlocksThreads : public IRVisitor {
using IRVisitor::visit;

Expand Down Expand Up @@ -73,12 +76,12 @@ class CountGPUBlocksThreads : public IRVisitor {
};

class CanonicalizeGPUVars : public IRMutator {
map<string, string> gpu_vars;
map<std::string, std::string> gpu_vars;

using IRMutator::visit;

string find_replacement(const string &suffix, const string &name) {
vector<string> v = split_string(name, suffix);
std::string find_replacement(const std::string &suffix, const std::string &name) {
vector<std::string> v = split_string(name, suffix);
internal_assert(v.size() == 2);
const auto &iter = gpu_vars.find(v[0]);
if (iter != gpu_vars.end()) {
Expand All @@ -87,7 +90,7 @@ class CanonicalizeGPUVars : public IRMutator {
return name;
}

string canonicalize_let(const string &name) {
std::string canonicalize_let(const std::string &name) {
if (ends_with(name, ".loop_max")) {
return find_replacement(".loop_max", name);
} else if (ends_with(name, ".loop_min")) {
Expand All @@ -100,7 +103,7 @@ class CanonicalizeGPUVars : public IRMutator {
}

Stmt visit(const For *op) override {
string name = op->name;
std::string name = op->name;
Expr min = mutate(op->min);
Expr extent = mutate(op->extent);
Stmt body = mutate(op->body);
Expand All @@ -113,13 +116,13 @@ class CanonicalizeGPUVars : public IRMutator {
op->body.accept(&counter);

if (op->for_type == ForType::GPUBlock) {
name += "." + get_block_name(counter.nblocks);
name += gpu_block_name(counter.nblocks);
debug(5) << "Replacing " << op->name << " with GPU block name " << name << "\n";
} else if (op->for_type == ForType::GPUThread) {
name += "." + get_thread_name(counter.nthreads);
name += gpu_thread_name(counter.nthreads);
debug(5) << "Replacing " << op->name << " with GPU thread name " << name << "\n";
} else if (op->for_type == ForType::GPULane) {
name += "." + get_thread_name(0);
name += gpu_thread_name(0);
}

if (name != op->name) {
Expand All @@ -143,7 +146,7 @@ class CanonicalizeGPUVars : public IRMutator {
}

Stmt visit(const LetStmt *op) override {
vector<std::pair<string, Expr>> lets;
vector<std::pair<std::string, Expr>> lets;
Stmt result;

do {
Expand All @@ -154,7 +157,7 @@ class CanonicalizeGPUVars : public IRMutator {
result = mutate(result);

for (auto it = lets.rbegin(); it != lets.rend(); it++) {
string name = canonicalize_let(it->first);
std::string name = canonicalize_let(it->first);
if (name != it->first) {
Expr new_var = Variable::make(Int(32), name);
result = substitute(it->first, new_var, result);
Expand All @@ -168,7 +171,7 @@ class CanonicalizeGPUVars : public IRMutator {
Stmt visit(const IfThenElse *op) override {
Expr condition = mutate(op->condition);

map<string, string> old_gpu_vars;
map<std::string, std::string> old_gpu_vars;
old_gpu_vars.swap(gpu_vars);
Stmt then_case = mutate(op->then_case);

Expand Down
7 changes: 7 additions & 0 deletions src/CanonicalizeGPUVars.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ namespace Internal {
* by the nesting order: innermost is assigned to x and so on. */
Stmt canonicalize_gpu_vars(Stmt s);

/** Names for the thread and block id variables. Includes the leading
* dot. Indexed from inside out, so 0 gives you the innermost loop. */
// @{
const std::string &gpu_thread_name(int index);
const std::string &gpu_block_name(int index);
// @}

} // namespace Internal
} // namespace Halide

Expand Down
36 changes: 12 additions & 24 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <sstream>
#include <utility>

#include "CanonicalizeGPUVars.h"
#include "CodeGen_D3D12Compute_Dev.h"
#include "CodeGen_GPU_Dev.h"
#include "CodeGen_Internal.h"
Expand Down Expand Up @@ -221,22 +222,18 @@ string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_reinterpret(Type

namespace {
string simt_intrinsic(const string &name) {
if (ends_with(name, ".__thread_id_x")) {
if (ends_with(name, gpu_thread_name(0))) {
return "tid_in_tgroup.x";
} else if (ends_with(name, ".__thread_id_y")) {
} else if (ends_with(name, gpu_thread_name(1))) {
return "tid_in_tgroup.y";
} else if (ends_with(name, ".__thread_id_z")) {
} else if (ends_with(name, gpu_thread_name(2))) {
return "tid_in_tgroup.z";
} else if (ends_with(name, ".__thread_id_w")) {
user_error << "HLSL (SM5.1) does not support more than three dimensions for compute kernel threads.\n";
} else if (ends_with(name, ".__block_id_x")) {
} else if (ends_with(name, gpu_block_name(0))) {
return "tgroup_index.x";
} else if (ends_with(name, ".__block_id_y")) {
} else if (ends_with(name, gpu_block_name(1))) {
return "tgroup_index.y";
} else if (ends_with(name, ".__block_id_z")) {
} else if (ends_with(name, gpu_block_name(2))) {
return "tgroup_index.z";
} else if (ends_with(name, ".__block_id_w")) {
user_error << "HLSL (SM5.1) does not support more than three dimensions for compute dispatch groups.\n";
}
internal_error << "simt_intrinsic called on bad variable name: " << name << "\n";
return "";
Expand Down Expand Up @@ -300,15 +297,10 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const For *loop) {
user_assert(loop->for_type != ForType::GPULane)
<< "The D3D12Compute backend does not support the gpu_lanes() scheduling directive.";

if (!is_gpu_var(loop->name)) {
user_assert(loop->for_type != ForType::Parallel) << "Cannot use parallel loops inside D3D12Compute kernel\n";
if (!is_gpu(loop->for_type)) {
CodeGen_GPU_C::visit(loop);
return;
}

internal_assert((loop->for_type == ForType::GPUBlock) ||
(loop->for_type == ForType::GPUThread))
<< "kernel loop must be either gpu block or gpu thread\n";
internal_assert(is_const_zero(loop->min));

stream << get_indent() << print_type(Int(32)) << " " << print_name(loop->name)
Expand Down Expand Up @@ -1153,7 +1145,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
struct FindThreadGroupSize : public IRVisitor {
using IRVisitor::visit;
void visit(const For *loop) override {
if (!is_gpu_var(loop->name)) {
if (!is_gpu(loop->for_type)) {
return loop->body.accept(this);
}
if (loop->for_type != ForType::GPUThread) {
Expand All @@ -1175,13 +1167,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
loop->body.accept(this);
}
int thread_loop_workgroup_index(const string &name) {
string ids[] = {".__thread_id_x",
".__thread_id_y",
".__thread_id_z",
".__thread_id_w"};
for (auto &id : ids) {
if (ends_with(name, id)) {
return (&id - ids);
for (int i = 0; i < 3; i++) {
if (ends_with(name, gpu_thread_name(i))) {
return i;
}
}
return -1;
Expand Down
Loading

0 comments on commit f84c764

Please sign in to comment.