Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More typed-Func work #6735

Merged
merged 4 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions python_bindings/correctness/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,50 @@ def test_typed_funcs():
x = hl.Var('x')
y = hl.Var('y')

f = hl.Func('f')
assert not f.defined()
try:
assert f.output_type() == Int(32)
except RuntimeError as e:
assert 'it is undefined' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
assert f.outputs() == 0
except RuntimeError as e:
assert 'it is undefined' in str(e)
else:
assert False, 'Did not see expected exception!'

try:
assert f.dimensions() == 0
except RuntimeError as e:
assert 'it is undefined' in str(e)
else:
assert False, 'Did not see expected exception!'


f = hl.Func(hl.Int(32), 2, 'f')
assert not f.defined()
assert f.output_type() == hl.Int(32)
assert f.output_types() == [hl.Int(32)]
assert f.outputs() == 1
assert f.dimensions() == 2

f = hl.Func([hl.Int(32), hl.Float(64)], 3, 'f')
assert not f.defined()
try:
assert f.output_type() == hl.Int(32)
except RuntimeError as e:
assert 'it returns a Tuple' in str(e)
else:
assert False, 'Did not see expected exception!'

assert f.output_types() == [hl.Int(32), hl.Float(64)]
assert f.outputs() == 2
assert f.dimensions() == 3

f = hl.Func(hl.Int(32), 1, 'f')
try:
f[x, y] = hl.i32(0);
Expand Down
69 changes: 41 additions & 28 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,23 +197,32 @@ void Func::define_extern(const std::string &function_name,

/** Get the types of the buffers returned by an extern definition. */
const Type &Func::output_type() const {
user_assert(defined())
<< "Can't access output buffer of undefined Func.\n";
user_assert(func.output_types().size() == 1)
<< "Can't call Func::output_type on Func \"" << name()
<< "\" because it returns a Tuple.\n";
return func.output_types()[0];
const auto &types = defined() ? func.output_types() : func.required_types();
if (types.empty()) {
user_error << "Can't call Func::output_type on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
} else if (types.size() > 1) {
user_error << "Can't call Func::output_type on Func \"" << name()
<< "\" because it returns a Tuple.\n";
}
return types[0];
}

const std::vector<Type> &Func::output_types() const {
user_assert(defined())
<< "Can't access output buffer of undefined Func.\n";
return func.output_types();
const auto &types = defined() ? func.output_types() : func.required_types();
user_assert(!types.empty())
<< "Can't call Func::output_type on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
return types;
}

/** Get the number of outputs this function has. */
int Func::outputs() const {
return func.outputs();
const auto &types = defined() ? func.output_types() : func.required_types();
user_assert(!types.empty())
<< "Can't call Func::outputs on Func \"" << name()
<< "\" because it is undefined or has no type requirements.\n";
return (int)types.size();
}

/** Get the name of the extern function called for an extern
Expand All @@ -223,10 +232,11 @@ const std::string &Func::extern_function_name() const {
}

int Func::dimensions() const {
if (!defined()) {
return 0;
}
return func.dimensions();
const int dims = defined() ? func.dimensions() : func.required_dimensions();
user_assert(dims != AnyDims)
<< "Can't call Func::dimensions on Func \"" << name()
<< "\" because it is undefined or has no dimension requirements.\n";
return dims;
}

FuncRef Func::operator()(vector<Var> args) const {
Expand All @@ -251,17 +261,19 @@ std::pair<int, int> Func::add_implicit_vars(vector<Var> &args) const {
placeholder_pos = (int)(iter - args.begin());
int i = 0;
iter = args.erase(iter);
while ((int)args.size() < dimensions()) {
// It's important to use func.dimensions() here, *not* this->dimensions(),
// since the latter can return the Func's required dimensions rather than its actual dimensions.
while ((int)args.size() < func.dimensions()) {
Internal::debug(2) << "Adding implicit var " << i << " to call to " << name() << "\n";
iter = args.insert(iter, Var::implicit(i++));
iter++;
count++;
}
}

if (defined() && args.size() != (size_t)dimensions()) {
if (defined() && args.size() != (size_t)func.dimensions()) {
user_error << "Func \"" << name() << "\" was called with "
<< args.size() << " arguments, but was defined with " << dimensions() << "\n";
<< args.size() << " arguments, but was defined with " << func.dimensions() << "\n";
}

return {placeholder_pos, count};
Expand All @@ -282,17 +294,19 @@ std::pair<int, int> Func::add_implicit_vars(vector<Expr> &args) const {
placeholder_pos = (int)(iter - args.begin());
int i = 0;
iter = args.erase(iter);
while ((int)args.size() < dimensions()) {
// It's important to use func.dimensions() here, *not* this->dimensions(),
// since the latter can return the Func's required dimensions rather than its actual dimensions.
while ((int)args.size() < func.dimensions()) {
Internal::debug(2) << "Adding implicit var " << i << " to call to " << name() << "\n";
iter = args.insert(iter, Var::implicit(i++));
iter++;
count++;
}
}

if (defined() && args.size() != (size_t)dimensions()) {
if (defined() && args.size() != (size_t)func.dimensions()) {
user_error << "Func \"" << name() << "\" was called with "
<< args.size() << " arguments, but was defined with " << dimensions() << "\n";
<< args.size() << " arguments, but was defined with " << func.dimensions() << "\n";
}

return {placeholder_pos, count};
Expand Down Expand Up @@ -3188,21 +3202,20 @@ void Func::infer_input_bounds(JITUserContext *context,
}

OutputImageParam Func::output_buffer() const {
user_assert(defined())
<< "Can't access output buffer of undefined Func.\n";
user_assert(func.output_buffers().size() == 1)
const auto &ob = func.output_buffers();

user_assert(ob.size() == 1)
<< "Can't call Func::output_buffer on Func \"" << name()
<< "\" because it returns a Tuple.\n";
return OutputImageParam(func.output_buffers()[0], Argument::OutputBuffer, *this);
return OutputImageParam(ob[0], Argument::OutputBuffer, *this);
}

vector<OutputImageParam> Func::output_buffers() const {
user_assert(defined())
<< "Can't access output buffers of undefined Func.\n";
const auto &ob = func.output_buffers();

vector<OutputImageParam> bufs(func.output_buffers().size());
vector<OutputImageParam> bufs(ob.size());
for (size_t i = 0; i < bufs.size(); i++) {
bufs[i] = OutputImageParam(func.output_buffers()[i], Argument::OutputBuffer, *this);
bufs[i] = OutputImageParam(ob[i], Argument::OutputBuffer, *this);
}
return bufs;
}
Expand Down
13 changes: 9 additions & 4 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1203,22 +1203,27 @@ class Func {
DeviceAPI device_api = DeviceAPI::Host);
// @}

/** Get the types of the outputs of this Func. */
/** Get the type(s) of the outputs of this Func.
* If the Func isn't yet defined, but was specified with required types,
* the requirements will be returned. */
// @{
const Type &output_type() const;
const std::vector<Type> &output_types() const;
// @}

/** Get the number of outputs of this Func. Corresponds to the
* size of the Tuple this Func was defined to return. */
* size of the Tuple this Func was defined to return.
* If the Func isn't yet defined, but was specified with required types,
* the number of outputs specified in the requirements will be returned. */
int outputs() const;

/** Get the name of the extern function called for an extern
* definition. */
const std::string &extern_function_name() const;

/** The dimensionality (number of arguments) of this
* function. Zero if the function is not yet defined. */
/** The dimensionality (number of arguments) of this function.
* If the Func isn't yet defined, but was specified with required dimensionality,
* the dimensionality specified in the requirements will be returned. */
int dimensions() const;

/** Construct either the left-hand-side of a definition, or a call
Expand Down
45 changes: 41 additions & 4 deletions src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,13 +602,26 @@ void Function::define(const vector<string> &args, vector<Expr> values) {
// Just a reality check; mismatches here really should have been caught earlier
internal_assert(contents->required_types == contents->output_types);
}
if (contents->required_dims != AnyDims) {
// Just a reality check; mismatches here really should have been caught earlier
internal_assert(contents->required_dims == (int)args.size());
}

for (size_t i = 0; i < values.size(); i++) {
if (contents->output_buffers.empty()) {
create_output_buffers(contents->output_types, (int)args.size());
}
}

void Function::create_output_buffers(const std::vector<Type> &types, int dims) const {
internal_assert(contents->output_buffers.empty());
internal_assert(!types.empty() && dims != AnyDims);

for (size_t i = 0; i < types.size(); i++) {
string buffer_name = name();
if (values.size() > 1) {
if (types.size() > 1) {
buffer_name += '.' + std::to_string((int)i);
}
Parameter output(values[i].type(), true, args.size(), buffer_name);
Parameter output(types[i], true, dims, buffer_name);
contents->output_buffers.push_back(output);
}
}
Expand Down Expand Up @@ -908,13 +921,25 @@ bool Function::is_pure_arg(const std::string &name) const {
}

int Function::dimensions() const {
return args().size();
return (int)args().size();
}

int Function::outputs() const {
return (int)output_types().size();
}

const std::vector<Type> &Function::output_types() const {
return contents->output_types;
}

const std::vector<Type> &Function::required_types() const {
return contents->required_types;
}

int Function::required_dimensions() const {
return contents->required_dims;
}

const std::vector<Expr> &Function::values() const {
static const std::vector<Expr> empty;
if (has_pure_definition()) {
Expand All @@ -933,6 +958,18 @@ const FuncSchedule &Function::schedule() const {
}

const std::vector<Parameter> &Function::output_buffers() const {
if (!contents->output_buffers.empty()) {
return contents->output_buffers;
}

// If types and dims are already specified, we can go ahead and create
// the output buffer(s) even if the Function has no pure definition yet.
if (!contents->required_types.empty() && contents->required_dims != AnyDims) {
create_output_buffers(contents->required_types, contents->required_dims);
return contents->output_buffers;
}

user_error << "Can't access output buffer(s) of undefined Func \"" << name() << "\".\n";
return contents->output_buffers;
}

Expand Down
14 changes: 11 additions & 3 deletions src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,17 @@ class Function {
int dimensions() const;

/** Get the number of outputs. */
int outputs() const {
return (int)output_types().size();
}
int outputs() const;

/** Get the types of the outputs. */
const std::vector<Type> &output_types() const;

/** Get the type constaints on the outputs (if any). */
const std::vector<Type> &required_types() const;

/** Get the dimensionality constaints on the outputs (if any). */
int required_dimensions() const;

/** Get the right-hand-side of the pure definition. Returns an
* empty vector if there is no pure definition. */
const std::vector<Expr> &values() const;
Expand Down Expand Up @@ -312,6 +316,10 @@ class Function {
/** If the Function has dimension requirements, check that the given argument
* is compatible with them. If not, assert-fail. (If there are no dimension requirements, do nothing.) */
void check_dims(int dims) const;

/** Define the output buffers. If the Function has types specified, this can be called at
* any time. If not, it can only be called for a Function with a pure definition. */
void create_output_buffers(const std::vector<Type> &types, int dims) const;
};

/** Deep copy an entire Function DAG. */
Expand Down
6 changes: 4 additions & 2 deletions src/Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ Argument to_argument(const Internal::Parameter &param) {

Func make_param_func(const Parameter &p, const std::string &name) {
internal_assert(p.is_buffer());
Func f(name + "_im");
Func f(p.type(), p.dimensions(), name + "_im");
auto b = p.buffer();
if (b.defined()) {
// If the Parameter has an explicit BufferPtr set, bind directly to it
Expand Down Expand Up @@ -2134,8 +2134,10 @@ void GeneratorOutputBase::init_internals() {
exprs_.clear();
funcs_.clear();
if (array_size_defined()) {
const auto t = types_defined() ? types() : std::vector<Type>{};
const int d = dims_defined() ? dims() : -1;
for (size_t i = 0; i < array_size(); ++i) {
funcs_.emplace_back(array_name(i));
funcs_.emplace_back(t, d, array_name(i));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ImageParam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Func ImageParam::create_func() const {
// Discourage future Funcs from having the same name
Internal::unique_name(name());
}
Func f(name() + "_im");
Func f(param.type(), param.dimensions(), name() + "_im");
f(args) = Internal::Call::make(param, args_expr);
return f;
}
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ tests(GROUPS correctness
tuple_update_ops.cpp
tuple_vector_reduce.cpp
two_vector_args.cpp
typed_func.cpp
undef.cpp
uninitialized_read.cpp
unique_func_image.cpp
Expand Down
Loading