This is an automated email from the ASF dual-hosted git repository. yangzhg pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push: new 246ac4e [fix] fix a bug of encryption function with iv may return wrong result (#8277) 246ac4e is described below commit 246ac4e37aa4da6836b7850cb990f02d1c3725a3 Author: Zhengguo Yang <yangz...@gmail.com> AuthorDate: Wed Mar 2 17:26:44 2022 +0800 [fix] fix a bug of encryption function with iv may return wrong result (#8277) --- be/src/exprs/encryption_functions.cpp | 93 +++++++++++++---------------------- be/src/exprs/minmax_predicate.h | 4 +- be/src/udf/udf.cpp | 3 ++ be/src/udf/udf.h | 3 ++ be/src/util/encryption_util.cpp | 37 +++++++++++--- be/src/util/encryption_util.h | 4 +- be/test/util/encryption_util_test.cpp | 91 +++++++++++++++++++++++++++++++++- 7 files changed, 164 insertions(+), 71 deletions(-) diff --git a/be/src/exprs/encryption_functions.cpp b/be/src/exprs/encryption_functions.cpp index 5d919a3..19ec1a7 100644 --- a/be/src/exprs/encryption_functions.cpp +++ b/be/src/exprs/encryption_functions.cpp @@ -55,32 +55,23 @@ StringVal encrypt(FunctionContext* ctx, const StringVal& src, const StringVal& k if (src.len == 0 || src.is_null) { return StringVal::null(); } + /* + * Buffer for ciphertext. Ensure the buffer is long enough for the + * ciphertext which may be longer than the plaintext, depending on the + * algorithm and mode. + */ + int cipher_len = src.len + 16; - std::unique_ptr<char[]> p; - p.reset(new char[cipher_len]); - int ret_code = 0; - if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB && - mode != SM4_128_ECB) { - if (iv.len == 0 || iv.is_null) { - return StringVal::null(); - } - int iv_len = 32; // max key length 256 / 8 - std::unique_ptr<char[]> init_vec; - init_vec.reset(new char[iv_len]); - std::memset(init_vec.get(), 0, iv.len + 1); - memcpy(init_vec.get(), iv.ptr, iv.len); - ret_code = EncryptionUtil::encrypt( - mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len, - (unsigned char*)init_vec.get(), true, (unsigned char*)p.get()); - } else { - ret_code = EncryptionUtil::encrypt(mode, (unsigned char*)src.ptr, src.len, - (unsigned char*)key.ptr, key.len, nullptr, true, - (unsigned char*)p.get()); - } - if (ret_code < 0) { + std::unique_ptr<char[]> cipher_text; + cipher_text.reset(new char[cipher_len]); + int cipher_text_len = 0; + cipher_text_len = EncryptionUtil::encrypt(mode, (unsigned char*)src.ptr, src.len, + (unsigned char*)key.ptr, key.len, (char*)iv.ptr, true, + (unsigned char*)cipher_text.get()); + if (cipher_text_len < 0) { return StringVal::null(); } - return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code); + return AnyValUtil::from_buffer_temp(ctx, cipher_text.get(), cipher_text_len); } StringVal decrypt(FunctionContext* ctx, const StringVal& src, const StringVal& key, @@ -89,31 +80,16 @@ StringVal decrypt(FunctionContext* ctx, const StringVal& src, const StringVal& k return StringVal::null(); } int cipher_len = src.len; - std::unique_ptr<char[]> p; - p.reset(new char[cipher_len]); - int ret_code = 0; - if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB && - mode != SM4_128_ECB) { - if (iv.len == 0 || iv.is_null) { - return StringVal::null(); - } - int iv_len = 32; // max key length 256 / 8 - std::unique_ptr<char[]> init_vec; - init_vec.reset(new char[iv_len]); - std::memset(init_vec.get(), 0, iv.len + 1); - memcpy(init_vec.get(), iv.ptr, iv.len); - ret_code = EncryptionUtil::decrypt( - mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len, - (unsigned char*)init_vec.get(), true, (unsigned char*)p.get()); - } else { - ret_code = EncryptionUtil::decrypt(mode, (unsigned char*)src.ptr, src.len, - (unsigned char*)key.ptr, key.len, nullptr, true, - (unsigned char*)p.get()); - } - if (ret_code < 0) { + std::unique_ptr<char[]> plain_text; + plain_text.reset(new char[cipher_len]); + int plain_text_len = 0; + plain_text_len = + EncryptionUtil::decrypt(mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, + key.len, (char*)iv.ptr, true, (unsigned char*)plain_text.get()); + if (plain_text_len < 0) { return StringVal::null(); } - return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code); + return AnyValUtil::from_buffer_temp(ctx, plain_text.get(), plain_text_len); } StringVal EncryptionFunctions::aes_encrypt(FunctionContext* ctx, const StringVal& src, @@ -197,15 +173,15 @@ StringVal EncryptionFunctions::from_base64(FunctionContext* ctx, const StringVal return StringVal::null(); } - int cipher_len = src.len; - std::unique_ptr<char[]> p; - p.reset(new char[cipher_len]); + int encoded_len = src.len; + std::unique_ptr<char[]> plain_text; + plain_text.reset(new char[encoded_len]); - int ret_code = base64_decode((const char*)src.ptr, src.len, p.get()); - if (ret_code < 0) { + int plain_text_len = base64_decode((const char*)src.ptr, src.len, plain_text.get()); + if (plain_text_len < 0) { return StringVal::null(); } - return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code); + return AnyValUtil::from_buffer_temp(ctx, plain_text.get(), plain_text_len); } StringVal EncryptionFunctions::to_base64(FunctionContext* ctx, const StringVal& src) { @@ -213,15 +189,16 @@ StringVal EncryptionFunctions::to_base64(FunctionContext* ctx, const StringVal& return StringVal::null(); } - int cipher_len = (size_t)(4.0 * ceil((double)src.len / 3.0)); - std::unique_ptr<char[]> p; - p.reset(new char[cipher_len]); + int encoded_len = (size_t)(4.0 * ceil((double)src.len / 3.0)); + std::unique_ptr<char[]> encoded_text; + encoded_text.reset(new char[encoded_len]); - int ret_code = base64_encode((unsigned char*)src.ptr, src.len, (unsigned char*)p.get()); - if (ret_code < 0) { + int encoded_text_len = + base64_encode((unsigned char*)src.ptr, src.len, (unsigned char*)encoded_text.get()); + if (encoded_text_len < 0) { return StringVal::null(); } - return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code); + return AnyValUtil::from_buffer_temp(ctx, encoded_text.get(), encoded_text_len); } StringVal EncryptionFunctions::md5sum(FunctionContext* ctx, int num_args, const StringVal* args) { diff --git a/be/src/exprs/minmax_predicate.h b/be/src/exprs/minmax_predicate.h index 3a9ff5b..2c8140d 100644 --- a/be/src/exprs/minmax_predicate.h +++ b/be/src/exprs/minmax_predicate.h @@ -25,7 +25,6 @@ namespace doris { // only used in Runtime Filter class MinMaxFuncBase { public: - virtual ~MinMaxFuncBase() = default; virtual void insert(const void* data) = 0; virtual bool find(void* data) = 0; virtual bool is_empty() = 0; @@ -35,6 +34,7 @@ public: virtual Status assign(void* min_data, void* max_data) = 0; // merge from other minmax_func virtual Status merge(MinMaxFuncBase* minmax_func, ObjectPool* pool) = 0; + virtual ~MinMaxFuncBase() = default; }; template <class T> @@ -114,4 +114,4 @@ private: bool _empty = true; }; -} // namespace doris \ No newline at end of file +} // namespace doris diff --git a/be/src/udf/udf.cpp b/be/src/udf/udf.cpp index eae0bf1..612aabb 100644 --- a/be/src/udf/udf.cpp +++ b/be/src/udf/udf.cpp @@ -562,4 +562,7 @@ void* FunctionContext::get_function_state(FunctionStateScope scope) const { return nullptr; } } +std::ostream& operator<<(std::ostream& os, const StringVal& string_val) { + return os << string_val.to_string(); +} } // namespace doris_udf diff --git a/be/src/udf/udf.h b/be/src/udf/udf.h index c262c8c..481f19a 100644 --- a/be/src/udf/udf.h +++ b/be/src/udf/udf.h @@ -21,6 +21,7 @@ #include <string.h> #include <cstdint> +#include <iostream> #include <vector> // This is the only Doris header required to develop UDFs and UDAs. This header @@ -656,7 +657,9 @@ struct StringVal : public AnyVal { void append(FunctionContext* ctx, const uint8_t* buf, int64_t len); void append(FunctionContext* ctx, const uint8_t* buf, int64_t len, const uint8_t* buf2, int64_t buf2_len); + std::string to_string() const { return std::string((char*)ptr, len); } }; +std::ostream& operator<<(std::ostream& os, const StringVal& string_val); struct DecimalV2Val : public AnyVal { __int128 val; diff --git a/be/src/util/encryption_util.cpp b/be/src/util/encryption_util.cpp index eb95cf5..b9396e9 100644 --- a/be/src/util/encryption_util.cpp +++ b/be/src/util/encryption_util.cpp @@ -21,7 +21,9 @@ #include <openssl/evp.h> #include <openssl/ossl_typ.h> #include <sys/types.h> + #include <cstring> +#include <string> namespace doris { @@ -171,20 +173,29 @@ static int do_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher, int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length, const unsigned char* key, uint32_t key_length, - const unsigned char* iv, bool padding, unsigned char* encrypt) { + const char* iv_str, bool padding, unsigned char* encrypt) { const EVP_CIPHER* cipher = get_evp_type(mode); /* The encrypt key to be used for encryption */ unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8]; create_key(key, key_length, encrypt_key, mode); - if (cipher == nullptr || (EVP_CIPHER_iv_length(cipher) > 0 && !iv)) { + int iv_length = EVP_CIPHER_iv_length(cipher); + if (cipher == nullptr || (iv_length > 0 && !iv_str)) { return AES_BAD_DATA; } + char* init_vec = nullptr; + std::string iv_default("DORISDORISDORIS_"); + + if (iv_str) { + init_vec = &iv_default[0]; + memcpy(init_vec, iv_str, strnlen(iv_str, EVP_MAX_IV_LENGTH)); + init_vec[iv_length] = '\0'; + } EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new(); EVP_CIPHER_CTX_reset(cipher_ctx); int length = 0; - int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key, iv, padding, - encrypt, &length); + int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key, + reinterpret_cast<unsigned char*>(init_vec), padding, encrypt, &length); EVP_CIPHER_CTX_free(cipher_ctx); if (ret == 0) { ERR_clear_error(); @@ -219,21 +230,31 @@ static int do_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher, int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length, const unsigned char* key, uint32_t key_length, - const unsigned char* iv, bool padding, unsigned char* decrypt_content) { + const char* iv_str, bool padding, unsigned char* decrypt_content) { const EVP_CIPHER* cipher = get_evp_type(mode); /* The encrypt key to be used for decryption */ unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8]; create_key(key, key_length, encrypt_key, mode); - if (cipher == nullptr || (EVP_CIPHER_iv_length(cipher) > 0 && !iv)) { + int iv_length = EVP_CIPHER_iv_length(cipher); + if (cipher == nullptr || (iv_length > 0 && !iv_str)) { return AES_BAD_DATA; } + char* init_vec = nullptr; + std::string iv_default("DORISDORISDORIS_"); + + if (iv_str) { + init_vec = &iv_default[0]; + memcpy(init_vec, iv_str, strnlen(iv_str, EVP_MAX_IV_LENGTH)); + init_vec[iv_length] = '\0'; + } EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new(); EVP_CIPHER_CTX_reset(cipher_ctx); int length = 0; - int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, iv, padding, - decrypt_content, &length); + int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, + reinterpret_cast<unsigned char*>(init_vec), padding, decrypt_content, + &length); EVP_CIPHER_CTX_free(cipher_ctx); if (ret > 0) { return length; diff --git a/be/src/util/encryption_util.h b/be/src/util/encryption_util.h index e051d28..711a817 100644 --- a/be/src/util/encryption_util.h +++ b/be/src/util/encryption_util.h @@ -58,11 +58,11 @@ enum EncryptionState { AES_SUCCESS = 0, AES_BAD_DATA = -1 }; class EncryptionUtil { public: static int encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length, - const unsigned char* key, uint32_t key_length, const unsigned char* iv, + const unsigned char* key, uint32_t key_length, const char* iv_str, bool padding, unsigned char* encrypt); static int decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length, - const unsigned char* key, uint32_t key_length, const unsigned char* iv, + const unsigned char* key, uint32_t key_length, const char* iv_str, bool padding, unsigned char* decrypt_content); }; diff --git a/be/test/util/encryption_util_test.cpp b/be/test/util/encryption_util_test.cpp index 30c9752..2f30ade 100644 --- a/be/test/util/encryption_util_test.cpp +++ b/be/test/util/encryption_util_test.cpp @@ -117,7 +117,6 @@ TEST_F(EncryptionUtilTest, sm4_test_by_case) { std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]); int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get()); - std::cout << encrypt_1.get(); std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]); int ret_code = EncryptionUtil::decrypt(SM4_128_ECB, (unsigned char*)encrypt_1.get(), length_1, (unsigned char*)_aes_key.c_str(), _aes_key.length(), @@ -137,6 +136,96 @@ TEST_F(EncryptionUtilTest, sm4_test_by_case) { ASSERT_EQ(source_2, decrypted_content_2); } +TEST_F(EncryptionUtilTest, aes_with_iv_test_by_case) { + std::string case_1 = "XbJgw1AxBNwZZPpvzPtWyg=="; // base64 for encrypted "hello, doris" + std::string source_1 = "hello, doris"; + std::string case_2 = "gpKcO/iwgeRCIWBQdkpAkQ=="; // base64 for encrypted "doris test" + std::string source_2 = "doris test"; + std::string iv = "doris"; + + std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]); + int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get()); + std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]); + int ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_1.get(), length_1, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_1.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_1(decrypted_1.get(), ret_code); + ASSERT_EQ(source_1, decrypted_content_1); + std::unique_ptr<char[]> decrypted_11(new char[case_1.length()]); + + ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_1.get(), length_1, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_11.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_11(decrypted_11.get(), ret_code); + ASSERT_EQ(source_1, decrypted_content_11); + + std::unique_ptr<char[]> encrypt_2(new char[case_2.length()]); + int length_2 = base64_decode(case_2.c_str(), case_2.length(), encrypt_2.get()); + std::unique_ptr<char[]> decrypted_2(new char[case_2.length()]); + ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_2.get(), length_2, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_2.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_2(decrypted_2.get(), ret_code); + ASSERT_EQ(source_2, decrypted_content_2); + + std::unique_ptr<char[]> decrypted_21(new char[case_2.length()]); + ret_code = EncryptionUtil::decrypt(AES_128_CBC, (unsigned char*)encrypt_2.get(), length_2, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_21.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_21(decrypted_21.get(), ret_code); + ASSERT_EQ(source_2, decrypted_content_21); +} + +TEST_F(EncryptionUtilTest, sm4_with_iv_test_by_case) { + std::string case_1 = "9FFlX59+3EbIC7rqylMNwg=="; // base64 for encrypted "hello, doris" + std::string source_1 = "hello, doris"; + std::string case_2 = "RIJVVUUmMT/4CVNYdxVvXA=="; // base64 for encrypted "doris test" + std::string source_2 = "doris test"; + std::string iv = "doris"; + + std::unique_ptr<char[]> encrypt_1(new char[case_1.length()]); + int length_1 = base64_decode(case_1.c_str(), case_1.length(), encrypt_1.get()); + std::unique_ptr<char[]> decrypted_1(new char[case_1.length()]); + std::unique_ptr<char[]> decrypted_11(new char[case_1.length()]); + + int ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_1.get(), length_1, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_1.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_1(decrypted_1.get(), ret_code); + ASSERT_EQ(source_1, decrypted_content_1); + + std::unique_ptr<char[]> encrypt_2(new char[case_2.length()]); + int length_2 = base64_decode(case_2.c_str(), case_2.length(), encrypt_2.get()); + std::unique_ptr<char[]> decrypted_2(new char[case_2.length()]); + std::unique_ptr<char[]> decrypted_21(new char[case_2.length()]); + + ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_2.get(), length_2, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_2.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_2(decrypted_2.get(), ret_code); + ASSERT_EQ(source_2, decrypted_content_2); + + ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_1.get(), length_1, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_11.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_11(decrypted_11.get(), ret_code); + ASSERT_EQ(source_1, decrypted_content_11); + + ret_code = EncryptionUtil::decrypt(SM4_128_CBC, (unsigned char*)encrypt_2.get(), length_2, + (unsigned char*)_aes_key.c_str(), _aes_key.length(), + iv.c_str(), true, (unsigned char*)decrypted_21.get()); + ASSERT_TRUE(ret_code > 0); + std::string decrypted_content_21(decrypted_21.get(), ret_code); + ASSERT_EQ(source_2, decrypted_content_21); +} + } // namespace doris int main(int argc, char** argv) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org