Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python_bindings/src/halide/halide_/PyExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ void define_expr(py::module &m) {
") cannot be converted to a bool. "
"If this error occurs using the 'and'/'or' keywords, "
"consider using the '&'/'|' operators instead.");
return false;
};

auto expr_class =
Expand Down Expand Up @@ -78,6 +77,12 @@ void define_expr(py::module &m) {
py::implicitly_convertible<RVar, Expr>();
py::implicitly_convertible<Var, Expr>();

auto eviction_key_class =
py::class_<EvictionKey>(m, "EvictionKey")
.def(py::init<Expr>());

py::implicitly_convertible<Expr, EvictionKey>();

auto range_class =
py::class_<Range>(m, "Range")
.def(py::init<>())
Expand Down
12 changes: 6 additions & 6 deletions python_bindings/src/halide/halide_/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ void define_func(py::module &m) {
.def("async_", &Func::async)
.def("ring_buffer", &Func::ring_buffer)
.def("bound_storage", &Func::bound_storage)
.def("memoize", &Func::memoize)
.def("memoize", &Func::memoize, py::arg("eviction_key") = EvictionKey())
.def("compute_inline", &Func::compute_inline)
.def("compute_root", &Func::compute_root)
.def("store_root", &Func::store_root)
Expand Down Expand Up @@ -404,12 +404,12 @@ void define_func(py::module &m) {
},
py::arg("dst"), py::arg("target") = Target())

.def("in_", (Func(Func::*)(const Func &))&Func::in, py::arg("f"))
.def("in_", (Func(Func::*)(const std::vector<Func> &fs))&Func::in, py::arg("fs"))
.def("in_", (Func(Func::*)())&Func::in)
.def("in_", static_cast<Func (Func::*)(const Func &)>(&Func::in), py::arg("f"))
.def("in_", static_cast<Func (Func::*)(const std::vector<Func> &fs)>(&Func::in), py::arg("fs"))
.def("in_", static_cast<Func (Func::*)()>(&Func::in))

.def("clone_in", (Func(Func::*)(const Func &))&Func::clone_in, py::arg("f"))
.def("clone_in", (Func(Func::*)(const std::vector<Func> &fs))&Func::clone_in, py::arg("fs"))
.def("clone_in", static_cast<Func (Func::*)(const Func &)>(&Func::clone_in), py::arg("f"))
.def("clone_in", static_cast<Func (Func::*)(const std::vector<Func> &fs)>(&Func::clone_in), py::arg("fs"))

.def("copy_to_device", &Func::copy_to_device, py::arg("device_api") = DeviceAPI::Default_GPU)
.def("copy_to_host", &Func::copy_to_host)
Expand Down
1 change: 1 addition & 0 deletions python_bindings/test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ set(tests
extern.py
float_precision_test.py
iroperator.py
memoize.py
multi_method_module_test.py
multipass_constraints.py
pystub.py
Expand Down
24 changes: 24 additions & 0 deletions python_bindings/test/correctness/memoize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from halide import Func, Var


def test_memoize():
x = Var("x")

f = Func("f")
f[x] = 0.0
f[x] += 1
f.compute_root().memoize()

output = Func("output")
output[x] = f[x]

result = output.realize([3])
assert list(result) == [1., 1., 1.]


def main():
test_memoize()


if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -2259,6 +2259,11 @@ class Func {
* to remove memoized entries using this eviction key from the
* cache. Memoized computations that do not provide an eviction
* key will never be evicted by this mechanism.
*
* It is invalid to memoize the output of a Pipeline; attempting
* to do so will issue an error. To cache an entire pipeline,
* either implement a caching mechanism outside of Halide or
* explicitly copy out of the cache with another output Func.
*/
Func &memoize(const EvictionKey &eviction_key = EvictionKey());

Expand Down
5 changes: 5 additions & 0 deletions src/Memoization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ namespace Internal {
namespace {

class FindParameterDependencies : public IRGraphVisitor {
std::set<Function, Function::Compare> visited_functions;

public:
FindParameterDependencies() = default;
~FindParameterDependencies() override = default;

void visit_function(const Function &function) {
if (const auto [_, inserted] = visited_functions.insert(function); !inserted) {
return;
}
function.accept(this);

if (function.has_extern_definition()) {
Expand Down
6 changes: 5 additions & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,11 @@ Module Pipeline::compile_to_module(const vector<Argument> &args,

for (const Function &f : contents->outputs) {
user_assert(f.has_pure_definition() || f.has_extern_definition())
<< "Can't compile Pipeline with undefined output Func: " << f.name() << ".\n";
<< "Can't compile Pipeline with undefined output Func: " << f.name() << ".";
user_assert(!f.schedule().memoized())
<< "Can't compile Pipeline with memoized output Func: " << f.name() << ". "
<< "Memoization is valid only on intermediate Funcs because it takes "
<< "control of buffer allocation.";
}

string new_fn_name(fn_name);
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ tests(GROUPS error
lerp_mismatch.cpp
lerp_signed_weight.cpp
memoize_different_compute_store.cpp
memoize_output_invalid.cpp
memoize_redefine_eviction_key.cpp
metal_threads_too_large.cpp
metal_vector_too_large.cpp
Expand Down
14 changes: 14 additions & 0 deletions test/error/memoize_output_invalid.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <Halide.h>
using namespace Halide;

int main(int argc, char **argv) {
Var x{"x"};
Func f{"f"};
f(x) = 0.0f;
f(x) += 1;
f.memoize();

f.realize({3});

printf("Success!\n");
}
Loading