Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions src/relay/transforms/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include <utility>

#include "../../support/scalars.h"
#include "pattern_utils.h"

namespace tvm {
Expand Down Expand Up @@ -110,6 +111,39 @@ class MixedPrecisionPass : public MixedModeMutator {
std::vector<DataType> original_dtype_;
bool keep_orig_output_dtype_;

/*! \brief If some of the constant attributes are out of mixed_precision_type_ bounds, then
* computation cannot be performed in mixed precision. */
bool IsMixedPrecisionApplicableToAttrs(const Attrs& attrs) const {
if (attrs.get() != nullptr) {
double min_bound;
double max_bound;
if (mixed_precision_type_.is_float16()) {
min_bound = -support::kMaxFloat16;
max_bound = support::kMaxFloat16;
} else if (mixed_precision_type_.is_bfloat16()) {
min_bound = -support::kMaxBFloat16;
max_bound = support::kMaxBFloat16;
} else if (mixed_precision_type_.is_float8()) {
double bound = (mixed_precision_type_.code() == DataType::kE4M3Float) ? support::kMaxE4M3
: support::kMaxE5M2;
min_bound = -bound;
max_bound = bound;
} else if (mixed_precision_type_.is_float()) {
min_bound = std::numeric_limits<float>::lowest();
max_bound = std::numeric_limits<float>::max();
} else {
return true;
}

if (auto cur_attrs = attrs.as<ClipAttrs>()) {
if (cur_attrs->a_min < min_bound || cur_attrs->a_max > max_bound) {
return false;
}
}
}
return true;
}

Attrs GetNewAttrs(const CallNode* call, const DataType& accumulation_dtype) const {
/* If the accumulation dtype is in the attributes make a copy and mutate the field. */
Attrs cur_attrs = call->attrs;
Expand Down Expand Up @@ -382,9 +416,12 @@ class MixedPrecisionPass : public MixedModeMutator {
all_args_mixed_type_compatible ? MIXED_PRECISION_ALWAYS : MIXED_PRECISION_NEVER;
}

bool is_mixed_precision_applicable =
static_cast<bool>(final_category == MIXED_PRECISION_ALWAYS &&
IsMixedPrecisionApplicableToAttrs(pre_call_node->attrs));
// Create the new arguments to the call.
DataType wanted_arg_dtypes =
final_category == MIXED_PRECISION_ALWAYS ? mixed_precision_type_ : DataType::Float(32);
is_mixed_precision_applicable ? mixed_precision_type_ : DataType::Float(32);
auto call_args_and_types = CastAllArgs(post_call_node->args, cur_arg_types, wanted_arg_dtypes);
Array<Expr> new_args = call_args_and_types.first;
Array<Type> new_arg_types;
Expand All @@ -397,7 +434,7 @@ class MixedPrecisionPass : public MixedModeMutator {
}

// Finally create the new attributes.
if (final_category == MIXED_PRECISION_ALWAYS) {
if (is_mixed_precision_applicable) {
Attrs new_attrs = GetNewAttrs(pre_call_node, accumulation_dtype);
Expr output = Call(cur_op, new_args, new_attrs, new_arg_types, pre_call_node->span);
if (accumulation_dtype != output_dtype) {
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relay/test_to_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,5 +537,54 @@ def test_convert_follow_node_with_integer_arguments(target_precision):
assert tvm.ir.structural_equal(expected_mod, output_mod)


def test_clip(target_precision):
data = relay.var("data", shape=[1, 10], dtype="float32")
res = relay.clip(data, a_min=-128000, a_max=128000)

mod = tvm.IRModule.from_expr(res)

mod_params = {
"data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
}
output_mod = verify_mixed_precision_output_close(
mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01
)

# Create expected module
if target_precision == "bfloat16":
data = relay.cast(relay.var("data", shape=[1, 10]), target_precision)
res = relay.clip(data, a_min=-128000, a_max=128000)
expected_mod = tvm.IRModule.from_expr(res)
expected_mod = InferType()(expected_mod)
assert tvm.ir.structural_equal(expected_mod, output_mod)


def test_clip_with_pre_op(target_precision):
data = relay.var("data", shape=[1, 10], dtype="float32")
const = relay.const(5, "float32")
res = relay.divide(data, const)
res = relay.clip(res, a_min=-128000, a_max=128000)

mod = tvm.IRModule.from_expr(res)

mod_params = {
"data": np.random.uniform(-1, 1, size=[1, 10]).astype("float32"),
}
output_mod = verify_mixed_precision_output_close(
mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01
)

# Create expected module
data = relay.cast(relay.var("data", shape=[1, 10]), target_precision)
const = relay.cast(relay.const(5, "float32"), target_precision)
res = relay.divide(data, const)
if target_precision == "float16":
res = relay.cast(res, "float32")
res = relay.clip(res, a_min=-128000, a_max=128000)
expected_mod = tvm.IRModule.from_expr(res)
expected_mod = InferType()(expected_mod)
assert tvm.ir.structural_equal(expected_mod, output_mod)


if __name__ == "__main__":
tvm.testing.main()