From 733cee5f26b4f8c33b41257ccfc5b46633550ae3 Mon Sep 17 00:00:00 2001 From: Lukas Alt Date: Thu, 19 Dec 2024 17:27:47 +0000 Subject: [PATCH 1/2] JSON Protocol String & Base64 optimization for ARM --- LICENSE_BSD2 | 7 + thrift/lib/cpp2/protocol/BinaryProtocol-inl.h | 3 +- .../protocol/JSONProtocolCommon-ext-inl.h | 234 +++++++ .../cpp2/protocol/JSONProtocolCommon-inl.h | 571 +++++++++++++++++- thrift/lib/cpp2/protocol/JSONProtocolCommon.h | 13 +- .../cpp2/protocol/test/JSONProtocolTest.cpp | 153 ++++- 6 files changed, 933 insertions(+), 48 deletions(-) create mode 100644 LICENSE_BSD2 create mode 100644 thrift/lib/cpp2/protocol/JSONProtocolCommon-ext-inl.h diff --git a/LICENSE_BSD2 b/LICENSE_BSD2 new file mode 100644 index 00000000000..4763f1af5df --- /dev/null +++ b/LICENSE_BSD2 @@ -0,0 +1,7 @@ +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/thrift/lib/cpp2/protocol/BinaryProtocol-inl.h b/thrift/lib/cpp2/protocol/BinaryProtocol-inl.h index 3a08143eac9..e87180000e2 100644 --- a/thrift/lib/cpp2/protocol/BinaryProtocol-inl.h +++ b/thrift/lib/cpp2/protocol/BinaryProtocol-inl.h @@ -619,8 +619,7 @@ inline bool BinaryProtocolReader::advanceToNextField( if (in_.length() >= 3) { uint8_t type = *in_.data(); if (nextFieldType == type) { - int16_t fieldId = - folly::Endian::big(folly::loadUnaligned(in_.data() + 1)); + int16_t fieldId = folly::Endian::big(folly::loadUnaligned(in_.data() + 1)); in_.skipNoAdvance(3); if (nextFieldId == fieldId) { return true; diff --git a/thrift/lib/cpp2/protocol/JSONProtocolCommon-ext-inl.h b/thrift/lib/cpp2/protocol/JSONProtocolCommon-ext-inl.h new file mode 100644 index 00000000000..259f18537d5 --- /dev/null +++ b/thrift/lib/cpp2/protocol/JSONProtocolCommon-ext-inl.h @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2015-2018, Wojciech Muła. All rights reserved. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS + IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED + TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A + PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __ARM_NEON +#include +#endif /* __ARM_NEON */ + +namespace apache { +namespace thrift { +#ifdef __ARM_NEON + +#define packed_byte(x) vdup_n_u8(x) + +FOLLY_ALWAYS_INLINE uint8x8_t +lookup_pshufb_bitmask(const uint8x8_t input, uint8x8_t& error) { + const uint8x8_t higher_nibble = vshr_n_u8(input, 4); + const uint8x8_t lower_nibble = vand_u8(input, packed_byte(0x0f)); + + const uint8x8x2_t shiftLUT = { + 0, 0, 19, 4, uint8_t(-65), uint8_t(-65), uint8_t(-71), uint8_t(-71), + 0, 0, 0, 0, 0, 0, 0, 0}; + + const uint8x8x2_t maskLUT = { + /* 0 : 0b1010_1000*/ 0xa8, + /* 1 .. 9 : 0b1111_1000*/ 0xf8, + 0xf8, 0xf8, 0xf8, 0xf8, 0xf8, 0xf8, + 0xf8, 0xf8, + /* 10 : 0b1111_0000*/ 0xf0, + /* 11 : 0b0101_0100*/ 0x54, + /* 12 .. 14 : 0b0101_0000*/ 0x50, + 0x50, + 0x50, + /* 15 : 0b0101_0100*/ 0x54}; + + const uint8x8x2_t bitposLUT = { + 0x01,0x02,0x04,0x08,0x10,0x20,0x40,0x80, + 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00}; + + const uint8x8_t sh = vtbl2_u8(shiftLUT, higher_nibble); + const uint8x8_t eq_2f = vceq_u8(input, packed_byte(0x2f)); + const uint8x8_t shift = vbsl_u8(eq_2f, packed_byte(16), sh); + + const uint8x8_t M = vtbl2_u8(maskLUT, lower_nibble); + const uint8x8_t bit = vtbl2_u8(bitposLUT, higher_nibble); + + error = vceq_u8(vand_u8(M, bit), packed_byte(0)); + + const uint8x8_t result = vadd_u8(input, shift); + + return result; +} + +inline bool isBase64Terminating(const uint8x8_t field) { + static const uint8x8_t delimPattern = + vdup_n_u8(apache::thrift::detail::json::kJSONStringDelimiter); + static const uint8x8_t padPattern = vdup_n_u8('='); + + const uint8x8_t delimMatch = vceq_u8(field, delimPattern); + const uint8x8_t padMatch = vceq_u8(field, padPattern); + + return vget_lane_u64(vreinterpret_u64_u8(delimMatch), 0) != 0 || + vget_lane_u64(vreinterpret_u64_u8(padMatch), 0) != 0; +} + +static const uint8_t kBase64DecodeTable[256] = { + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x3e, 0xff, 0xff, 0xff, 0x3f, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, + 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, + 0x31, 0x32, 0x33, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, +}; + + +inline void JSONProtocolReaderCommon::readJSONBase64Neon(folly::io::QueueAppender& s) { + ensureCharNoWhitespace(apache::thrift::detail::json::kJSONStringDelimiter); + for (auto peek = in_.peekBytes(); !peek.empty(); peek = in_.peekBytes()) { + uint8_t buf[4]; + buf[3] = '\0'; + size_t skipped = 0; + int i = 0; + for (; i + 4 * 8 - 1 < peek.size(); i += 4 * 8) { + const uint8x8x4_t in = vld4_u8((const uint8_t*)&peek[0] + i); + s.ensure(3 * 8); + + if (isBase64Terminating(in.val[0]) || isBase64Terminating(in.val[1]) || + isBase64Terminating(in.val[2]) || isBase64Terminating(in.val[3])) { + if (skipped > 0) + in_.skip(skipped); + goto remainder; + } + uint8x8_t error_a; + uint8x8_t error_b; + uint8x8_t error_c; + uint8x8_t error_d; +#define lookup_fn lookup_pshufb_bitmask + + uint8x8_t field_a = lookup_fn(in.val[0], error_a); + uint8x8_t field_b = lookup_fn(in.val[1], error_b); + uint8x8_t field_c = lookup_fn(in.val[2], error_c); + uint8x8_t field_d = lookup_fn(in.val[3], error_d); + + const uint8x8_t error = + vorr_u8(vorr_u8(error_a, error_b), vorr_u8(error_c, error_d)); + + const uint64_t scalarError = vget_lane_u64(vreinterpret_u64_u8(error), 0); + if (scalarError) { + LOG(FATAL) << "Error decoding base64" << scalarError; + } + + uint8x8x3_t result; + result.val[0] = vorr_u8(vshr_n_u8(field_b, 4), vshl_n_u8(field_a, 2)); + result.val[1] = vorr_u8(vshr_n_u8(field_c, 2), vshl_n_u8(field_b, 4)); + result.val[2] = vorr_u8(field_d, vshl_n_u8(field_c, 6)); + + vst3_u8(s.writableData(), result); + s.append(3 * 8); + skipped += 4 * 8; + } + for (; i + 3 < peek.size(); i += 4) { + if (peek[i] == apache::thrift::detail::json::kJSONStringDelimiter || + peek[i] == '=' || + peek[i + 1] == apache::thrift::detail::json::kJSONStringDelimiter || + peek[i + 1] == '=' || + peek[i + 2] == apache::thrift::detail::json::kJSONStringDelimiter || + peek[i + 2] == '=' || + peek[i + 3] == apache::thrift::detail::json::kJSONStringDelimiter || + peek[i + 3] == '=') { + if (skipped > 0) { + in_.skip(skipped); + } + goto remainder; + } + s.ensure(3); + *s.writableData() = (kBase64DecodeTable[peek[i]] << 2) | + (kBase64DecodeTable[peek[i + 1]] >> 4); + *(s.writableData() + 1) = + ((kBase64DecodeTable[peek[i + 1]] << 4) & 0xf0) | + (kBase64DecodeTable[peek[i + 2]] >> 2); + *(s.writableData() + 2) = + ((kBase64DecodeTable[peek[i + 2]] << 6) & 0xc0) | + (kBase64DecodeTable[peek[i + 3]]); + s.append(3); + skipped += 4; + } + if (skipped == 0) + break; + in_.skip(skipped); + } + +remainder: + uint8_t input[4]; + uint8_t inputBufSize = 0; + // remainder + bool paddingReached{false}; + while (true) { + const auto c = in_.read(); + if (c == apache::thrift::detail::json::kJSONStringDelimiter) { + break; + } + if (paddingReached) { + continue; + } + if (c == '=') { + paddingReached = true; + continue; + } + input[inputBufSize++] = c; + if (inputBufSize == 4) { + base64_decode(input, 4); + input[3] = '\0'; + s.push(input, 3); + inputBufSize = 0; + } + } + + if (inputBufSize > 1) { + base64_decode(input, inputBufSize); + input[inputBufSize - 1] = '\0'; + s.push(input, inputBufSize - 1); + } + } +#endif // __ARM_NEON +} // namespace thrift +} // namespace apache \ No newline at end of file diff --git a/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h b/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h index c5f26338803..4a74ffb7a7a 100644 --- a/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h +++ b/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h @@ -1,5 +1,6 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +17,13 @@ #pragma once +#include +#include +#include #include +#ifdef __ARM_NEON +#include +#endif /* __ARM_NEON */ namespace apache { namespace thrift { @@ -261,8 +268,124 @@ inline uint32_t JSONProtocolWriterCommon::writeJSONEscapeChar(uint8_t ch) { return 6; } +inline uint32_t JSONProtocolWriterCommon::writeJSONChar(uint8_t ch) { + if (ch >= 32) { + // Only special character >= 32 is '\' and '=' + if (ch == apache::thrift::detail::json::kJSONStringDelimiter) { + constexpr uint16_t res = apache::thrift::detail::json::kJSONBackslash | + ((uint16_t)apache::thrift::detail::json::kJSONStringDelimiter << 8); + out_.write(res); + } else if (ch == apache::thrift::detail::json::kJSONBackslash) { + constexpr uint16_t res = apache::thrift::detail::json::kJSONBackslash | + ((uint16_t)apache::thrift::detail::json::kJSONBackslash << 8); + out_.write(res); + return 2; + } else { + out_.write(ch); + return 1; + } + } else { + uint8_t outCh = kJSONCharTable[ch]; + // Check if regular character, backslash escaped, or JSON escaped + if (outCh != 0) { + uint16_t res{0}; + res |= apache::thrift::detail::json::kJSONBackslash; + res |= ((uint16_t)outCh << 8); + out_.write(res); + return 2; + } else { + return writeJSONEscapeChar(ch); + } + } +} + inline uint32_t JSONProtocolWriterCommon::writeJSONString( folly::StringPiece str) { + uint32_t ret = 2; +#ifdef __ARM_NEON + if (str.empty()) { + // for an empty string + constexpr uint16_t res = + apache::thrift::detail::json::kJSONStringDelimiter | + ((uint16_t)apache::thrift::detail::json::kJSONStringDelimiter << 8); + out_.write(res); + return ret; + } + out_.write(apache::thrift::detail::json::kJSONStringDelimiter); + int i = 0; + + static const uint8x16_t backslashx16 = vdupq_n_u8('\\'); + static const uint8x16_t specialCharsx16 = vdupq_n_u8(0x30); + static const uint8x8_t backslashx8 = vdup_n_u8('\\'); + static const uint8x8_t specialCharsx8 = vdup_n_u8(0x30); + for (; i + 15 < str.size(); i += 16) { + // load 16 bytes per chunk, if available + uint8x16_t val = vld1q_u8((uint8_t*)&str[i]); + + uint8x16_t lteMask = vcleq_u8(val, specialCharsx16); // non-zero for every char that is below 0x30 / '0'. + uint8x16_t backslashMask = vceqq_u8(val, backslashx16); // non-zero for every char that equals to '\' + uint8x16_t mask = vorrq_u8(lteMask, backslashMask); // non-zero for every char that requires escape check + uint64_t lowEscaped = vgetq_lane_u64(vreinterpretq_u64_u8(mask), 0); // if true, any char in the lower half (byte 0..7) needs escape check + uint64_t highEscaped = vgetq_lane_u64(vreinterpretq_u64_u8(mask), 1); // if true, any char in the upper half (byte 8..15) needs escape check + if (FOLLY_UNLIKELY(lowEscaped || highEscaped)) { + + if (lowEscaped) { + for (int j = i; j < i + 8; ++j) { + ret += writeJSONChar(str[j]); + } + } else { + out_.push((const uint8_t*)&val, sizeof(uint8x8_t)); + ret += 8; + } + + if (highEscaped) { + for (int j = i + 8; j < i + 16; ++j) { + ret += writeJSONChar(str[j]); + } + } else { + out_.push((const uint8_t*)&val + sizeof(uint8x8_t), sizeof(uint8x8_t)); + ret += 8; + } + + } else { + out_.push((const uint8_t*)&val, sizeof(uint8x16_t)); + ret += 16; + } + } // end 16 byte per iteration loop + + for (; i + 7 < str.size(); i += 8) { + // load 16 bytes per chunk, if available + uint8x8_t val = vld1_u8((uint8_t*)&str[i]); + + uint8x8_t lteMask = vcle_u8(val, specialCharsx8); // non-zero for every char that is below 0x30 / '0'. + uint8x8_t backslashMask = vceq_u8(val, backslashx8); // non-zero for every char that equals to '\' + uint8x8_t mask = vorr_u8(lteMask, backslashMask); // non-zero for every char that requires escape check + uint64_t escaped = vget_lane_u64(vreinterpret_u64_u8(mask), 0); // if true, any char needs escape check + if (FOLLY_UNLIKELY(escaped)) { + auto firstEscapeIdx = __builtin_ctzl(escaped) / 8; + auto lastEscapeIdx = 7 - __builtin_clzl(escaped) / 8; + for (int j = i; j < i + firstEscapeIdx; ++j) { + out_.write(str[j]); + } + for (int j = i + firstEscapeIdx; j <= i + lastEscapeIdx; ++j) { + ret += writeJSONChar(str[j]); + } + for (int j = i + lastEscapeIdx + 1; j < i + 8; ++j) { + out_.write(str[j]); + } + } else { + out_.push((const uint8_t*)&val, sizeof(uint8x8_t)); + ret += 8; + } + } // end 8 byte per iteration loop + + // remainder loop +#pragma unroll + for (; i < str.size(); ++i) { + ret += writeJSONChar(str[i]); + } + out_.write(apache::thrift::detail::json::kJSONStringDelimiter); +#else // __ARM_NEON out_.write(apache::thrift::detail::json::kJSONStringDelimiter); uint32_t ret = 2; for (uint8_t ch : str) { @@ -288,34 +411,86 @@ inline uint32_t JSONProtocolWriterCommon::writeJSONString( } } out_.write(apache::thrift::detail::json::kJSONStringDelimiter); +#endif + return ret; +} + +inline uint64_t pack_uint64(const uint32_t i0, const uint32_t i1) { + uint64_t ret{0}; + ret |= ((uint64_t)i1) << 32; + ret |= i0; + return ret; +} +inline uint32_t pack_uint32( + const uint8_t c0, const uint8_t c1, const uint8_t c2, const uint8_t c3) { + uint32_t ret{0}; + ret |= c3 << 24; + ret |= c2 << 16; + ret |= c1 << 8; + ret |= c0; return ret; } +inline uint16_t pack_uint16(const uint8_t c0, const uint8_t c1) { + uint16_t ret{0}; + ret |= c1 << 8; + ret |= c0; + return ret; +} + +static const uint8_t* kBase64EncodeTable = + reinterpret_cast("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"); + +inline uint32_t base64_encode_3_inline( + const uint8_t in0, const uint8_t in1, const uint8_t in2) { + return pack_uint32( + // 6 bits of in0 + kBase64EncodeTable[(in0 >> 2) & 0x3f], + // 2 bits of in0 and 4 bits of in1 + kBase64EncodeTable[((in0 << 4) & 0x30) | ((in1 >> 4) & 0x0f)], + // 4 bits of in 1 and 2 bits of in2 + kBase64EncodeTable[((in1 << 2) & 0x3c) | ((in2 >> 6) & 0x03)], + // 6 bits of in2 + kBase64EncodeTable[in2 & 0x3f]); +} + inline uint32_t JSONProtocolWriterCommon::writeJSONBase64(folly::ByteRange v) { uint32_t ret = 2; out_.write(apache::thrift::detail::json::kJSONStringDelimiter); auto bytes = v.data(); uint32_t len = folly::to_narrow(v.size()); - uint8_t b[4]; + while (len >= 6) { + // encode 6 bytes at a time + out_.write(pack_uint64( + base64_encode_3_inline(bytes[0], bytes[1], bytes[2]), + base64_encode_3_inline(bytes[3], bytes[4], bytes[5]))); + ret += 8; + bytes += 6; + len -= 6; + } while (len >= 3) { // Encode 3 bytes at a time - base64_encode(bytes, 3, b); - for (int i = 0; i < 4; i++) { - out_.write(b[i]); - } + out_.write(base64_encode_3_inline(bytes[0], bytes[1], bytes[2])); ret += 4; bytes += 3; len -= 3; } - if (len) { // Handle remainder + if (len == 2) { DCHECK_LE(len, folly::to_unsigned(std::numeric_limits::max())); - base64_encode(bytes, folly::to_narrow(len), b); - for (uint32_t i = 0; i < len + 1; i++) { - out_.write(b[i]); - } - ret += len + 1; + out_.write(pack_uint32( + kBase64EncodeTable[(bytes[0] >> 2) & 0x3f], + kBase64EncodeTable[((bytes[0] << 4) & 0x30) | ((bytes[1] >> 4) & 0x0f)], + kBase64EncodeTable[(bytes[1] << 2) & 0x3c], + apache::thrift::detail::json::kJSONStringDelimiter)); + return ret + 3; + } else if (len == 1) { + DCHECK_LE(len, folly::to_unsigned(std::numeric_limits::max())); + out_.write(pack_uint16( + kBase64EncodeTable[(bytes[0] >> 2) & 0x3f], + kBase64EncodeTable[(bytes[0] << 4) & 0x30])); + ret += 2; } out_.write(apache::thrift::detail::json::kJSONStringDelimiter); @@ -406,19 +581,21 @@ inline void JSONProtocolReaderCommon::readBinary(StrType& str) { inline void JSONProtocolReaderCommon::readBinary( std::unique_ptr& str) { - std::string tmp; + folly::IOBufQueue queue; + folly::io::QueueAppender a(&queue, 1000); bool keyish; ensureAndReadContext(keyish); - readJSONBase64(tmp); - str = folly::IOBuf::copyBuffer(tmp); + readJSONBase64(a); + str = queue.move(); } inline void JSONProtocolReaderCommon::readBinary(folly::IOBuf& str) { - std::string tmp; + folly::IOBufQueue queue; + folly::io::QueueAppender a(&queue, 1000); bool keyish; ensureAndReadContext(keyish); - readJSONBase64(tmp); - str.appendChain(folly::IOBuf::copyBuffer(tmp)); + readJSONBase64(a); + str.appendChain(queue.move()); } /** @@ -687,13 +864,353 @@ inline void JSONProtocolReaderCommon::readJSONEscapeChar(uint8_t& out) { out = static_cast((hexVal(b1) << 4) + hexVal(b2)); } +#ifdef __ARM_NEON +static inline uint8_t hexVal_inl(uint8_t ch) { + if ((ch >= '0') && (ch <= '9')) { + return ch - '0'; + } else if ((ch >= 'a') && (ch <= 'f')) { + return ch - 'a' + 10; + } else if ((ch >= 'A') && (ch <= 'F')) { + return ch - 'A' + 10; + } else { + return '\0'; + } +} + +inline uint8_t readOrFallback( + const folly::ByteRange& input, + folly::io::Cursor& fallBackInput, + unsigned& inIdx, + unsigned& skipped) { + if (inIdx < input.size()) { + ++skipped; + return input[inIdx++]; + } else { + if (skipped > 0) { + fallBackInput.skip(skipped); + skipped = 0; + } + return fallBackInput.read(); + } +} + +inline char lookupJSONEscapeChar(char ch) { + switch (ch) { + case '"': [[fallthrough]]; + case '\\': [[fallthrough]]; + case '/': + return ch; + case 'b': + return '\b'; + case 'f': + return '\f'; + case 'n': + return '\n'; + case 'r': + return '\r'; + case 't': + return '\t'; + default: + throw TProtocolException( + TProtocolException::INVALID_DATA, + "Invalid escaped char " + std::to_string(ch)); + } +} + +inline char16_t hexToChar16(uint8_t c0, uint8_t c1, uint8_t c2, uint8_t c3) { + uint16_t result = 0; + + // Convert each character to its numerical equivalent + result |= hexVal_inl(c0) << 12; + result |= hexVal_inl(c1) << 8; + result |= hexVal_inl(c2) << 4; + result |= hexVal_inl(c3); + + return result; +} + +template +inline void appendChar32ToString(StrType& s, char32_t wc) { + if (wc < 0x80) { + s += static_cast(wc); + } else if (wc < 0x800) { + s += static_cast((wc >> 6) | 0xC0); + s += static_cast((wc & 0x3F) | 0x80); + } else if (wc < 0x10000) { + s += static_cast((wc >> 12) | 0xE0); + s += static_cast(((wc >> 6) & 0x3F) | 0x80); + s += static_cast((wc & 0x3F) | 0x80); + } else { + s += static_cast((wc >> 18) | 0xF0); + s += static_cast(((wc >> 12) & 0x3F) | 0x80); + s += static_cast(((wc >> 6) & 0x3F) | 0x80); + s += static_cast((wc & 0x3F) | 0x80); + } +} + +template +inline bool decodeJSONStringSequentially( + const folly::ByteRange& input, + folly::io::Cursor& fallBackInput, + StrType& output, + unsigned& inIdx, + unsigned& skipped, + unsigned limit, + char16_t& highSurrogate) { + unsigned t = inIdx + limit; + while ((limit == 0 || inIdx < t) && inIdx < input.size()) { + uint8_t ch = readOrFallback(input, fallBackInput, inIdx, skipped); + if (ch == '"') { + return true; + } + if (ch == '\\') { + ch = readOrFallback(input, fallBackInput, inIdx, skipped); + if (ch == 'u') { + // skip first to chars (expected to be zero) and parse hex chars into + // single char + const uint8_t c0 = readOrFallback(input, fallBackInput, inIdx, skipped); + if constexpr (!allowDecodeUTF8) { + if (c0 != '0') { + throw TProtocolException( + TProtocolException::INVALID_DATA, + "Expected ASCII char but got unicode char"); + } + } + const uint8_t c1 = readOrFallback(input, fallBackInput, inIdx, skipped); + if constexpr (!allowDecodeUTF8) { + if (c1 != '0') { + throw TProtocolException( + TProtocolException::INVALID_DATA, + "Expected ASCII char but got unicode char"); + } + } + const uint8_t c2 = readOrFallback(input, fallBackInput, inIdx, skipped); + const uint8_t c3 = readOrFallback(input, fallBackInput, inIdx, skipped); + // do not inline this lookup into the function call, evaluation order of + // args is undefined + const char16_t ch1 = hexToChar16(c0, c1, c2, c3); + + if constexpr (allowDecodeUTF8) { + if (highSurrogate != 0 && + folly::utf16_code_unit_is_low_surrogate(ch1)) { + appendChar32ToString( + output, + folly::unicode_code_point_from_utf16_surrogate_pair( + highSurrogate, ch1)); + highSurrogate = 0; + } else if ( + folly::utf16_code_unit_is_high_surrogate(ch1) && + inIdx < input.size()) { + highSurrogate = ch1; + } else { + appendChar32ToString(output, ch1); + } + } else { + appendChar32ToString(output, ch1); + } + continue; + } else { + ch = lookupJSONEscapeChar(ch); + } + } + if constexpr (allowDecodeUTF8) { + if (highSurrogate != 0) { + appendChar32ToString(output, highSurrogate); + highSurrogate = 0; + } + } + output += (char)ch; + } + return false; +} + +template +inline void decodeJSONStringRemainder( + folly::io::Cursor& fallBackInput, + StrType& output, + char16_t& highSurrogate) { + while (true) { + auto ch = fallBackInput.read(); + if (ch == '"') { + return; + } + if (ch == '\\') { + ch = fallBackInput.read(); + if (ch == 'u') { + auto peek = fallBackInput.peek(); + if (peek.size() < 4) { + throw TProtocolException( + TProtocolException::INVALID_DATA, + "Read beginning of escaped unicode char, expected unicode hex but reached end of stream instead"); + } + if constexpr (!allowDecodeUTF8) { + if (peek[0] != '0' || peek[1] != '0') { + throw TProtocolException( + TProtocolException::INVALID_DATA, + "Expected ASCII char but got unicode char"); + } + } + char16_t ch1 = hexToChar16(peek[0], peek[1], peek[2], peek[3]); + fallBackInput.skip(4); + + if constexpr (allowDecodeUTF8) { + if (highSurrogate != 0 && + folly::utf16_code_unit_is_low_surrogate(ch1)) { + appendChar32ToString( + output, + folly::unicode_code_point_from_utf16_surrogate_pair( + highSurrogate, ch1)); + highSurrogate = 0; + } else if (folly::utf16_code_unit_is_high_surrogate(ch1)) { + if (highSurrogate != 0) { + appendChar32ToString(output, highSurrogate); + } + highSurrogate = ch1; + } else { + appendChar32ToString(output, ch1); + } + } else { + appendChar32ToString(output, ch1); + } + continue; + } else { + ch = lookupJSONEscapeChar(ch); + } + } + if constexpr (allowDecodeUTF8) { + if (highSurrogate != 0) { + appendChar32ToString(output, highSurrogate); + highSurrogate = 0; + } + } + output += (char)ch; + } +} + +template +bool decodeJSONStringPeek( + const folly::ByteRange& input, + folly::io::Cursor& fallBackInput, + StrType& output, + bool& dataConsumed, + char16_t& highSurrogate) { + if (input.empty()) { + throw TProtocolException( + TProtocolException::INVALID_DATA, + "Got empty input stream while decoding JSON String"); + } + + unsigned int i = 0; + unsigned int skipped = 0; + + static const uint8x16_t endMask = vdupq_n_u8('"'); + static const uint8x16_t escapeMask = vdupq_n_u8('\\'); + uint8_t buf[16] ; + + bool stringTerminated{false}; + + while (i + 15 < input.size()) { + unsigned int j = 0; + // load 16 consecutive chars into a vector register + uint8x16_t val = vld1q_u8((uint8_t*)&input[i]); + // check each char for equality to termination character + uint8x16_t endRes = vceqq_u8(val, endMask); // holds 0 or 1 for each char + // check each char for equality to the escape character + uint8x16_t escRes = vceqq_u8(val, escapeMask); + + uint8x16_t res = vorrq_u8(endRes, escRes); + uint64x2_t cmp = vreinterpretq_u64_u8(res); + + + bool firstHalfRequiresSeqProcessing = + vgetq_lane_u64(cmp, 0) != 0; + bool secondHalfRequiresSeqProcessing = + vgetq_lane_u64(cmp, 1) != 0; + if constexpr (allowDecodeUTF8) { + if (highSurrogate != 0 && !firstHalfRequiresSeqProcessing) { + appendChar32ToString(output, highSurrogate); + highSurrogate = 0; + } + } + if (FOLLY_UNLIKELY(firstHalfRequiresSeqProcessing || secondHalfRequiresSeqProcessing)) { + // the 16 char contains a JSON escape or termination character - we cannot + // copy the input + // If the first lane of the uint64x2_t is zero, the first 8 chars do not + // contain a termination + uint8_t limit = 16; + if (!firstHalfRequiresSeqProcessing) { + // Only the second half requires special processing, we can trivially + // copy the first half and process the second half sequentially + vst1_u8(buf, vget_low_u8(val)); + output.append((const char*)buf, 8); + i += 8; + skipped += 8; + limit = 8; + } + stringTerminated |= decodeJSONStringSequentially( + input, fallBackInput, output, i, skipped, limit, highSurrogate); + if (stringTerminated) { + goto done; + } + } else { + // no escaping required. just copy the input + + /* + * vst1q_u8(buf, val); + * output.append((char*)&buf, 16); + */ + output.append((char*)&val, 16); + + i += 16; + skipped += 16; + } + } + +done: + if (skipped > 0) { + fallBackInput.skip(skipped); + skipped = 0; + dataConsumed = true; + } else { + dataConsumed = false; + } + return stringTerminated; +} + +template +inline void readJSONStringNeon(folly::io::Cursor& in_, StrType& out) { + try { + bool dataConsumed{true}; + char16_t highSurrogate{0}; + for (auto peek = in_.peekBytes(); !peek.empty() && dataConsumed; + peek = in_.peekBytes()) { + if (decodeJSONStringPeek( + peek, in_, out, dataConsumed, highSurrogate)) { + return; + } + } + decodeJSONStringRemainder(in_, out, highSurrogate); + } catch (std::out_of_range& ex) { + throw TProtocolException( + TProtocolException::INVALID_DATA, + "Reached preliminary end of input stream"); + } +} +#endif // __ARM_NEON + template inline void JSONProtocolReaderCommon::readJSONString(StrType& val) { ensureChar(apache::thrift::detail::json::kJSONStringDelimiter); - + val.clear(); +#ifdef __ARM_NEON + if (allowDecodeUTF8_) { + readJSONStringNeon(in_, val); + } else { + readJSONStringNeon(in_, val); + } +#else // __ARM_NEON std::string json = "\""; bool fullDecodeRequired = false; - val.clear(); while (true) { auto ch = in_.read(); if (ch == apache::thrift::detail::json::kJSONStringDelimiter) { @@ -740,16 +1257,19 @@ inline void JSONProtocolReaderCommon::readJSONString(StrType& val) { throwUnrecognizableAsString(json, e); } } +#endif // __ARM_NEON } -template -inline void JSONProtocolReaderCommon::readJSONBase64(StrType& str) { +inline void JSONProtocolReaderCommon::readJSONBase64( + folly::io::QueueAppender& s) { +#ifdef __ARM_NEON + readJSONBase64Neon(s); +#else // __ARM_NEON std::string tmp; readJSONString(tmp); uint8_t* b = (uint8_t*)tmp.c_str(); uint32_t len = folly::to_narrow(tmp.length()); - str.clear(); // Allow optional trailing '=' as padding while (len > 0 && b[len - 1] == '=') { @@ -758,7 +1278,7 @@ inline void JSONProtocolReaderCommon::readJSONBase64(StrType& str) { while (len >= 4) { base64_decode(b, 4); - str.append((const char*)b, 3); + s.push(b, 3); b += 4; len -= 4; } @@ -766,8 +1286,9 @@ inline void JSONProtocolReaderCommon::readJSONBase64(StrType& str) { // base64 but legal for skip of regular string type) if (len > 1) { base64_decode(b, len); - str.append((const char*)b, len - 1); + s.push(b, len - 1); } +#endif } // Return the integer value of a hex character ch. @@ -809,4 +1330,4 @@ inline int8_t JSONProtocolReaderCommon::peekCharSafe() { } } // namespace thrift -} // namespace apache +} // namespace apache \ No newline at end of file diff --git a/thrift/lib/cpp2/protocol/JSONProtocolCommon.h b/thrift/lib/cpp2/protocol/JSONProtocolCommon.h index c42391db45d..c87e4981f4b 100644 --- a/thrift/lib/cpp2/protocol/JSONProtocolCommon.h +++ b/thrift/lib/cpp2/protocol/JSONProtocolCommon.h @@ -255,8 +255,18 @@ class JSONProtocolReaderCommon : public detail::ProtocolBase { void readJSONEscapeChar(uint8_t& out); template void readJSONString(StrType& val); + void readJSONBase64(folly::io::QueueAppender& s); +#ifdef __ARM_NEON + void readJSONBase64Neon(folly::io::QueueAppender& s); +#endif // __ARM_NEON template - void readJSONBase64(StrType& s); + inline void readJSONBase64(StrType& s) { + folly::IOBufQueue queue; + folly::io::QueueAppender a(&queue, 1000); + readJSONBase64(a); + folly::IOBuf b = queue.moveAsValue(); + s.append((const char *) b.data(), b.length()); + } // This string's characters must match up with the elements in kEscapeCharVals // I don't have '/' on this list even though it appears on www.json.org -- @@ -318,3 +328,4 @@ class JSONProtocolReaderCommon : public detail::ProtocolBase { } // namespace apache::thrift #include +#include diff --git a/thrift/lib/cpp2/protocol/test/JSONProtocolTest.cpp b/thrift/lib/cpp2/protocol/test/JSONProtocolTest.cpp index e8aaeea3129..3b0b04fd4e4 100644 --- a/thrift/lib/cpp2/protocol/test/JSONProtocolTest.cpp +++ b/thrift/lib/cpp2/protocol/test/JSONProtocolTest.cpp @@ -1,5 +1,6 @@ /* * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -191,12 +192,49 @@ TEST_F(JSONProtocolTest, writeFloat) { } } +TEST_F(JSONProtocolTest, writeEmptyString) { + EXPECT_EQ("\"\"",writing_cpp2([](W& p) { p.writeString(""); })); +} + + TEST_F(JSONProtocolTest, writeString) { auto expected = R"("foobar")"; EXPECT_EQ( expected, writing_cpp2([](W& p) { p.writeString(string("foobar")); })); EXPECT_EQ(expected, writing_cpp2([](W& p) { p.writeString("foobar"); })); } +TEST_F(JSONProtocolTest, writeStringLastEscaped) { + for (int i = 1; i <= 64; ++i) { + std::string input = std::string(i - 1, 'a') + '\\'; + std::string expected = "\"" + std::string(i - 1, 'a') + "\\\\\""; + EXPECT_EQ(expected, writing_cpp2([input](W& p) { p.writeString(input); })) + << "Failed to handle escape at index " << i; + } +} +TEST_F(JSONProtocolTest, writeStringFullyEscaped) { + for (int i = 1; i <= 64; ++i) { + std::string input = std::string(i, '\\'); + std::string expected = "\"" + input + input + "\""; + EXPECT_EQ(expected, writing_cpp2([input](W& p) { p.writeString(input); })) + << "Failed to handle fully escaped string of length " << i; + } +} +TEST_F(JSONProtocolTest, writeStringPairEscaped) { + std::string input = "aaaa\\aaabbb\\bbbb"; + std::string expected = "\"aaaa\\\\aaabbb\\\\bbbb\""; + EXPECT_EQ(expected, writing_cpp2([input](W& p) { p.writeString(input); })); +} +TEST_F(JSONProtocolTest, writeString_large) { + auto expected = R"("eigccbbunjfhnhgehbljlrvjjicebnievrnedgvhrhit")"; + EXPECT_EQ( + expected, writing_cpp2([](W& p) { + p.writeString(string("eigccbbunjfhnhgehbljlrvjjicebnievrnedgvhrhit")); + }) + ); + EXPECT_EQ(expected, writing_cpp2([](W& p) { + p.writeString("eigccbbunjfhnhgehbljlrvjjicebnievrnedgvhrhit"); + })); +} TEST_F(JSONProtocolTest, writeBinary) { auto expected = R"("Zm9vYmFy")"; @@ -210,6 +248,23 @@ TEST_F(JSONProtocolTest, writeBinary) { expected, writing_cpp2([](W& p) { p.writeBinary(*IOBuf::wrapBuffer(ByteRange(StringPiece("foobar")))); })); + + + // Test different remainder lengths + EXPECT_EQ( + "\"TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNldGV0dXIgc2FkaXBzY2luZyBlbGl0cg\"", + writing_cpp2([](W& p) { p.writeBinary(string("Lorem ipsum dolor sit amet, consetetur sadipscing elitr"));}) + ); + EXPECT_EQ("\"TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNldGV0dXIgc2FkaXBzY2luZyBlbGl0\"", + writing_cpp2([](W& p) { p.writeBinary(string("Lorem ipsum dolor sit amet, consetetur sadipscing elit")); }) + ); + EXPECT_EQ("\"TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNldGV0dXIgc2FkaXBzY2luZyBlbGk\"", + writing_cpp2([](W& p) { p.writeBinary(string("Lorem ipsum dolor sit amet, consetetur sadipscing eli")); }) + ); + EXPECT_EQ("\"Zm9vYg\"", writing_cpp2([](W& p) { p.writeBinary(string("foob")); })); + EXPECT_EQ("\"Zm9v\"", writing_cpp2([](W& p) { p.writeBinary(string("foo")); })); + EXPECT_EQ("\"Zm8\"", writing_cpp2([](W& p) { p.writeBinary(string("fo")); })); + EXPECT_EQ("\"Zg\"", writing_cpp2([](W& p) { p.writeBinary(string("f")); })); } TEST_F(JSONProtocolTest, writeMessage) { @@ -447,38 +502,96 @@ TEST_F(JSONProtocolTest, readFloat_numeric_limits) { })); } +#define STR_TEST_CASE(expected, input) \ + EXPECT_EQ(expected, reading_cpp2(input, [](R& p) { \ + return returning([&](string& _) { p.readString(_); }); })); + +#define STR_TEST_CASE_INVALID(input) \ + EXPECT_ANY_THROW(reading_cpp2(input, [](R& p) { \ + return returning([&](string& _) { p.readString(_); }); })); + +#define STR_TEST_CASE_NO_UTF8(expected, input) \ + EXPECT_EQ(expected, reading_cpp2(input, [](R& p) { \ + p.setAllowDecodeUTF8(false); \ + return returning([&](string& _) { p.readString(_); }); })); + +TEST_F(JSONProtocolTest, readStringEmpty) { + STR_TEST_CASE("", "\"\""); +} + +TEST_F(JSONProtocolTest, readStringTerminated) { + STR_TEST_CASE("abcdefgh\\asddffff", "\"abcdefgh\\\\asddffff\""); +} TEST_F(JSONProtocolTest, readString) { - auto input = R"("foobar")"; - auto expected = "foobar"; - EXPECT_EQ(expected, reading_cpp2(input, [](R& p) { - return returning([&](string& _) { p.readString(_); }); - })); + STR_TEST_CASE("foobar", "\"foobar\""); + STR_TEST_CASE("0123456789abcdef", "\"0123456789abcdef\""); + STR_TEST_CASE("0123456789abcdefm", "\"0123456789abcdefm\""); + STR_TEST_CASE("0123456789abcdef0123456789abcde", "\"0123456789abcdef0123456789abcde\""); + STR_TEST_CASE("0123456789abcdef0123456789abcdef", "\"0123456789abcdef0123456789abcdef\""); + STR_TEST_CASE("0123456789abcdef0123456789abcdefm", "\"0123456789abcdef0123456789abcdefm\""); + STR_TEST_CASE("0123456789abcdef012345\\789abcdef", "\"0123456789abcdef012345\\\\789abcdef\""); + STR_TEST_CASE("foobar\u263A", "\"foobar\\u263A\""); + STR_TEST_CASE("foobaf\U0007263A", "\"foobaf\\uD989\\uDE3A\""); } TEST_F(JSONProtocolTest, readString_raw) { auto input = R"("\u0019\u0002\u0000\u0000")"; auto expected = string("\x19\x02\x00\x00", 4); - EXPECT_EQ(expected, reading_cpp2(input, [](R& p) { - p.setAllowDecodeUTF8(false); - return returning([&](string& _) { p.readString(_); }); - })); - EXPECT_EQ(expected, reading_cpp2(input, [](R& p) { - return returning([&](string& _) { p.readString(_); }); - })); + STR_TEST_CASE_NO_UTF8(expected, input); + STR_TEST_CASE(expected, input); } TEST_F(JSONProtocolTest, readString_utf8) { - auto input = R"("\u263A")"; - EXPECT_EQ("\u263A", reading_cpp2(input, [](R& p) { - return returning([&](string& _) { p.readString(_); }); - })); + STR_TEST_CASE("\u263A", R"("\u263A")"); + STR_TEST_CASE("\u263A\u263A", R"("\u263A\u263A")"); // do not falsely treat as surrogate pair +} +TEST_F(JSONProtocolTest, readString_utf8_offset) { + STR_TEST_CASE("01234567890\u263A", R"("01234567890\u263A")"); + STR_TEST_CASE("01234567890abc\u263A", R"("01234567890abc\u263A")"); + STR_TEST_CASE("01234567890abcdef\u263A", R"("01234567890abcdef\u263A")"); + STR_TEST_CASE("01234567890abcdefg\u263A", R"("01234567890abcdefg\u263A")"); + STR_TEST_CASE("\u263A01234567890", R"("\u263A01234567890")"); } TEST_F(JSONProtocolTest, readString_utf8_surrogate_pair) { - auto input = R"("\uD989\uDE3A")"; - EXPECT_EQ("\U0007263A", reading_cpp2(input, [](R& p) { - return returning([&](string& _) { p.readString(_); }); - })); + STR_TEST_CASE("\U0007263A", R"("\uD989\uDE3A")"); + // surrogate pair on edge between 16-byte chunk and remainder + STR_TEST_CASE("0123456789\U0007263A", R"("0123456789\uD989\uDE3A")"); + STR_TEST_CASE("0123456789a\U0007263A", R"("0123456789a\uD989\uDE3A")"); + STR_TEST_CASE("0123456789abcdef\U0007263A0123456789abcdef", R"("0123456789abcdef\uD989\uDE3A0123456789abcdef")"); +} +TEST_F(JSONProtocolTest, readString_invalid_incompleteUnicode) { + STR_TEST_CASE_INVALID(R"("\u")"); +} +TEST_F(JSONProtocolTest, readString_invalid_incompleteSurrogatePair) { + STR_TEST_CASE_INVALID(R"("\uD989\uDE3")"); +} +TEST_F(JSONProtocolTest, readString_invalid_incompleteEscape) { + STR_TEST_CASE_INVALID(R"("\")"); +} +TEST_F(JSONProtocolTest, readString_invalid_invalidEscape) { + STR_TEST_CASE_INVALID(R"("\L")"); +} +TEST_F(JSONProtocolTest, readString_invalid_unterminated) { + STR_TEST_CASE_INVALID("\""); +} + +TEST_F(JSONProtocolTest, readBinary_large) { + + auto input = R"("TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNldGV0dXIgc2FkaXBzY2luZyBlbGl0ciwgc2VkIGRpYW0gbm9udW15IGVpcm1vZCB0ZW1wb3IgaW52aWR1bnQgdXQgbA")"; + auto expected = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor invidunt ut l"; + EXPECT_EQ(expected, reading_cpp2(input, [](R& p) { + return returning([&](string& _) { p.readBinary(_); }); + })); +} + +TEST_F(JSONProtocolTest, readBinary_large_remainder) { + + auto input = R"("TG9yZW0gaXBzdW0gZG9sb3I============")"; + auto expected = "Lorem ipsum dolor"; + EXPECT_EQ(expected, reading_cpp2(input, [](R& p) { + return returning([&](string& _) { p.readBinary(_); }); + })); } TEST_F(JSONProtocolTest, readBinary) { From a866b18b846235556d67872572d32fbf846a8a43 Mon Sep 17 00:00:00 2001 From: Lukas Alt Date: Thu, 19 Dec 2024 18:19:00 +0000 Subject: [PATCH 2/2] Fixed x86 compilation issue --- thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h b/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h index 4a74ffb7a7a..d66a73e1e31 100644 --- a/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h +++ b/thrift/lib/cpp2/protocol/JSONProtocolCommon-inl.h @@ -387,7 +387,6 @@ inline uint32_t JSONProtocolWriterCommon::writeJSONString( out_.write(apache::thrift::detail::json::kJSONStringDelimiter); #else // __ARM_NEON out_.write(apache::thrift::detail::json::kJSONStringDelimiter); - uint32_t ret = 2; for (uint8_t ch : str) { // Only special characters >= 32 are '\' and '"' if (ch == apache::thrift::detail::json::kJSONBackslash ||