Skip to content

Commit

Permalink
[FIX] Allow tokenizer to parse numbers greater than INT_MAX. (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tkonolige authored Jun 7, 2021
1 parent 364bc1b commit 2c67d71
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 36 deletions.
25 changes: 17 additions & 8 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -524,13 +524,22 @@ class Parser {
NDArray NumberToNDArray(const Token& token) {
if (token->token_type == TokenType::kInteger) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
auto dtype = String2DLDataType("int32");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<int32_t*>(data->data);
// revisit this, literal node issue.
int64_t value = Downcast<tvm::Integer>(token->data);
array[0] = (int32_t)value;
return data;
int64_t i = Downcast<tvm::Integer>(token->data);
if (i > std::numeric_limits<int32_t>::max()) {
auto dtype = String2DLDataType("int64");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<int64_t*>(data->data);
// revisit this, literal node issue.
array[0] = i;
return data;
} else {
auto dtype = String2DLDataType("int32");
auto data = NDArray::Empty({}, dtype, dev);
auto array = reinterpret_cast<int32_t*>(data->data);
// revisit this, literal node issue.
array[0] = i;
return data;
}
} else if (token->token_type == TokenType::kFloat) {
DLDevice dev = {DLDeviceType::kDLCPU, 0};
auto float_imm = Downcast<tvm::FloatImm>(token->data);
Expand Down Expand Up @@ -1516,7 +1525,7 @@ class Parser {
}
case TokenType::kBoolean: {
Consume(TokenType::kBoolean);
int value = Downcast<tvm::Integer>(next->data);
int64_t value = Downcast<tvm::Integer>(next->data);
auto boolean = BooleanToNDarray(value);
Expr e = Constant(boolean, next->span);
ICHECK(e->span.defined()) << "constant spans must be defined";
Expand Down
66 changes: 38 additions & 28 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/runtime/object.h>

#include <fstream>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
Expand Down Expand Up @@ -171,44 +172,53 @@ struct Tokenizer {
Token ParseNumber(bool is_pos, bool is_float, std::string number) {
ICHECK(number.size() > 0) << "an empty string is an invalid number";

try {
if (is_float) {
throw std::invalid_argument("is_float");
}
if (!is_float) {
auto token = NewToken(TokenType::kInteger);
size_t index = 0;
int value = std::stoi(number, &index);
if (number.size() > index) {
throw std::invalid_argument("floating point");
int64_t value = 0;
try {
value = std::stoll(number, &index);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span) << "invalid number `" << number << "`");
}
value = is_pos ? value : -value;
token->data = tvm::Integer(value);
return token;
} catch (const std::invalid_argument& ia) {
auto token = NewToken(TokenType::kFloat);
if (number.size() <= index) {
value = is_pos ? value : -value;
if (value > std::numeric_limits<int32_t>::max()) {
token->data = tvm::IntImm(DataType::Int(64), value);
} else {
token->data = tvm::IntImm(DataType::Int(32), value);
}
return token;
}
}
auto token = NewToken(TokenType::kFloat);

auto suffix_pos = number.rfind("f");
auto suffix_pos = number.rfind("f");

auto literal_text = number.substr(0, suffix_pos);
auto literal_text = number.substr(0, suffix_pos);

auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos);
auto suffix = number.substr(suffix_pos + 1, number.size() - suffix_pos);

int width = 32;
int width = 32;

if (suffix.size()) {
try {
width = std::stoi(suffix);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid numeric suffix `" << suffix << "`");
}
if (suffix.size()) {
try {
width = std::stoi(suffix);
} catch (const std::invalid_argument& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid numeric suffix `" << suffix << "`");
} catch (const std::out_of_range& err) {
this->diag_ctx.Emit(Diagnostic::Error(token->span)
<< "invalid numeric suffix `" << suffix << "`");
}

double value = stod(literal_text);
value = is_pos ? value : -value;
token->data = tvm::FloatImm(DataType::Float(width), value);
return token;
}

double value = stod(literal_text);
value = is_pos ? value : -value;
token->data = tvm::FloatImm(DataType::Float(width), value);
return token;
}

Token ParseNumber(bool is_pos) {
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def test_int_literal():
assert get_scalar(parse_text("0")) == 0
assert get_scalar(parse_text("-100")) == -100
assert get_scalar(parse_text("-05")) == -5
assert get_scalar(parse_text("9223372036854775807")) == 9223372036854775807


def test_float_literal():
Expand Down

0 comments on commit 2c67d71

Please sign in to comment.