Skip to content

Commit

Permalink
Disallow async nestings that violate read after write dependencies (h…
Browse files Browse the repository at this point in the history
…alide#7868)

* Disallow async nestings that violate read after write dependencies

Fixes halide#7867

* Add test

* Add another failure case, and improve error message

* Add some more tests

* Update test

* Add new test to cmakelists

* Fix for llvm trunk

* Always acquire the folding semaphore, even if unused

* Skip async_order test under wasm

* trigger buildbots

---------

Co-authored-by: Volodymyr Kysenko <vksnk@google.com>
Co-authored-by: Steven Johnson <srj@google.com>
  • Loading branch information
3 people committed Dec 1, 2023
1 parent 4fc2a7d commit 674e6cc
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 12 deletions.
51 changes: 51 additions & 0 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,23 +109,74 @@ class NoOpCollapsingMutator : public IRMutator {
class GenerateProducerBody : public NoOpCollapsingMutator {
const string &func;
vector<Expr> sema;
std::set<string> producers_dropped;
bool found_producer = false;

using NoOpCollapsingMutator::visit;

void bad_producer_nesting_error(const string &producer, const string &async_consumer) {
user_error
<< "The Func " << producer << " is consumed by async Func " << async_consumer
<< " and has a compute_at location in between the store_at "
<< "location and the compute_at location of " << async_consumer
<< ". This is only legal when " << producer
<< " is both async and has a store_at location outside the store_at location of the consumer.";
}

// Preserve produce nodes and add synchronization
Stmt visit(const ProducerConsumer *op) override {
if (op->name == func && op->is_producer) {
found_producer = true;

// Add post-synchronization
internal_assert(!sema.empty()) << "Duplicate produce node: " << op->name << "\n";
Stmt body = op->body;

// We don't currently support waiting on producers to the producer
// half of the fork node. Or rather, if you want to do that you have
// to schedule those Funcs as async too. Check for any consume nodes
// where the producer has gone to the consumer side of the fork
// node.
class FindBadConsumeNodes : public IRVisitor {
const std::set<string> &producers_dropped;
using IRVisitor::visit;

void visit(const ProducerConsumer *op) override {
if (!op->is_producer && producers_dropped.count(op->name)) {
found = op->name;
}
}

public:
string found;
FindBadConsumeNodes(const std::set<string> &p)
: producers_dropped(p) {
}
} finder(producers_dropped);
body.accept(&finder);
if (!finder.found.empty()) {
bad_producer_nesting_error(finder.found, func);
}

while (!sema.empty()) {
Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern);
body = Block::make(body, Evaluate::make(release));
sema.pop_back();
}
return ProducerConsumer::make_produce(op->name, body);
} else {
if (op->is_producer) {
producers_dropped.insert(op->name);
}
bool found_producer_before = found_producer;
Stmt body = mutate(op->body);
if (!op->is_producer && producers_dropped.count(op->name) &&
found_producer && !found_producer_before) {
// We've found a consume node wrapping our async producer where
// the corresponding producer node was dropped from this half of
// the fork.
bad_producer_nesting_error(op->name, func);
}
if (is_no_op(body) || op->is_producer) {
return body;
} else {
Expand Down
5 changes: 0 additions & 5 deletions src/StorageFolding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,6 @@ class AttemptStorageFoldingOfFunction : public IRMutator {
to_release = max_required - max_required_next; // This is the last time we use these entries
}

if (provided.used.defined()) {
to_acquire = select(provided.used, to_acquire, 0);
}
// We should always release the required region, even if we don't use it.

// On the first iteration, we need to acquire the extent of the region shared
// between the producer and consumer, and we need to release it on the last
// iteration.
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ tests(GROUPS correctness
align_bounds.cpp
argmax.cpp
async_device_copy.cpp
async_order.cpp
autodiff.cpp
bad_likely.cpp
bit_counting.cpp
Expand Down
94 changes: 94 additions & 0 deletions test/correctness/async_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
if (get_jit_target_from_environment().arch == Target::WebAssembly) {
printf("[SKIP] WebAssembly does not support async() yet.\n");
return 0;
}

{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_at(consumer, y);
producer2.compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}
{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_root();
producer2.store_root().compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}

{
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer1(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.store_root().compute_at(consumer, y).async();
producer2.store_root().compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});

out.for_each_element([&](int x, int y) {
int correct = 2 * (x + y);
if (out(x, y) != correct) {
printf("out(%d, %d) = %d instead of %d\n",
x, y, out(x, y), correct);
exit(-1);
}
});
}

printf("Success!\n");
return 0;
}
2 changes: 2 additions & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ tests(GROUPS error
auto_schedule_no_parallel.cpp
auto_schedule_no_reorder.cpp
autodiff_unbounded.cpp
bad_async_producer.cpp
bad_async_producer_2.cpp
bad_bound.cpp
bad_bound_storage.cpp
bad_compute_at.cpp
Expand Down
31 changes: 31 additions & 0 deletions test/error/bad_async_producer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@

#include "Halide.h"

using namespace Halide;

int main(int argc, char **argv) {

Func f{"f"}, g{"g"}, h{"h"};
Var x;

f(x) = cast<uint8_t>(x + 7);
g(x) = f(x);
h(x) = g(x);

// The schedule below is an error. It should really be:
// f.store_root().compute_at(g, Var::outermost());
// So that it's nested inside the consumer h.
f.store_root().compute_at(h, x);
g.store_root().compute_at(h, x).async();

Buffer<uint8_t> buf = h.realize({32});
for (int i = 0; i < buf.dim(0).extent(); i++) {
uint8_t correct = i + 7;
if (buf(i) != correct) {
printf("buf(%d) = %d instead of %d\n", i, buf(i), correct);
return 1;
}
}

return 0;
}
23 changes: 23 additions & 0 deletions test/error/bad_async_producer_2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "Halide.h"

using namespace Halide;

// From https://github.com/halide/Halide/issues/5201
int main(int argc, char **argv) {
Func producer1, producer2, consumer;
Var x, y;

producer1(x, y) = x + y;
producer2(x, y) = producer1(x, y);
consumer(x, y) = producer2(x, y - 1) + producer2(x, y + 1);

consumer.compute_root();

producer1.compute_at(consumer, y).async();
producer2.store_root().compute_at(consumer, y).async();

consumer.bound(x, 0, 16).bound(y, 0, 16);

Buffer<int> out = consumer.realize({16, 16});
return 0;
}
33 changes: 26 additions & 7 deletions test/performance/async_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Expr expensive(Expr x, int c) {
if (c <= 0) {
return x;
} else {
return expensive(fast_pow(x, x + 1), c - 1);
return expensive(x * (x + 1), c - 1);
}
}

Expand All @@ -31,11 +31,12 @@ int main(int argc, char **argv) {
}

double times[2];
uint32_t correct = 0;
for (int use_async = 0; use_async < 2; use_async++) {
Var x, y, t, xi, yi;

ImageParam in(Float(32), 3);
Func cpu, gpu;
ImageParam in(UInt(32), 3);
Func cpu("cpu"), gpu("gpu");

// We have a two-stage pipeline that processes frames. We want
// to run the first stage on the GPU and the second stage on
Expand All @@ -50,26 +51,44 @@ int main(int argc, char **argv) {

// Assume GPU memory is limited, and compute the GPU stage one
// frame at a time. Hoist the allocation to the top level.
gpu.compute_at(cpu, t).store_root().gpu_tile(x, y, xi, yi, 8, 8);
gpu.compute_at(gpu.in(), Var::outermost()).store_root().gpu_tile(x, y, xi, yi, 8, 8);

// Stage the copy-back of the GPU result into a host-side
// double-buffer.
gpu.in().copy_to_host().compute_at(cpu, t).store_root().fold_storage(t, 2);

if (use_async) {
// gpu.async();
gpu.in().async();
gpu.async();
}

in.set(Buffer<float>(800, 800, 16));
Buffer<float> out(800, 800, 16);
Buffer<uint32_t> in_buf(800, 800, 16);
in_buf.fill(17);
in.set(in_buf);
Buffer<uint32_t> out(800, 800, 16);

cpu.compile_jit();

times[use_async] = benchmark(10, 1, [&]() {
cpu.realize(out);
});

if (!use_async) {
correct = out(0, 0, 0);
} else {
for (int t = 0; t < out.dim(2).extent(); t++) {
for (int y = 0; y < out.dim(1).extent(); y++) {
for (int x = 0; x < out.dim(0).extent(); x++) {
if (out(x, y, t) != correct) {
printf("Async output at (%d, %d, %d) is %u instead of %u\n",
x, y, t, out(x, y, t), correct);
return 1;
}
}
}
}
}

printf("%s: %f\n",
use_async ? "with async" : "without async",
times[use_async]);
Expand Down

0 comments on commit 674e6cc

Please sign in to comment.