Skip to content

Commit

Permalink
[fix](function) fix AES/SM3/SM4 encrypt/ decrypt algorithm initializa…
Browse files Browse the repository at this point in the history
…tion vector bug (apache#17420)

ECB algorithm, block_encryption_mode does not take effect, it only takes effect when init vector is provided.
Solved: 192/256 supports calculation without init vector

For other algorithms, an error should be reported when there is no init vector

Initialization Vector. The default value for the block_encryption_mode system variable is aes-128-ecb, or ECB mode, which does not require an initialization vector. The alternative permitted block encryption modes CBC, CFB1, CFB8, CFB128, and OFB all require an initialization vector.

Reference: https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_aes-decrypt

Note: This fix does not support smooth upgrades. during upgrade process, query may report error: funciton not found
  • Loading branch information
xinyiZzz authored Mar 9, 2023
1 parent 8a6a4b8 commit 397cc01
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 195 deletions.
14 changes: 13 additions & 1 deletion be/src/vec/functions/function_encryption.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ static void exectue_result(std::vector<const ColumnString::Offsets*>& offsets_li
template <typename Impl, EncryptionMode mode, bool is_encrypt>
struct EncryptionAndDecryptTwoImpl {
static DataTypes get_variadic_argument_types_impl() {
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()};
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>()};
}

static Status vector_vector(std::vector<const ColumnString::Offsets*>& offsets_list,
Expand All @@ -167,6 +168,17 @@ struct EncryptionAndDecryptTwoImpl {
continue;
}
EncryptionMode encryption_mode = mode;
int mode_size = (*offsets_list[2])[i] - (*offsets_list[2])[i - 1];
const auto mode_raw =
reinterpret_cast<const char*>(&(*chars_list[2])[(*offsets_list[2])[i - 1]]);
if (mode_size != 0) {
std::string mode_str(mode_raw, mode_size);
if (aes_mode_map.count(mode_str) == 0) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
continue;
}
encryption_mode = aes_mode_map.at(mode_str);
}
exectue_result<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode, nullptr,
0, result_data, result_offset, null_map);
}
Expand Down
97 changes: 59 additions & 38 deletions be/test/vec/function/function_string_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,9 @@ TEST(function_string_test, function_sm3sum_test) {
TEST(function_string_test, function_aes_encrypt_test) {
std::string func_name = "aes_encrypt";
{
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String, TypeIndex::String};

const char* mode = "AES_128_ECB";
const char* key = "doris";
const char* src[6] = {"aaaaaa", "bbbbbb", "cccccc", "dddddd", "eeeeee", ""};
std::string r[5];
Expand All @@ -692,13 +693,13 @@ TEST(function_string_test, function_aes_encrypt_test) {
r[i] = std::string(p, outlen);
}

DataSet data_set = {{{std::string(src[0]), std::string(key)}, r[0]},
{{std::string(src[1]), std::string(key)}, r[1]},
{{std::string(src[2]), std::string(key)}, r[2]},
{{std::string(src[3]), std::string(key)}, r[3]},
{{std::string(src[4]), std::string(key)}, r[4]},
{{std::string(src[5]), std::string(key)}, Null()},
{{Null(), std::string(key)}, Null()}};
DataSet data_set = {{{std::string(src[0]), std::string(key), std::string(mode)}, r[0]},
{{std::string(src[1]), std::string(key), std::string(mode)}, r[1]},
{{std::string(src[2]), std::string(key), std::string(mode)}, r[2]},
{{std::string(src[3]), std::string(key), std::string(mode)}, r[3]},
{{std::string(src[4]), std::string(key), std::string(mode)}, r[4]},
{{std::string(src[5]), std::string(key), std::string(mode)}, Null()},
{{Null(), std::string(key), std::string(mode)}, Null()}};

check_function<DataTypeString, true>(func_name, input_types, data_set);
}
Expand Down Expand Up @@ -743,8 +744,9 @@ TEST(function_string_test, function_aes_encrypt_test) {
TEST(function_string_test, function_aes_decrypt_test) {
std::string func_name = "aes_decrypt";
{
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String, TypeIndex::String};

const char* mode = "AES_128_ECB";
const char* key = "doris";
const char* src[5] = {"aaaaaa", "bbbbbb", "cccccc", "dddddd", "eeeeee"};
std::string r[5];
Expand All @@ -759,12 +761,12 @@ TEST(function_string_test, function_aes_decrypt_test) {
r[i] = std::string(p, outlen);
}

DataSet data_set = {{{r[0], std::string(key)}, std::string(src[0])},
{{r[1], std::string(key)}, std::string(src[1])},
{{r[2], std::string(key)}, std::string(src[2])},
{{r[3], std::string(key)}, std::string(src[3])},
{{r[4], std::string(key)}, std::string(src[4])},
{{Null(), std::string(key)}, Null()}};
DataSet data_set = {{{r[0], std::string(key), std::string(mode)}, std::string(src[0])},
{{r[1], std::string(key), std::string(mode)}, std::string(src[1])},
{{r[2], std::string(key), std::string(mode)}, std::string(src[2])},
{{r[3], std::string(key), std::string(mode)}, std::string(src[3])},
{{r[4], std::string(key), std::string(mode)}, std::string(src[4])},
{{Null(), std::string(key), std::string(mode)}, Null()}};

check_function<DataTypeString, true>(func_name, input_types, data_set);
}
Expand Down Expand Up @@ -806,29 +808,39 @@ TEST(function_string_test, function_aes_decrypt_test) {
TEST(function_string_test, function_sm4_encrypt_test) {
std::string func_name = "sm4_encrypt";
{
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String, TypeIndex::String,
TypeIndex::String};

const char* key = "doris";
const char* iv = "0123456789abcdef";
const char* mode = "SM4_128_ECB";
const char* src[6] = {"aaaaaa", "bbbbbb", "cccccc", "dddddd", "eeeeee", ""};
std::string r[5];

for (int i = 0; i < 5; i++) {
int cipher_len = strlen(src[i]) + 16;
char p[cipher_len];

int outlen = EncryptionUtil::encrypt(
EncryptionMode::SM4_128_ECB, (unsigned char*)src[i], strlen(src[i]),
(unsigned char*)key, strlen(key), nullptr, 0, true, (unsigned char*)p);
int iv_len = 32;
std::unique_ptr<char[]> init_vec;
init_vec.reset(new char[iv_len]);
std::memset(init_vec.get(), 0, strlen(iv) + 1);
memcpy(init_vec.get(), iv, strlen(iv));
int outlen =
EncryptionUtil::encrypt(EncryptionMode::SM4_128_ECB, (unsigned char*)src[i],
strlen(src[i]), (unsigned char*)key, strlen(key),
init_vec.get(), strlen(iv), true, (unsigned char*)p);
r[i] = std::string(p, outlen);
}

DataSet data_set = {{{std::string(src[0]), std::string(key)}, r[0]},
{{std::string(src[1]), std::string(key)}, r[1]},
{{std::string(src[2]), std::string(key)}, r[2]},
{{std::string(src[3]), std::string(key)}, r[3]},
{{std::string(src[4]), std::string(key)}, r[4]},
{{std::string(src[5]), std::string(key)}, Null()},
{{Null(), std::string(key)}, Null()}};
DataSet data_set = {
{{std::string(src[0]), std::string(key), std::string(iv), std::string(mode)}, r[0]},
{{std::string(src[1]), std::string(key), std::string(iv), std::string(mode)}, r[1]},
{{std::string(src[2]), std::string(key), std::string(iv), std::string(mode)}, r[2]},
{{std::string(src[3]), std::string(key), std::string(iv), std::string(mode)}, r[3]},
{{std::string(src[4]), std::string(key), std::string(iv), std::string(mode)}, r[4]},
{{std::string(src[5]), std::string(key), std::string(iv), std::string(mode)},
Null()},
{{Null(), std::string(key), std::string(iv), std::string(mode)}, Null()}};

check_function<DataTypeString, true>(func_name, input_types, data_set);
}
Expand Down Expand Up @@ -875,28 +887,37 @@ TEST(function_string_test, function_sm4_encrypt_test) {
TEST(function_string_test, function_sm4_decrypt_test) {
std::string func_name = "sm4_decrypt";
{
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
InputTypeSet input_types = {TypeIndex::String, TypeIndex::String, TypeIndex::String,
TypeIndex::String};

const char* key = "doris";
const char* iv = "0123456789abcdef";
const char* mode = "SM4_128_ECB";
const char* src[5] = {"aaaaaa", "bbbbbb", "cccccc", "dddddd", "eeeeee"};
std::string r[5];

for (int i = 0; i < 5; i++) {
int cipher_len = strlen(src[i]) + 16;
char p[cipher_len];

int outlen = EncryptionUtil::encrypt(
EncryptionMode::SM4_128_ECB, (unsigned char*)src[i], strlen(src[i]),
(unsigned char*)key, strlen(key), nullptr, 0, true, (unsigned char*)p);
int iv_len = 32;
std::unique_ptr<char[]> init_vec;
init_vec.reset(new char[iv_len]);
std::memset(init_vec.get(), 0, strlen(iv) + 1);
memcpy(init_vec.get(), iv, strlen(iv));
int outlen =
EncryptionUtil::encrypt(EncryptionMode::SM4_128_ECB, (unsigned char*)src[i],
strlen(src[i]), (unsigned char*)key, strlen(key),
init_vec.get(), strlen(iv), true, (unsigned char*)p);
r[i] = std::string(p, outlen);
}

DataSet data_set = {{{r[0], std::string(key)}, std::string(src[0])},
{{r[1], std::string(key)}, std::string(src[1])},
{{r[2], std::string(key)}, std::string(src[2])},
{{r[3], std::string(key)}, std::string(src[3])},
{{r[4], std::string(key)}, std::string(src[4])},
{{Null(), std::string(key)}, Null()}};
DataSet data_set = {
{{r[0], std::string(key), std::string(iv), std::string(mode)}, std::string(src[0])},
{{r[1], std::string(key), std::string(iv), std::string(mode)}, std::string(src[1])},
{{r[2], std::string(key), std::string(iv), std::string(mode)}, std::string(src[2])},
{{r[3], std::string(key), std::string(iv), std::string(mode)}, std::string(src[3])},
{{r[4], std::string(key), std::string(iv), std::string(mode)}, std::string(src[4])},
{{Null(), std::string(key), std::string(iv), std::string(mode)}, Null()}};

check_function<DataTypeString, true>(func_name, input_types, data_set);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ AES_ENCRYPT
### description

Encryption of data using the OpenSSL. This function is consistent with the `AES_ENCRYPT` function in MySQL. Using AES_128_ECB algorithm by default, and the padding mode is PKCS7.
Reference: https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_aes-decrypt

#### Syntax

Expand All @@ -42,7 +43,7 @@ AES_ENCRYPT(str,key_str[,init_vector])

- `str`: Content to be encrypted
- `key_str`: Secret key
- `init_vector`: Initialization Vector
- `init_vector`: Initialization Vector. The default value for the block_encryption_mode system variable is aes ecb mode, which does not require an initialization vector. The alternative permitted block encryption modes CBC, CFB1, CFB8, CFB128, and OFB all require an initialization vector.

#### Return Type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ AES_ENCRYPT
### Description

Aes 加密函数。该函数与 MySQL 中的 `AES_ENCRYPT` 函数行为一致。默认采用 AES_128_ECB 算法,padding 模式为 PKCS7。底层使用 OpenSSL 库进行加密。
Reference: https://dev.mysql.com/doc/refman/8.0/en/encryption-functions.html#function_aes-decrypt

#### Syntax

Expand All @@ -42,7 +43,7 @@ AES_ENCRYPT(str,key_str[,init_vector])

- `str`: 待加密的内容
- `key_str`: 密钥
- `init_vector`: 初始向量
- `init_vector`: 初始向量。block_encryption_mode 默认值为 aes-128-ecb,它不需要初始向量,可选的块加密模式 CBC、CFB1、CFB8、CFB128 和 OFB 都需要一个初始向量。

#### Return Type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ private void analyzeBuiltinAggFunction(Analyzer analyzer) throws AnalysisExcepti
|| fnName.getFunction().equalsIgnoreCase("aes_encrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))
&& children.size() == 3) {
&& (children.size() == 2 || children.size() == 3)) {
String blockEncryptionMode = "";
Set<String> aesModes = new HashSet<>(Arrays.asList(
"AES_128_ECB",
Expand Down Expand Up @@ -985,6 +985,12 @@ private void analyzeBuiltinAggFunction(Analyzer analyzer) throws AnalysisExcepti
throw new AnalysisException("session variable block_encryption_mode is invalid with aes");

}
if (children.size() == 2 && !blockEncryptionMode.toUpperCase().equals("AES_128_ECB")
&& !blockEncryptionMode.toUpperCase().equals("AES_192_ECB")
&& !blockEncryptionMode.toUpperCase().equals("AES_256_ECB")) {
throw new AnalysisException("Incorrect parameter count in the call to native function "
+ "'aes_encrypt' or 'aes_decrypt'");
}
}
if (fnName.getFunction().equalsIgnoreCase("sm4_decrypt")
|| fnName.getFunction().equalsIgnoreCase("sm4_encrypt")) {
Expand All @@ -995,6 +1001,10 @@ private void analyzeBuiltinAggFunction(Analyzer analyzer) throws AnalysisExcepti
throw new AnalysisException("session variable block_encryption_mode is invalid with sm4");

}
if (children.size() == 2) {
throw new AnalysisException("Incorrect parameter count in the call to native function "
+ "'sm4_encrypt' or 'sm4_decrypt'");
}
}
}
children.add(new StringLiteral(blockEncryptionMode));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public BoundFunction build(String name, List<? extends Object> arguments) {
})
.collect(Collectors.joining(", ", "(", ")"));
throw new IllegalStateException("Can not build function: '" + name
+ "', expression: " + name + argString, t);
+ "', expression: " + name + argString + ", " + t.getCause().getMessage(), t);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,18 @@ public class AesDecrypt extends AesCryptoFunction {
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE)
);

/**
* Some javadoc for checkstyle...
*/
public AesDecrypt(Expression arg0, Expression arg1) {
super("aes_decrypt", arg0, arg1);
super("aes_decrypt", arg0, arg1, getDefaultBlockEncryptionMode());
String blockEncryptionMode = String.valueOf(getDefaultBlockEncryptionMode());
if (!blockEncryptionMode.toUpperCase().equals("'AES_128_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_192_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_256_ECB'")) {
throw new AnalysisException("Incorrect parameter count in the call to native function "
+ "'aes_encrypt' or 'aes_decrypt'");
}
}

public AesDecrypt(Expression arg0, Expression arg1, Expression arg2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,18 @@ public class AesEncrypt extends AesCryptoFunction {
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE)
);

/**
* Some javadoc for checkstyle...
*/
public AesEncrypt(Expression arg0, Expression arg1) {
super("aes_encrypt", arg0, arg1);
super("aes_encrypt", arg0, arg1, getDefaultBlockEncryptionMode());
String blockEncryptionMode = String.valueOf(getDefaultBlockEncryptionMode());
if (!blockEncryptionMode.toUpperCase().equals("'AES_128_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_192_ECB'")
&& !blockEncryptionMode.toUpperCase().equals("'AES_256_ECB'")) {
throw new AnalysisException("Incorrect parameter count in the call to native function "
+ "'aes_encrypt' or 'aes_decrypt'");
}
}

public AesEncrypt(Expression arg0, Expression arg1, Expression arg2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ public class Sm4Decrypt extends Sm4CryptoFunction {
*/
public Sm4Decrypt(Expression arg0, Expression arg1) {
super("sm4_decrypt", arg0, arg1);
throw new AnalysisException("Incorrect parameter count in the call to native function "
+ "'sm4_encrypt' or 'sm4_decrypt'");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ public class Sm4Encrypt extends Sm4CryptoFunction {
*/
public Sm4Encrypt(Expression arg0, Expression arg1) {
super("sm4_encrypt", arg0, arg1);
throw new AnalysisException("Incorrect parameter count in the call to native function "
+ "'sm4_encrypt' or 'sm4_decrypt'");
}

/**
Expand Down
16 changes: 8 additions & 8 deletions gensrc/script/doris_builtins_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,21 +1517,21 @@
[['murmur_hash3_64'], 'BIGINT', ['STRING', '...'], ''],

# aes and base64 function
[['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['sm4_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['sm4_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['sm4_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
[['from_base64'], 'VARCHAR', ['VARCHAR'], 'ALWAYS_NULLABLE'],
[['aes_encrypt'], 'STRING', ['STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['aes_decrypt'], 'STRING', ['STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['sm4_encrypt'], 'STRING', ['STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['sm4_decrypt'], 'STRING', ['STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['sm4_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['sm4_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
[['from_base64'], 'STRING', ['STRING'], 'ALWAYS_NULLABLE'],
Expand Down
Loading

0 comments on commit 397cc01

Please sign in to comment.