Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion cpp/src/gandiva/decimal_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,52 @@ Status DecimalIR::BuildAdd() {
return Status::OK();
}

Status DecimalIR::BuildSubtract() {
// Create fn prototype :
// int128_t
// subtract_decimal128_decimal128(int128_t x_value, int32_t x_precision, int32_t
// x_scale,
// int128_t y_value, int32_t y_precision, int32_t y_scale
// int32_t out_precision, int32_t out_scale)
auto i32 = types()->i32_type();
auto i128 = types()->i128_type();
auto function = BuildFunction("subtract_decimal128_decimal128", i128,
{
{"x_value", i128},
{"x_precision", i32},
{"x_scale", i32},
{"y_value", i128},
{"y_precision", i32},
{"y_scale", i32},
{"out_precision", i32},
{"out_scale", i32},
});

auto entry = llvm::BasicBlock::Create(*context(), "entry", function);
ir_builder()->SetInsertPoint(entry);

// reuse add function after negating y_value. i.e
// add(x_value, x_precision, x_scale, -y_value, y_precision, y_scale,
// out_precision, out_scale)
std::vector<llvm::Value*> args;
int i = 0;
for (auto& in_arg : function->args()) {
if (i == 3) {
auto y_neg_value = ir_builder()->CreateNeg(&in_arg);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if this could overflow, but the decimal type is limited to 38 decimal digits, which is too small for 2**127.

As a side note, it seems we don't validate inputs to the Decimal128 type and happily accept larger numbers:

>>> ty = pa.decimal128(39, 0)                                                                            
>>> arr = pa.array([(2**127) - 1], type=ty)                                                              
>>> arr                                                                                                  
<pyarrow.lib.Decimal128Array object at 0x7f9b89444e08>
[
  170141183460469231731687303715884105727
]
>>> arr = pa.array([(2**127)], type=ty)                                                                  
>>> arr                                                                                                  
<pyarrow.lib.Decimal128Array object at 0x7f9b89444138>
[
  -170141183460469231731687303715884105728
]
>>> arr = pa.array([-(2**127)], type=ty)                                                                 
>>> arr                                                                                                  
<pyarrow.lib.Decimal128Array object at 0x7f9b89432688>
[
  -170141183460469231731687303715884105728
]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll create a jira to ensure gandiva fails when the type has a precision/scale is > 38.

For the actual values itself, I can add a IsValid() api to BasicDecimal (that checks against a min/max value). but, checking this each time will be a perf overhead. maybe, we should have a conf variable to do overflow checks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuitively, I'd say people who use decimals value correctness. Though it's clear our C++ API currently doesn't check...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created ARROW-4569 and ARROW-4570 to follow up on these.

args.push_back(y_neg_value);
} else {
args.push_back(&in_arg);
}
++i;
}
auto value =
ir_builder()->CreateCall(module()->getFunction("add_decimal128_decimal128"), args);

// store result to out
ir_builder()->CreateRet(value);
return Status::OK();
}

Status DecimalIR::AddFunctions(Engine* engine) {
auto decimal_ir = std::make_shared<DecimalIR>(engine);

Expand All @@ -317,7 +363,10 @@ Status DecimalIR::AddFunctions(Engine* engine) {
decimal_ir->InitializeIntrinsics();

// build "add"
return decimal_ir->BuildAdd();
ARROW_RETURN_NOT_OK(decimal_ir->BuildAdd());

// build "subtract"
return decimal_ir->BuildSubtract();
}

// Do an bitwise-or of all the overflow bits.
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/gandiva/decimal_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ class DecimalIR : public FunctionIRBuilder {
// Build the function for adding decimals.
Status BuildAdd();

// Build the function for decimal subtraction.
Status BuildSubtract();

// Add a trace in IR code.
void AddTrace(const std::string& fmt, std::vector<llvm::Value*> args);

Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64),

BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, decimal128),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(subtract, decimal128),

BINARY_RELATIONAL_BOOL_FN(equal),
BINARY_RELATIONAL_BOOL_FN(not_equal),
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/gandiva/precompiled/decimal_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,10 @@ BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128&
}
}

BasicDecimal128 Subtract(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y,
int32_t out_precision, int32_t out_scale) {
return Add(x, {-y.value(), y.precision(), y.scale()}, out_precision, out_scale);
}

} // namespace decimalops
} // namespace gandiva
5 changes: 5 additions & 0 deletions cpp/src/gandiva/precompiled/decimal_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,10 @@ namespace decimalops {
arrow::BasicDecimal128 Add(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y,
int32_t out_precision, int32_t out_scale);

/// Subtract 'y' from 'x', and return the result.
arrow::BasicDecimal128 Subtract(const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale);

} // namespace decimalops
} // namespace gandiva
59 changes: 53 additions & 6 deletions cpp/src/gandiva/precompiled/decimal_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,47 @@ namespace gandiva {

class TestDecimalSql : public ::testing::Test {
protected:
static void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected);
static void Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
const DecimalScalar128& y, const DecimalScalar128& expected);

void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected) {
return Verify(DecimalTypeUtil::kOpAdd, x, y, expected);
}

void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected) {
return Verify(DecimalTypeUtil::kOpSubtract, x, y, expected);
}
};

#define EXPECT_DECIMAL_EQ(x, y, expected, actual) \
EXPECT_EQ(expected, actual) << (x).ToString() << " + " << (y).ToString() \
<< " expected : " << expected.ToString() << " actual " \
<< actual.ToString()

void TestDecimalSql::AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected) {
void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,
const DecimalScalar128& y, const DecimalScalar128& expected) {
auto t1 = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
auto t2 = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());

Decimal128TypePtr out_type;
EXPECT_OK(DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {t1, t2}, &out_type));
EXPECT_OK(DecimalTypeUtil::GetResultType(op, {t1, t2}, &out_type));

arrow::BasicDecimal128 out_value;
switch (op) {
case DecimalTypeUtil::kOpAdd:
out_value = decimalops::Add(x, y, out_type->precision(), out_type->scale());
break;

case DecimalTypeUtil::kOpSubtract:
out_value = decimalops::Subtract(x, y, out_type->precision(), out_type->scale());
break;

auto out_value = decimalops::Add(x, y, out_type->precision(), out_type->scale());
default:
// not implemented.
ASSERT_FALSE(true);
}
EXPECT_DECIMAL_EQ(
x, y, expected,
DecimalScalar128(out_value, out_type->precision(), out_type->scale()));
Expand Down Expand Up @@ -74,4 +97,28 @@ TEST_F(TestDecimalSql, Add) {
DecimalScalar128{"-99999999999999999999999999999990000010", 38, 6});
}

TEST_F(TestDecimalSql, Subtract) {
// fast-path
SubtractAndVerify(DecimalScalar128{"201", 30, 3}, // x
DecimalScalar128{"301", 30, 3}, // y
DecimalScalar128{"-100", 31, 3}); // expected

// max precision
SubtractAndVerify(
DecimalScalar128{"09999999999999999999999999999999000000", 38, 5}, // x
DecimalScalar128{"100", 38, 7}, // y
DecimalScalar128{"99999999999999999999999999999989999990", 38, 6});

// Both -ve
SubtractAndVerify(DecimalScalar128{"-201", 30, 3}, // x
DecimalScalar128{"-301", 30, 2}, // y
DecimalScalar128{"2809", 32, 3}); // expected

// -ve and max precision
SubtractAndVerify(
DecimalScalar128{"-09999999999999999999999999999999000000", 38, 5}, // x
DecimalScalar128{"-100", 38, 7}, // y
DecimalScalar128{"-99999999999999999999999999999989999990", 38, 6});
}

} // namespace gandiva
54 changes: 43 additions & 11 deletions cpp/src/gandiva/tests/decimal_single_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ using arrow::Decimal128;

namespace gandiva {

#define EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual) \
EXPECT_EQ(expected, actual) << (x).ToString() << " + " << (y).ToString() \
<< " expected : " << (expected).ToString() \
#define EXPECT_DECIMAL_RESULT(op, x, y, expected, actual) \
EXPECT_EQ(expected, actual) << op << " (" << (x).ToString() << "),(" << (y).ToString() \
<< ")" \
<< " expected : " << (expected).ToString() \
<< " actual : " << (actual).ToString();

DecimalScalar128 decimal_literal(const char* value, int precision, int scale) {
Expand All @@ -46,8 +47,19 @@ class TestDecimalOps : public ::testing::Test {
void SetUp() { pool_ = arrow::default_memory_pool(); }

ArrayPtr MakeDecimalVector(const DecimalScalar128& in);

void Verify(DecimalTypeUtil::Op, const std::string& function, const DecimalScalar128& x,
const DecimalScalar128& y, const DecimalScalar128& expected);

void AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected);
const DecimalScalar128& expected) {
Verify(DecimalTypeUtil::kOpAdd, "add", x, y, expected);
}

void SubtractAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected) {
Verify(DecimalTypeUtil::kOpSubtract, "subtract", x, y, expected);
}

protected:
arrow::MemoryPool* pool_;
Expand All @@ -62,24 +74,24 @@ ArrayPtr TestDecimalOps::MakeDecimalVector(const DecimalScalar128& in) {
return MakeArrowArrayDecimal(decimal_type, {decimal_value}, {true});
}

void TestDecimalOps::AddAndVerify(const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected) {
void TestDecimalOps::Verify(DecimalTypeUtil::Op op, const std::string& function,
const DecimalScalar128& x, const DecimalScalar128& y,
const DecimalScalar128& expected) {
auto x_type = std::make_shared<arrow::Decimal128Type>(x.precision(), x.scale());
auto y_type = std::make_shared<arrow::Decimal128Type>(y.precision(), y.scale());
auto field_x = field("x", x_type);
auto field_y = field("y", y_type);
auto schema = arrow::schema({field_x, field_y});

Decimal128TypePtr output_type;
auto status = DecimalTypeUtil::GetResultType(DecimalTypeUtil::kOpAdd, {x_type, y_type},
&output_type);
auto status = DecimalTypeUtil::GetResultType(op, {x_type, y_type}, &output_type);
EXPECT_OK(status);

// output fields
auto res = field("res", output_type);

// build expression : x + y
auto expr = TreeExprBuilder::MakeExpression("add", {field_x, field_y}, res);
// build expression : x op y
auto expr = TreeExprBuilder::MakeExpression(function, {field_x, field_y}, res);

// Build a projector for the expression.
std::shared_ptr<Projector> projector;
Expand All @@ -106,7 +118,7 @@ void TestDecimalOps::AddAndVerify(const DecimalScalar128& x, const DecimalScalar
std::string value_string = out_value.ToString(0);
DecimalScalar128 actual{value_string, dtype->precision(), dtype->scale()};

EXPECT_DECIMAL_SUM_EQUALS(x, y, expected, actual);
EXPECT_DECIMAL_RESULT(function, x, y, expected, actual);
}

TEST_F(TestDecimalOps, TestAdd) {
Expand Down Expand Up @@ -221,4 +233,24 @@ TEST_F(TestDecimalOps, TestAdd) {
decimal_literal("-10000992", 38, 7), // y
decimal_literal("-2001098", 38, 6));
}

// subtract is a wrapper over add. so, minimal tests are sufficient.
TEST_F(TestDecimalOps, TestSubtract) {
// fast-path
SubtractAndVerify(decimal_literal("201", 30, 3), // x
decimal_literal("301", 30, 3), // y
decimal_literal("-100", 31, 3)); // expected

// max precision
SubtractAndVerify(
decimal_literal("09999999999999999999999999999999000000", 38, 5), // x
decimal_literal("100", 38, 7), // y
decimal_literal("99999999999999999999999999999989999990", 38, 6));

// Mix of +ve and -ve
SubtractAndVerify(decimal_literal("-201", 30, 3), // x
decimal_literal("301", 30, 2), // y
decimal_literal("-3211", 32, 3)); // expected
}

} // namespace gandiva