Skip to content

Commit

Permalink
Add rand function for seed in int64 type (apache#78)
Browse files Browse the repository at this point in the history
* Initial commit

* Add int64 seed support

* Add unit test cases
  • Loading branch information
PHILO-HE authored and zhztheplayer committed Feb 28, 2022
1 parent 5f86d70 commit 8dfaa7a
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 3 deletions.
3 changes: 3 additions & 0 deletions cpp/src/gandiva/function_registry_math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
"gdv_fn_random", NativeFunction::kNeedsFunctionHolder),
NativeFunction("random", {"rand"}, DataTypeVector{int32()}, float64(),
kResultNullNever, "gdv_fn_random_with_seed",
NativeFunction::kNeedsFunctionHolder),
NativeFunction("random", {"rand"}, DataTypeVector{int64()}, float64(),
kResultNullNever, "gdv_fn_random_with_seed64",
NativeFunction::kNeedsFunctionHolder)};

return math_fn_registry_;
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity) {
return (*holder)();
}

double gdv_fn_random_with_seed64(int64_t ptr, int64_t seed, bool seed_validity) {
gandiva::RandomGeneratorHolder* holder =
reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr);
return (*holder)();
}

int64_t gdv_fn_to_date_utf8_utf8(int64_t context_ptr, int64_t holder_ptr,
const char* data, int data_len, bool in1_validity,
const char* pattern, int pattern_len, bool in2_validity,
Expand Down Expand Up @@ -1512,6 +1518,10 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed", types->double_type(), args,
reinterpret_cast<void*>(gdv_fn_random_with_seed));

args = {types->i64_type(), types->i64_type(), types->i1_type()};
engine->AddGlobalMappingForFunc("gdv_fn_random_with_seed64", types->double_type(), args,
reinterpret_cast<void*>(gdv_fn_random_with_seed64));

args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type()}; // int32_t lenr
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ bool in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, bool in_va
int gdv_fn_time_with_zone(int* time_fields, const char* zone, int zone_len,
int64_t* ret_time);

double gdv_fn_random(int64_t ptr);

double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity);

double gdv_fn_random_with_seed64(int64_t ptr, int64_t seed, bool seed_validity);

GANDIVA_EXPORT
const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t in_len,
int32_t* out_len);
Expand Down
11 changes: 8 additions & 3 deletions cpp/src/gandiva/random_generator_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,16 @@ Status RandomGeneratorHolder::Make(const FunctionNode& node,

auto literal_type = literal->return_type()->id();
ARROW_RETURN_IF(
literal_type != arrow::Type::INT32,
Status::Invalid("'random' function requires an int32 literal as parameter"));
literal_type != arrow::Type::INT32 && literal_type != arrow::Type::INT64,
Status::Invalid("'random' function requires an int32/int64 literal as parameter"));

*holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder(
if (literal_type == arrow::Type::INT32) {
*holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder(
literal->is_null() ? 0 : arrow::util::get<int32_t>(literal->holder())));
} else {
*holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder(
literal->is_null() ? 0 : arrow::util::get<int64_t>(literal->holder())));
}
return Status::OK();
}
} // namespace gandiva
5 changes: 5 additions & 0 deletions cpp/src/gandiva/random_generator_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class GANDIVA_EXPORT RandomGeneratorHolder : public FunctionHolder {
generator_.seed(static_cast<uint64_t>(seed64));
}

explicit RandomGeneratorHolder(int64_t seed64) : distribution_(0, 1) {
seed64 = (seed64 ^ 0x00000005DEECE66D) & 0x0000ffffffffffff;
generator_.seed(static_cast<uint64_t>(seed64));
}

RandomGeneratorHolder() : distribution_(0, 1) {
generator_.seed(::arrow::internal::GetRandomSeed());
}
Expand Down
21 changes: 21 additions & 0 deletions cpp/src/gandiva/random_generator_holder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ TEST_F(TestRandGenHolder, WithValidSeeds) {
EXPECT_NE(random_1(), random_2());
}

TEST_F(TestRandGenHolder, WithValidSeedsInLongType) {
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1;
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_3;
FunctionNode rand_func_1 = BuildRandWithSeedFunc(100L, false);
FunctionNode rand_func_2 = BuildRandWithSeedFunc(1000L, false);
FunctionNode rand_func_3 = BuildRandWithSeedFunc(100000L, false);
auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1);
EXPECT_EQ(status.ok(), true) << status.message();
status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2);
EXPECT_EQ(status.ok(), true) << status.message();
status = RandomGeneratorHolder::Make(rand_func_3, &rand_gen_holder_3);
EXPECT_EQ(status.ok(), true) << status.message();

auto& random_1 = *rand_gen_holder_1;
auto& random_2 = *rand_gen_holder_2;
auto& random_3 = *rand_gen_holder_3;
EXPECT_NE(random_2(), random_3());
EXPECT_NE(random_1(), random_2());
}

TEST_F(TestRandGenHolder, WithInValidSeed) {
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1;
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;
Expand Down

0 comments on commit 8dfaa7a

Please sign in to comment.