Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add_requirement() maintenance #7045

Merged
merged 3 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python_bindings/src/halide/_generator_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,10 @@ def using_autoscheduler(self):
def natural_vector_size(self, type: Type) -> int:
return self.target().natural_vector_size(type)

def add_requirement(self, condition: Expr, *args) -> None:
assert self._stage < _Stage.pipeline_built
self._pipeline_requirements.append((condition, [*args]))

@classmethod
def call(cls, *args, **kwargs):
generator = cls()
Expand Down Expand Up @@ -475,6 +479,7 @@ def __init__(self, generator_params: dict = {}):
self._requirements = {}
self._replacements = {}
self._in_configure = 0
self._pipeline_requirements = []

self._advance_to_gp_created()
if generator_params:
Expand Down Expand Up @@ -699,6 +704,8 @@ def _build_pipeline(self) -> Pipeline:
funcs.append(f)

self._pipeline = Pipeline(funcs)
for condition, error_args in self._pipeline_requirements:
self._pipeline.add_requirement(condition, *error_args)
self._stage = _Stage.pipeline_built
return self._pipeline

Expand Down
16 changes: 16 additions & 0 deletions python_bindings/src/halide/halide_/PyHalide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,21 @@ Expr double_to_expr_check(double v) {
return Expr(f);
}

std::vector<Expr> collect_print_args(const py::args &args) {
std::vector<Expr> v;
v.reserve(args.size());
for (size_t i = 0; i < args.size(); ++i) {
// No way to see if a cast will work: just have to try
// and fail. Normally we don't want string to be convertible
// to Expr, but in this unusual case we do.
try {
v.emplace_back(args[i].cast<std::string>());
} catch (...) {
v.push_back(args[i].cast<Expr>());
}
}
return v;
}

} // namespace PythonBindings
} // namespace Halide
1 change: 1 addition & 0 deletions python_bindings/src/halide/halide_/PyHalide.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ std::vector<T> args_to_vector(const py::args &args, size_t start_offset = 0, siz
return v;
}

std::vector<Expr> collect_print_args(const py::args &args);
Expr double_to_expr_check(double v);

} // namespace PythonBindings
Expand Down
28 changes: 2 additions & 26 deletions python_bindings/src/halide/halide_/PyIROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,6 @@
namespace Halide {
namespace PythonBindings {

namespace {

// TODO: clever template usage could generalize this to list-of-types-to-try.
std::vector<Expr> args_to_vector_for_print(const py::args &args, size_t start_offset = 0) {
if (args.size() < start_offset) {
throw py::value_error("Not enough arguments");
}
std::vector<Expr> v;
v.reserve(args.size() - (start_offset));
for (size_t i = start_offset; i < args.size(); ++i) {
// No way to see if a cast will work: just have to try
// and fail. Normally we don't want string to be convertible
// to Expr, but in this unusual case we do.
try {
v.emplace_back(args[i].cast<std::string>());
} catch (...) {
v.push_back(args[i].cast<Expr>());
}
}
return v;
}

} // namespace

void define_operators(py::module &m) {
m.def("max", [](const py::args &args) -> Expr {
if (args.size() < 2) {
Expand Down Expand Up @@ -149,11 +125,11 @@ void define_operators(py::module &m) {
m.def("reinterpret", (Expr(*)(Type, Expr)) & reinterpret);
m.def("cast", (Expr(*)(Type, Expr)) & cast);
m.def("print", [](const py::args &args) -> Expr {
return print(args_to_vector_for_print(args));
return print(collect_print_args(args));
});
m.def(
"print_when", [](const Expr &condition, const py::args &args) -> Expr {
return print_when(condition, args_to_vector_for_print(args));
return print_when(condition, collect_print_args(args));
},
py::arg("condition"));
m.def(
Expand Down
7 changes: 7 additions & 0 deletions python_bindings/src/halide/halide_/PyPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ void define_pipeline(py::module &m) {
.def("defined", &Pipeline::defined)
.def("invalidate_cache", &Pipeline::invalidate_cache)

.def(
"add_requirement", [](Pipeline &p, const Expr &condition, const py::args &error_args) -> void {
auto v = collect_print_args(error_args);
p.add_requirement(condition, v);
},
py::arg("condition"))

.def("__repr__", [](const Pipeline &p) -> std::string {
std::ostringstream o;
o << "<halide.Pipeline [";
Expand Down
44 changes: 43 additions & 1 deletion python_bindings/test/correctness/addconstant_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test(addconstant_impl_func, offset):
scalar_u64 = 5724968371
scalar_i8 = -7
scalar_i16 = -30712
scalar_i32 = -98901
scalar_i32 = 98901
scalar_i64 = -8163465847
scalar_float = 3.14159
scalar_double = 1.61803
Expand Down Expand Up @@ -93,6 +93,48 @@ def test(addconstant_impl_func, offset):
for z in range(input_3d.shape[2]):
assert output_3d[x, y, z] == input_3d[x, y, z] + scalar_i8 + offset

try:
# Expected requirement failure #1
scalar_i32 = 0
addconstant_impl_func(
scalar_u1,
scalar_u8, scalar_u16, scalar_u32, scalar_u64,
scalar_i8, scalar_i16, scalar_i32, scalar_i64,
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
else:
assert False, 'Did not see expected exception!'

try:
# Expected requirement failure #2 -- note that for AOT-compiled
# code in Python, the error message is stricly numeric (the text
# of the error isn't currently propagated int he exception).
scalar_i32 = -1
addconstant_impl_func(
scalar_u1,
scalar_u8, scalar_u16, scalar_u32, scalar_u64,
scalar_i8, scalar_i16, scalar_i32, scalar_i64,
scalar_float, scalar_double,
input_u8, input_u16, input_u32, input_u64,
input_i8, input_i16, input_i32, input_i64,
input_float, input_double, input_2d, input_3d,
output_u8, output_u16, output_u32, output_u64,
output_i8, output_i16, output_i32, output_i64,
output_float, output_double, output_2d, output_3d,
)
except RuntimeError as e:
assert str(e) == "Halide Runtime Error: -27", e
else:
assert False, 'Did not see expected exception!'


if __name__ == "__main__":
for t, o in TESTS_AND_OFFSETS:
Expand Down
31 changes: 31 additions & 0 deletions python_bindings/test/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,36 @@ def test_typed_funcs():
assert False, 'Did not see expected exception!'


def test_requirements():
delta = hl.Param(hl.Int(32), 'delta')
x = hl.Var('x')
f = hl.Func('f_requirements')
f[x] = x + delta

# Add a requirement
p = hl.Pipeline([f])
p.add_requirement(delta != 0) # error_args omitted
p.add_requirement(delta > 0, "negative values are bad", delta)

delta.set(1)
p.realize([10])

try:
delta.set(0)
p.realize([10])
except hl.HalideError as e:
assert 'Requirement Failed: (false)' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
delta.set(-1)
p.realize([10])
except hl.HalideError as e:
assert 'Requirement Failed: (false) negative values are bad -1' in str(e)
else:
assert False, 'Did not see expected exception!'

if __name__ == "__main__":
test_compiletime_error()
test_runtime_error()
Expand All @@ -402,3 +432,4 @@ def test_typed_funcs():
test_basics5()
test_scalar_funcs()
test_bool_conversion()
test_requirements()
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantcpp_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class AddConstantGenerator : public Halide::Generator<AddConstantGenerator> {
Var x, y, z;

void generate() {
add_requirement(scalar_int32 != 0); // error_args omitted for this case
add_requirement(scalar_int32 > 0, "negative values are bad", scalar_int32);

output_uint8(x) = input_uint8(x) + scalar_uint8;
output_uint16(x) = input_uint16(x) + scalar_uint16;
output_uint32(x) = input_uint32(x) + scalar_uint32;
Expand Down
3 changes: 3 additions & 0 deletions python_bindings/test/generators/addconstantpy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class AddConstantGenerator:

def generate(self):
g = self
g.add_requirement(g.scalar_int32 != 0) # error_args omitted for this case
g.add_requirement(g.scalar_int32 > 0, "negative values are bad", g.scalar_int32)

g.output_uint8[x] = g.input_uint8[x] + g.scalar_uint8
g.output_uint16[x] = g.input_uint16[x] + g.scalar_uint16
g.output_uint32[x] = g.input_uint32[x] + g.scalar_uint32
Expand Down
8 changes: 8 additions & 0 deletions src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,11 @@ void GeneratorBase::pre_schedule() {
void GeneratorBase::post_schedule() {
}

void GeneratorBase::add_requirement(const Expr &condition, const std::vector<Expr> &error_args) {
internal_assert(!pipeline.defined());
requirements.push_back({condition, error_args});
}

Pipeline GeneratorBase::get_pipeline() {
check_min_phase(GenerateCalled);
if (!pipeline.defined()) {
Expand Down Expand Up @@ -1584,6 +1589,9 @@ Pipeline GeneratorBase::get_pipeline() {
}
}
pipeline = Pipeline(funcs);
for (const auto &r : requirements) {
pipeline.add_requirement(r.condition, r.error_args);
}
}
return pipeline;
}
Expand Down
17 changes: 14 additions & 3 deletions src/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3444,9 +3444,14 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator {
return p;
}

template<typename... Args>
HALIDE_NO_USER_CODE_INLINE void add_requirement(Expr condition, Args &&...args) {
get_pipeline().add_requirement(condition, std::forward<Args>(args)...);
void add_requirement(const Expr &condition, const std::vector<Expr> &error_args);

template<typename... Args,
typename = typename std::enable_if<Internal::all_are_printable_args<Args...>::value>::type>
inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) {
std::vector<Expr> collected_args;
Internal::collect_print_args(collected_args, std::forward<Args>(error_args)...);
add_requirement(condition, collected_args);
}

void trace_pipeline() {
Expand Down Expand Up @@ -3636,6 +3641,12 @@ class GeneratorBase : public NamesInterface, public AbstractGenerator {
std::string generator_registered_name, generator_stub_name;
Pipeline pipeline;

struct Requirement {
Expr condition;
std::vector<Expr> error_args;
};
std::vector<Requirement> requirements;

// Return our GeneratorParamInfo.
GeneratorParamInfo &param_info();

Expand Down
9 changes: 9 additions & 0 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,15 @@ Stmt remove_promises(const Stmt &s);
* the tagged expression. If not, returns the expression. */
Expr unwrap_tags(const Expr &e);

template<typename T>
struct is_printable_arg {
static constexpr bool value = std::is_convertible<T, const char *>::value ||
std::is_convertible<T, Halide::Expr>::value;
};

template<typename... Args>
struct all_are_printable_args : meta_and<is_printable_arg<Args>...> {};

// Secondary args to print can be Exprs or const char *
inline HALIDE_NO_USER_CODE_INLINE void collect_print_args(std::vector<Expr> &args) {
}
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ Realization Pipeline::realize(JITUserContext *context,
return r;
}

void Pipeline::add_requirement(const Expr &condition, std::vector<Expr> &error_args) {
void Pipeline::add_requirement(const Expr &condition, const std::vector<Expr> &error_args) {
user_assert(defined()) << "Pipeline is undefined\n";

// It is an error for a requirement to reference a Func or a Var
Expand Down
17 changes: 10 additions & 7 deletions src/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,17 +547,20 @@ class Pipeline {
* with the remaining arguments, and return
* halide_error_code_requirement_failed. Requirements are checked
* in the order added. */
void add_requirement(const Expr &condition, std::vector<Expr> &error);

/** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */
void trace_pipeline();
// @{
void add_requirement(const Expr &condition, const std::vector<Expr> &error_args);

template<typename... Args>
inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...args) {
template<typename... Args,
typename = typename std::enable_if<Internal::all_are_printable_args<Args...>::value>::type>
inline HALIDE_NO_USER_CODE_INLINE void add_requirement(const Expr &condition, Args &&...error_args) {
std::vector<Expr> collected_args;
Internal::collect_print_args(collected_args, std::forward<Args>(args)...);
Internal::collect_print_args(collected_args, std::forward<Args>(error_args)...);
add_requirement(condition, collected_args);
}
// @}

/** Generate begin_pipeline and end_pipeline tracing calls for this pipeline. */
void trace_pipeline();

private:
std::string generate_function_name() const;
Expand Down