Skip to content

Commit

Permalink
Add arg generator for decimal type (#9634)
Browse files Browse the repository at this point in the history
Summary:
When creating a decimal expression in the fuzzer test, argument types could be
generated with ArgumentTypeFuzzer. But it is not able to generate argument
types which meet the required constraints given a result type. To solve this
limitation, an argument type generator is introduced, which generates all
possible decimal input types, computes result types and stores these in a map
keyed on result type. This mapping is then used to generate input types for a
given result type.
Inspired by #9358.

Pull Request resolved: #9634

Reviewed By: kgpai

Differential Revision: D56910439

Pulled By: bikramSingh91

fbshipit-source-id: 938d28fecc4d527b95571302a054a62018915805
  • Loading branch information
rui-mo authored and facebook-github-bot committed May 3, 2024
1 parent 38abde9 commit e38079f
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 3 deletions.
38 changes: 38 additions & 0 deletions velox/expression/fuzzer/ArgGenerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/FunctionSignature.h"
#include "velox/vector/fuzzer/Utils.h"

namespace facebook::velox::fuzzer {

/// Generates random, but valid input types for a specified function signature
/// with the return type.
class ArgGenerator {
public:
virtual ~ArgGenerator() = default;

/// Given a signature and a concrete return type returns randomly selected
/// valid input types. Returns empty vector if no input types can produce the
/// specified result type.
virtual std::vector<TypePtr> generateArgs(
const exec::FunctionSignature& signature,
const TypePtr& returnType,
FuzzerGenerator& rng) = 0;
};

} // namespace facebook::velox::fuzzer
5 changes: 3 additions & 2 deletions velox/expression/fuzzer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ target_link_libraries(velox_expression_test_utility velox_type
velox_expression_functions gtest)

add_library(
velox_expression_fuzzer ArgumentTypeFuzzer.cpp ExpressionFuzzer.cpp
FuzzerRunner.cpp ExpressionFuzzerVerifier.cpp)
velox_expression_fuzzer
ArgumentTypeFuzzer.cpp DecimalArgGeneratorBase.cpp ExpressionFuzzer.cpp
FuzzerRunner.cpp ExpressionFuzzerVerifier.cpp)

target_link_libraries(
velox_expression_fuzzer
Expand Down
99 changes: 99 additions & 0 deletions velox/expression/fuzzer/DecimalArgGeneratorBase.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/fuzzer/DecimalArgGeneratorBase.h"
#include <boost/random/uniform_int_distribution.hpp>

namespace facebook::velox::fuzzer {
namespace {

// Returns all the possible decimal types.
const std::vector<TypePtr>& getAllTypes() {
const auto generateAllTypes = []() {
std::vector<TypePtr> allTypes;
for (auto p = 1; p <= 38; ++p) {
for (auto s = 0; s <= p; ++s) {
allTypes.push_back(DECIMAL(p, s));
}
}
return allTypes;
};

static const std::vector<TypePtr> allTypes = generateAllTypes();
return allTypes;
}

uint32_t rand32(uint32_t max, FuzzerGenerator& rng) {
return boost::random::uniform_int_distribution<uint32_t>()(rng) % max;
}
} // namespace

std::vector<TypePtr> DecimalArgGeneratorBase::generateArgs(
const exec::FunctionSignature& /*signature*/,
const TypePtr& returnType,
FuzzerGenerator& rng) {
auto inputs = findInputs(returnType, rng);
for (const auto& input : inputs) {
if (input == nullptr) {
return {};
}
}
return inputs;
}

void DecimalArgGeneratorBase::initialize(uint32_t numArgs) {
switch (numArgs) {
case 1: {
for (const auto& t : getAllTypes()) {
auto [p, s] = getDecimalPrecisionScale(*t);
if (auto returnType = toReturnType(p, s)) {
inputs_[returnType.value()].push_back({t});
}
}
break;
}
case 2: {
for (const auto& a : getAllTypes()) {
for (const auto& b : getAllTypes()) {
auto [p1, s1] = getDecimalPrecisionScale(*a);
auto [p2, s2] = getDecimalPrecisionScale(*b);

if (auto returnType = toReturnType(p1, s1, p2, s2)) {
inputs_[returnType.value()].push_back({a, b});
}
}
}
break;
}
default:
VELOX_NYI(
"Initialization with {} argument types is not supported.", numArgs);
}
}

std::vector<TypePtr> DecimalArgGeneratorBase::findInputs(
const TypePtr& returnType,
FuzzerGenerator& rng) const {
const auto [p, s] = getDecimalPrecisionScale(*returnType);
const auto it = inputs_.find({p, s});
if (it == inputs_.end()) {
VLOG(1) << "Cannot find input types for " << returnType->toString();
return {};
}

const auto index = rand32(it->second.size(), rng);
return it->second[index];
}
} // namespace facebook::velox::fuzzer
72 changes: 72 additions & 0 deletions velox/expression/fuzzer/DecimalArgGeneratorBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/fuzzer/ArgGenerator.h"

namespace facebook::velox::fuzzer {

/// An argument type generator for decimal function. A map keyed on the pair of
/// precision and scale could be initialized with all possible input types.
/// Argument types are generated by looking up the map with the precision and
/// scale of return type, and randomly selecting valid input types. Derived
/// classes should call 'initialize' from the constructor and specify the number
/// of decimal arguments. They should also implement toReturnType with matching
/// number of pairs of precision and scale.
class DecimalArgGeneratorBase : public ArgGenerator {
public:
std::vector<TypePtr> generateArgs(
const exec::FunctionSignature& signature,
const TypePtr& returnType,
FuzzerGenerator& rng) override;

protected:
// Computes result type for all possible pairs of decimal input types. Stores
// the results in 'inputs_' map keyed by the precision and scale of return
// type.
// @param numArgs the number of decimal argument types. It only supports
// initialization with one or two argument types.
virtual void initialize(uint32_t numArgs);

// Given precisions and scales of the inputs, returns precision and scale of
// the result. Returns std::nullopt if a valid return type cannot be generated
// with inputs. Used when the return type is generated with one pair of input
// precision and scale.
virtual std::optional<std::pair<int, int>> toReturnType(int p, int s) {
VELOX_UNREACHABLE();
}

// Used when the return type is generated with two pairs of input precision
// and scale.
virtual std::optional<std::pair<int, int>>
toReturnType(int p1, int s1, int p2, int s2) {
VELOX_UNREACHABLE();
}

private:
// Returns randomly selected pair of input types that produce the specified
// result type.
std::vector<TypePtr> findInputs(
const TypePtr& returnType,
FuzzerGenerator& rng) const;

// Maps from the precision and scale of return type to corresponding input
// types.
std::unordered_map<std::pair<int, int>, std::vector<std::vector<TypePtr>>>
inputs_;
};

} // namespace facebook::velox::fuzzer
2 changes: 1 addition & 1 deletion velox/expression/fuzzer/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

add_executable(velox_expression_fuzzer_unit_test ArgumentTypeFuzzerTest.cpp ExpressionFuzzerUnitTest.cpp)
add_executable(velox_expression_fuzzer_unit_test ArgumentTypeFuzzerTest.cpp DecimalArgGeneratorTest.cpp ExpressionFuzzerUnitTest.cpp)

target_link_libraries(
velox_expression_fuzzer_unit_test
Expand Down
128 changes: 128 additions & 0 deletions velox/expression/fuzzer/tests/DecimalArgGeneratorTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "velox/expression/SignatureBinder.h"
#include "velox/expression/fuzzer/DecimalArgGeneratorBase.h"

namespace facebook::velox::fuzzer::test {

class DecimalArgGeneratorTest : public testing::Test {
protected:
class UnaryArgGenerator : public DecimalArgGeneratorBase {
public:
UnaryArgGenerator() {
initialize(1);
}

protected:
std::optional<std::pair<int, int>> toReturnType(int p, int s) override {
auto precision = std::min(38, p + s + 1);
auto scale = std::min(s + 1, 18);
return {{precision, scale}};
}
};

class BinaryArgGenerator : public DecimalArgGeneratorBase {
public:
BinaryArgGenerator() {
initialize(2);
}

protected:
std::optional<std::pair<int, int>>
toReturnType(int p1, int s1, int p2, int s2) override {
auto s = std::max(s1, s2);
auto p = std::min(38, std::max(p1 - s1, p2 - s2) + std::max(s1, s2) + 1);
return {{p, s}};
}
};

// Assert the equivalence between the given return type and the actual type
// resolved from generated argument types.
void assertReturnType(
const std::shared_ptr<DecimalArgGeneratorBase>& generator,
const exec::FunctionSignature& signature,
const TypePtr& returnType) {
std::mt19937 seed{0};
const auto argTypes = generator->generateArgs(signature, returnType, seed);

// Resolve return type from argument types for the given signature.
TypePtr actualType;
exec::SignatureBinder binder(signature, argTypes);
if (binder.tryBind()) {
actualType = binder.tryResolveReturnType();
} else {
VELOX_FAIL("Failed to resolve return type from argument types.");
}
EXPECT_TRUE(returnType->equivalent(*actualType))
<< "Expected type: " << returnType->toString()
<< ", actual type: " << actualType->toString();
}

// Assert that no argument types can be generated for the given return type.
void assertEmptyArgs(
std::shared_ptr<DecimalArgGeneratorBase> generator,
const exec::FunctionSignature& signature,
const TypePtr& returnType) {
std::mt19937 seed{0};
const auto argTypes = generator->generateArgs(signature, returnType, seed);
EXPECT_TRUE(argTypes.empty());
}
};

TEST_F(DecimalArgGeneratorTest, unary) {
auto signature =
exec::FunctionSignatureBuilder()
.integerVariable("scale")
.integerVariable("precision")
.integerVariable("r_precision", "min(38, precision + scale + 1)")
.integerVariable("r_scale", "min(scale + 1, 18)")
.returnType("decimal(r_precision, r_scale)")
.argumentType("decimal(precision, scale)")
.build();

const auto generator = std::make_shared<UnaryArgGenerator>();
for (auto returnType : {DECIMAL(10, 2), DECIMAL(38, 18)}) {
assertReturnType(generator, *signature, returnType);
}
assertEmptyArgs(generator, *signature, DECIMAL(38, 20));
}

TEST_F(DecimalArgGeneratorTest, binary) {
auto signature =
exec::FunctionSignatureBuilder()
.integerVariable("a_scale")
.integerVariable("b_scale")
.integerVariable("a_precision")
.integerVariable("b_precision")
.integerVariable(
"r_precision",
"min(38, max(a_precision - a_scale, b_precision - b_scale) + max(a_scale, b_scale) + 1)")
.integerVariable("r_scale", "max(a_scale, b_scale)")
.returnType("decimal(r_precision, r_scale)")
.argumentType("decimal(a_precision, a_scale)")
.argumentType("decimal(b_precision, b_scale)")
.build();

const auto generator = std::make_shared<BinaryArgGenerator>();
for (auto returnType :
{DECIMAL(10, 2), DECIMAL(38, 20), DECIMAL(38, 38), DECIMAL(38, 0)}) {
assertReturnType(generator, *signature, returnType);
}
}

} // namespace facebook::velox::fuzzer::test

0 comments on commit e38079f

Please sign in to comment.