Skip to content

Commit

Permalink
[Relay] Support i16, f16 scalars in Relay text (apache#11224)
Browse files Browse the repository at this point in the history
While testing fp16 models for Collage discovered the Relay text
format did not support f16. While adding that cleaned up scalar handling
in general. However I left two inlined tests for 'is simple const'
in place (fuse_ops.cc and memory_alloc.cc) since it's not clear whether
they should remain specific to just {i,f}{32,64} or whether they can
be replaced with the support::IsSimpleScalar central predicate.
  • Loading branch information
mbs-octoml authored May 19, 2022
1 parent ffc0443 commit 534c38b
Show file tree
Hide file tree
Showing 10 changed files with 505 additions and 148 deletions.
45 changes: 6 additions & 39 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@

#include <fstream>

#include "../support/scalars.h"
#include "./meta_ref.h"
#include "./op_table.h"
#include "./span_check.h"
#include "./tokenizer.h"
#include "tvm/runtime/builtin_fp16.h"

namespace tvm {
namespace parser {
Expand Down Expand Up @@ -534,49 +536,15 @@ class Parser {
/*! \brief Convert a numeric token to an NDArray for embedding into the Relay program. */
NDArray NumberToNDArray(const Token& token) {
if (token->token_type == TokenType::kInteger) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
int64_t i = Downcast<tvm::Integer>(token->data);
if (i > std::numeric_limits<int32_t>::max()) {
auto dtype = String2DLDataType("int64");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<int64_t*>(data->data);
// revisit this, literal node issue.
array[0] = i;
return data;
} else {
auto dtype = String2DLDataType("int32");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<int32_t*>(data->data);
// revisit this, literal node issue.
array[0] = i;
return data;
}
return support::IntImmToNDArray(Downcast<tvm::IntImm>(token->data));
} else if (token->token_type == TokenType::kFloat) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
auto float_imm = Downcast<tvm::FloatImm>(token->data);
auto data = NDArray::Empty({}, float_imm->dtype, dev);
auto array = reinterpret_cast<float*>(data->data);
// revisit this, literal node issue.
// TODO(@jroesch): bounds checking
float value = float_imm->value;
array[0] = value;
return data;
return support::FloatImmToNDArray(Downcast<tvm::FloatImm>(token->data));
} else {
LOG(FATAL) << "internal error: should only call this function on numeric tokens";
return NDArray();
return {};
}
}

/*! \brief Convert a boolean value to an NDArray for embedding into the Relay program. */
NDArray BooleanToNDarray(bool value) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
auto dtype = String2DLDataType("bool");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<bool*>(data->data);
array[0] = value;
return data;
}

[[noreturn]] void ParseError(const Token& token, const std::string& msg) {
throw std::runtime_error(msg);
}
Expand Down Expand Up @@ -1573,8 +1541,7 @@ class Parser {
case TokenType::kBoolean: {
Consume(TokenType::kBoolean);
int64_t value = Downcast<tvm::Integer>(next->data);
auto boolean = BooleanToNDarray(value);
Expr e = Constant(boolean, next->span);
Expr e = Constant(support::BoolToNDArray(value), next->span);
ICHECK(e->span.defined()) << "constant spans must be defined";
return e;
}
Expand Down
104 changes: 69 additions & 35 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <utility>
#include <vector>

#include "../support/scalars.h"
#include "./meta_ref.h"
#include "./token.h"

Expand Down Expand Up @@ -174,35 +175,16 @@ struct Tokenizer {
Token ParseNumber(bool is_pos, bool is_float, std::string number) {
ICHECK(number.size() > 0) << "an empty string is an invalid number";

if (!is_float) {
auto token = NewToken(TokenType::kInteger);
size_t index = 0;
int64_t value = 0;
try {
value = std::stoll(number, &index);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`");
}
if (number.size() <= index) {
value = is_pos ? value : -value;
if (value > std::numeric_limits<int32_t>::max()) {
token->data = tvm::IntImm(DataType::Int(64), value);
} else {
token->data = tvm::IntImm(DataType::Int(32), value);
}
return token;
}
Token token = NewToken(is_float ? TokenType::kFloat : TokenType::kInteger);
size_t suffix_pos = number.rfind(is_float ? 'f' : 'i');
if (suffix_pos == std::string::npos) {
suffix_pos = number.size();
}
std::string literal_text = number.substr(0, suffix_pos);
std::string suffix;
if (suffix_pos < number.size()) {
suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos);
}
auto token = NewToken(TokenType::kFloat);

auto suffix_pos = number.rfind("f");

auto literal_text = number.substr(0, suffix_pos);

auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos);

int width = 32;

if (suffix.size()) {
Expand All @@ -217,9 +199,62 @@ struct Tokenizer {
}
}

double value = stod(literal_text);
value = is_pos ? value : -value;
token->data = tvm::FloatImm(DataType::Float(width), value);
if (is_float) {
double value = 0.0;
size_t index = 0;
try {
value = stod(literal_text, &index);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid floating point number `" << literal_text << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid floating point number `" << literal_text << "`");
}
if (index < literal_text.size()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid floating point number `" << literal_text << "`");
}
value = is_pos ? value : -value;
token->data = support::ValueToFloatImm(value, width);
if (!token->data.defined()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "floating point number `" << literal_text
<< "` unrepresentable in width " << width);
token->data = support::ValueToFloatImm(0.0, width);
}
} else {
int64_t value = 0;
size_t index = 0;
try {
value = std::stoll(literal_text, &index);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid integer number `" << literal_text << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid integer number `" << literal_text << "`");
}
if (index < literal_text.size()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid integer number `" << literal_text << "`");
}
value = is_pos ? value : -value;
token->data = support::ValueToIntImm(value, width);
if (!token->data.defined() && suffix.empty()) {
// Without any i suffix the legacy behavior was to default to int64 if out of range
// for int32.
width = 64;
token->data = support::ValueToIntImm(value, width);
}
if (!token->data.defined()) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "integer number `" << literal_text << "` unrepresentable in width "
<< width);
token->data = support::ValueToIntImm(0, width);
}
}

return token;
}

Expand All @@ -230,14 +265,13 @@ struct Tokenizer {
}

bool is_float = false;

// Remove trailing floating point prefix.
if (More() && Peek() == 'f') {
if (More() && (Peek() == 'f' || Peek() == 'i')) {
is_float = Peek() == 'f';
// Capture trailing width suffix
ss << Next();
while (More() && IsNumeric(Peek())) {
ss << Next();
}
is_float = true;
}
return ParseNumber(is_pos, is_float, ss.str());
}
Expand Down
7 changes: 1 addition & 6 deletions src/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,7 @@ TVM_REGISTER_OBJECT_TYPE(DocTextNode);

class DocText : public DocAtom {
public:
explicit DocText(std::string str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not have tab or newline.";
}
data_ = runtime::make_object<DocTextNode>(str);
}
explicit DocText(std::string str) { data_ = runtime::make_object<DocTextNode>(str); }

TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode);
};
Expand Down
80 changes: 32 additions & 48 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
#include "../ir/attr_functor.h"
#include "../parser/meta_ref.h"
#include "../relay/analysis/dependency_graph.h"
#include "../support/scalars.h"
#include "doc.h"
#include "meta_data.h"
#include "text_printer.h"
#include "tvm/runtime/builtin_fp16.h"

namespace tvm {
namespace relay {
Expand All @@ -61,8 +63,17 @@ Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
}
// default annotations
if (annotate_ == nullptr) {
if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
doc << " /* ty=" << Print(expr->checked_type()) << " */";
if ((expr.as<ConstantNode>() || expr.as<CallNode>() || expr.as<VarNode>() ||
expr.as<FunctionNode>() || expr.as<TupleNode>() || expr.as<TupleGetItemNode>()) &&
(expr->checked_type_.defined() || expr->span.defined())) {
doc << " /*";
if (expr->checked_type_.defined()) {
doc << " ty=" << Print(expr->checked_type());
}
if (expr->span.defined()) {
doc << " span=" << PrintSpan(expr->span);
}
doc << " */";
}
} else {
std::string annotated_expr = annotate_(expr);
Expand Down Expand Up @@ -219,7 +230,7 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
name = "v" + name;
}
Doc val = GetUniqueName("%" + name);
memo_[var] = val;
memo_[var] = val; // Referential occurrences will not include the following.
if (!var->virtual_device()->IsFullyUnconstrained()) {
val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}";
}
Expand Down Expand Up @@ -335,51 +346,17 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
// first time.
Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef<Var>(op)); }

/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param value The value to be printed.
*/
template <typename T>
Doc RelayTextPrinter::ScalarLiteral(DataType dtype, const T& value) {
std::ostringstream os;
if (dtype == DataType::Int(32)) {
os << value;
} else if (dtype == DataType::Float(32)) {
os << value << 'f';
} else if (dtype == DataType::Float(64)) {
os << value << "f64";
} else if (dtype == DataType::Bool()) {
return Doc::PyBoolLiteral(value != 0);
} else {
os << value;
}
return Doc::Text(os.str());
}

Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) {
// Print out simple scalars directly.
if (op->is_scalar()) {
std::ostringstream os;
DataType dtype = DataType(op->data->dtype);
ICHECK_EQ(op->data->device.device_type, kDLCPU);
if (dtype == DataType::Int(32)) {
return ScalarLiteral(dtype, static_cast<const int32_t*>(op->data->data)[0]);
} else if (dtype == DataType::Int(64)) {
return ScalarLiteral(dtype, static_cast<const int64_t*>(op->data->data)[0]);
} else if (dtype == DataType::Float(32)) {
return ScalarLiteral(dtype, static_cast<const float*>(op->data->data)[0]);
} else if (dtype == DataType::Float(64)) {
return ScalarLiteral(dtype, static_cast<const double*>(op->data->data)[0]);
} else if (dtype == DataType::Bool()) {
return ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
}
if (support::IsSimpleScalar(op)) {
return Doc::Text(support::NDArrayScalarToString(op->data));
}
// default fall-back, record it as meta node.
// Fallbock: record it as a meta node.
Doc doc;
// Don't append optional_info. Because the entry function is Print,
// and it will append the optional_info afterwards.
return doc << PrintExpr(GetRef<Expr>(op), true, false, false);
return doc << PrintExpr(GetRef<Expr>(op), /*meta=*/true, /*try_inline=*/false,
/*optional_info=*/false);
}

Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
Expand Down Expand Up @@ -540,9 +517,6 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
return doc;
} else {
doc << "(" << Doc::Concat(args) << ")";
if (op->span.defined()) {
doc << " /* " << PrintSpan(op->span) << " */";
}
return doc;
}
}
Expand Down Expand Up @@ -799,11 +773,21 @@ Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
}

Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) {
return ScalarLiteral(op->dtype, op->value);
if (support::IsSimpleScalarDtype(op->dtype)) {
return Doc::Text(support::IntImmToString(GetRef<IntImm>(op)));
} else {
// Fallback: Print int64_t without width suffix.
return Doc::Text(std::to_string(op->value));
}
}

Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) {
return ScalarLiteral(op->dtype, op->value);
if (support::IsSimpleScalarDtype(op->dtype)) {
return Doc::Text(support::FloatImmToString(GetRef<FloatImm>(op)));
} else {
// Fallbock: Print double without width suffix.
return Doc::Text(std::to_string(op->value));
}
}

Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {
Expand Down Expand Up @@ -977,7 +961,7 @@ Doc RelayTextPrinter::PrintSpan(const Span& span) {
Doc doc;
const auto* span_node = span.as<SpanNode>();
ICHECK(span_node);
doc << span_node->source_name->name;
doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column;
return doc;
}

Expand Down
7 changes: 0 additions & 7 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,6 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
// Should only be triggered when op is a free variable being visited for the
// first time.
Doc VisitExpr_(const VarNode* op) final;
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param value The value to be printed.
*/
template <typename T>
static Doc ScalarLiteral(DataType dtype, const T& value);
Doc VisitExpr_(const ConstantNode* op) final;
Doc VisitExpr_(const TupleNode* op) final;
Doc VisitExpr_(const TupleGetItemNode* op) final;
Expand Down
Loading

0 comments on commit 534c38b

Please sign in to comment.