From fa9df458bd62310bd24b28863416cb4e7fc7b97b Mon Sep 17 00:00:00 2001 From: Halulu Date: Mon, 8 Jan 2024 21:01:02 +0800 Subject: [PATCH 01/23] rebase --- libOTe/Tools/SilentPprf.h | 4 + libOTe/Tools/Subfield/ExConvCode.h | 631 +++++++ libOTe/Tools/Subfield/Expander.h | 499 ++++++ libOTe/Tools/Subfield/Subfield.h | 231 +++ libOTe/Tools/Subfield/SubfieldPprf.h | 1444 +++++++++++++++++ .../SoftSpokenOT/SoftSpokenMalOtExt.cpp | 2 +- libOTe/Vole/Subfield/NoisyVoleReceiver.h | 105 ++ libOTe/Vole/Subfield/NoisyVoleSender.h | 97 ++ libOTe/Vole/Subfield/SilentVoleReceiver.h | 788 +++++++++ libOTe/Vole/Subfield/SilentVoleSender.h | 480 ++++++ libOTe_Tests/Subfield_Test.h | 13 + libOTe_Tests/Subfield_Tests.cpp | 462 ++++++ libOTe_Tests/UnitTests.cpp | 6 + 13 files changed, 4761 insertions(+), 1 deletion(-) create mode 100644 libOTe/Tools/Subfield/ExConvCode.h create mode 100644 libOTe/Tools/Subfield/Expander.h create mode 100644 libOTe/Tools/Subfield/Subfield.h create mode 100644 libOTe/Tools/Subfield/SubfieldPprf.h create mode 100644 libOTe/Vole/Subfield/NoisyVoleReceiver.h create mode 100644 libOTe/Vole/Subfield/NoisyVoleSender.h create mode 100644 libOTe/Vole/Subfield/SilentVoleReceiver.h create mode 100644 libOTe/Vole/Subfield/SilentVoleSender.h create mode 100644 libOTe_Tests/Subfield_Test.h create mode 100644 libOTe_Tests/Subfield_Tests.cpp diff --git a/libOTe/Tools/SilentPprf.h b/libOTe/Tools/SilentPprf.h index b5ffbd4d..6a619171 100644 --- a/libOTe/Tools/SilentPprf.h +++ b/libOTe/Tools/SilentPprf.h @@ -72,6 +72,10 @@ namespace osuCrypto //MaliciousFS }; + u64 interleavedPoint(u64 point, u64 treeIdx, u64 totalTrees, u64 domain, PprfOutputFormat format); + void interleavedPoints(span points, u64 domain, PprfOutputFormat format); + u64 getActivePath(const span& choiceBits); + struct TreeAllocator { TreeAllocator() = default; diff --git a/libOTe/Tools/Subfield/ExConvCode.h b/libOTe/Tools/Subfield/ExConvCode.h new file mode 100644 index 00000000..98ba3140 --- /dev/null +++ b/libOTe/Tools/Subfield/ExConvCode.h @@ -0,0 +1,631 @@ +// � 2023 Visa. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +#pragma once + +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Timer.h" +#include "libOTe/Tools/Subfield/Expander.h" +#include "libOTe/Tools/EACode/Util.h" + +namespace osuCrypto::Subfield +{ + + // The encoder for the generator matrix G = B * A. dualEncode(...) is the main function + // config(...) should be called first. + // + // B is the expander while A is the convolution. + // + // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly + // with fixed row weight mExpanderWeight. + // + // A is a lower triangular n by n matrix with ones on the diagonal. The + // mAccumulatorSize diagonals left of the main diagonal are uniformly random. + // If mStickyAccumulator, then the first diagonal left of the main is always ones. + // + // See ExConvCodeInstantiations.cpp for how to instantiate new types that + // dualEncode can be called on. + // + // https://eprint.iacr.org/2023/882 + + template + class ExConvCode : public TimerAdapter + { + public: + ExpanderCode mExpander; + + // configure the code. The default parameters are choses to balance security and performance. + // For additional parameter choices see the paper. + void config( + u64 messageSize, + u64 codeSize = 0 /*2 * messageSize is default */, + u64 expanderWeight = 7, + u64 accumulatorSize = 16, + bool systematic = true, + block seed = block(99999, 88888)) + { + if (codeSize == 0) + codeSize = 2 * messageSize; + + if (accumulatorSize % 8) + throw std::runtime_error("ExConvCode accumulator size must be a multiple of 8." LOCATION); + + mSeed = seed; + mMessageSize = messageSize; + mCodeSize = codeSize; + mAccumulatorSize = accumulatorSize; + mSystematic = systematic; + mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); + } + + // the seed that generates the code. + block mSeed = ZeroBlock; + + // The message size of the code. K. + u64 mMessageSize = 0; + + // The codeword size of the code. n. + u64 mCodeSize = 0; + + // The size of the accumulator. + u64 mAccumulatorSize = 0; + + // is the code systematic (true=faster) + bool mSystematic = true; + + // return n-k. code size n, message size k. + u64 parityRows() const { return mCodeSize - mMessageSize; } + + // return code size n. + u64 parityCols() const { return mCodeSize; } + + // return message size k. + u64 generatorRows() const { return mMessageSize; } + + // return code size n. + u64 generatorCols() const { return mCodeSize; } + + // Compute w = G * e. e will be modified in the computation. + template + void dualEncode(span e, span w) + { + if (e.size() != mCodeSize) + throw RTE_LOC; + + if (w.size() != mMessageSize) + throw RTE_LOC; + + if (mSystematic) + { + dualEncode(e); + memcpy(w.data(), e.data(), w.size() * sizeof(T)); + setTimePoint("ExConv.encode.memcpy"); + } + else + { + + setTimePoint("ExConv.encode.begin"); + + accumulate(e); + + setTimePoint("ExConv.encode.accumulate"); + + mExpander.template expand(e, w); + setTimePoint("ExConv.encode.expand"); + } + } + + // Compute e[0,...,k-1] = G * e. + template + void dualEncode(span e) + { + if (e.size() != mCodeSize) + throw RTE_LOC; + + if (mSystematic) + { + auto d = e.subspan(mMessageSize); + setTimePoint("ExConv.encode.begin"); + accumulate(d); + setTimePoint("ExConv.encode.accumulate"); + mExpander.template expand(d, e.subspan(0, mMessageSize)); + setTimePoint("ExConv.encode.expand"); + } + else + { + oc::AlignedUnVector w(mMessageSize); + dualEncode(e, w); + memcpy(e.data(), w.data(), w.size() * sizeof(T)); + setTimePoint("ExConv.encode.memcpy"); + + } + } + + + // Compute e[0,...,k-1] = G * e. + template + void dualEncode2(span e0, span e1) + { + if (e0.size() != mCodeSize) + throw RTE_LOC; + if (e1.size() != mCodeSize) + throw RTE_LOC; + + if (mSystematic) + { + auto d0 = e0.subspan(mMessageSize); + auto d1 = e1.subspan(mMessageSize); + setTimePoint("ExConv.encode.begin"); + accumulate(d0, d1); + setTimePoint("ExConv.encode.accumulate"); + mExpander.template expand( + d0, d1, + e0.subspan(0, mMessageSize), + e1.subspan(0, mMessageSize)); + setTimePoint("ExConv.encode.expand"); + } + else + { + //oc::AlignedUnVector w0(mMessageSize); + //dualEncode(e, w); + //memcpy(e.data(), w.data(), w.size() * sizeof(T)); + //setTimePoint("ExConv.encode.memcpy"); + + // not impl. + throw RTE_LOC; + + } + } + + // get the expander matrix + SparseMtx getB() const + { + if (mSystematic) + { + PointList R(mMessageSize, mCodeSize); + auto B = mExpander.getB().points(); + + for (auto p : B) + { + R.push_back(p.mRow, mMessageSize + p.mCol); + } + for (u64 i = 0; i < mMessageSize; ++i) + R.push_back(i, i); + + return R; + } + else + { + return mExpander.getB(); + } + + } + + // Get the parity check version of the accumulator + SparseMtx getAPar() const + { + PRNG prng(mSeed ^ OneBlock); + + auto n = mCodeSize - mSystematic * mMessageSize; + + PointList AP(n, n);; + DenseMtx A = DenseMtx::Identity(n); + + block rnd; + u8* __restrict ptr = (u8*)prng.mBuffer.data(); + auto qe = prng.mBuffer.size() * 128; + u64 q = 0; + + for (u64 i = 0; i < n; ++i) + { + accOne(AP, i, ptr, prng, rnd, q, qe, n); + } + return AP; + } + + // get the accumulator matrix + SparseMtx getA() const + { + auto APar = getAPar(); + + auto A = DenseMtx::Identity(mCodeSize); + + u64 offset = mSystematic ? mMessageSize : 0ull; + + for (u64 i = 0; i < APar.rows(); ++i) + { + for (auto y : APar.col(i)) + { + //std::cout << y << " "; + if (y != i) + { + auto ay = A.row(y + offset); + auto ai = A.row(i + offset); + ay ^= ai; + } + } + + //std::cout << "\n" << A << std::endl; + } + + return A.sparse(); + } + + // Private functions ------------------------------------ + + inline static void refill(PRNG& prng) + { + assert(prng.mBuffer.size() == 256); + //block b[8]; + for (u64 i = 0; i < 256; i += 8) + { + //auto idx = mPrng.mBuffer[i].get(); + block* __restrict b = prng.mBuffer.data() + i; + block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); + //for (u64 j = 0; j < 8; ++j) + //{ + // b = b ^ mPrng.mBuffer.data()[idx[j]]; + //} + b[0] = AES::roundEnc(b[0], k[0]); + b[1] = AES::roundEnc(b[1], k[1]); + b[2] = AES::roundEnc(b[2], k[2]); + b[3] = AES::roundEnc(b[3], k[3]); + b[4] = AES::roundEnc(b[4], k[4]); + b[5] = AES::roundEnc(b[5], k[5]); + b[6] = AES::roundEnc(b[6], k[6]); + b[7] = AES::roundEnc(b[7], k[7]); + + b[0] = b[0] ^ k[0]; + b[1] = b[1] ^ k[1]; + b[2] = b[2] ^ k[2]; + b[3] = b[3] ^ k[3]; + b[4] = b[4] ^ k[4]; + b[5] = b[5] ^ k[5]; + b[6] = b[6] ^ k[6]; + b[7] = b[7] ^ k[7]; + } + } + + // generate the point list for accumulating row i. + void accOne( + PointList& pl, + u64 i, + u8* __restrict& ptr, + PRNG& prng, + block& rnd, + u64& q, + u64 qe, + u64 size) const + { + u64 j = i + 1; + pl.push_back(i, i); + + if (q + mAccumulatorSize > qe) + { + refill(prng); + ptr = (u8*)prng.mBuffer.data(); + q = 0; + } + + + for (u64 k = 0; k < mAccumulatorSize; k += 8, q += 8, j += 8) + { + assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); + rnd = block::allSame(*ptr); + ++ptr; + + //std::cout << "r " << rnd << std::endl; + auto b0 = rnd; + auto b1 = rnd.slli_epi32<1>(); + auto b2 = rnd.slli_epi32<2>(); + auto b3 = rnd.slli_epi32<3>(); + auto b4 = rnd.slli_epi32<4>(); + auto b5 = rnd.slli_epi32<5>(); + auto b6 = rnd.slli_epi32<6>(); + auto b7 = rnd.slli_epi32<7>(); + //rnd = rnd.mm_slli_epi32<8>(); + + if (j + 0 < size && b0.get(0) < 0) pl.push_back(j + 0, i); + if (j + 1 < size && b1.get(0) < 0) pl.push_back(j + 1, i); + if (j + 2 < size && b2.get(0) < 0) pl.push_back(j + 2, i); + if (j + 3 < size && b3.get(0) < 0) pl.push_back(j + 3, i); + if (j + 4 < size && b4.get(0) < 0) pl.push_back(j + 4, i); + if (j + 5 < size && b5.get(0) < 0) pl.push_back(j + 5, i); + if (j + 6 < size && b6.get(0) < 0) pl.push_back(j + 6, i); + if (j + 7 < size && b7.get(0) < 0) pl.push_back(j + 7, i); + } + + + //if (mWrapping) + { + if (j < size) + pl.push_back(j, i); + ++j; + } + + } + +#ifdef ENABLE_SSE + + using My__m128 = __m128; + +#else + using My__m128 = block; + + inline My__m128 _mm_load_ps(float* b) { return *(block*)b; } + + // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_ps&ig_expand=557 + inline My__m128 _mm_blendv_ps(My__m128 a, My__m128 b, My__m128 mask) + { + My__m128 dst; + for (u64 j = 0; j < 4; ++j) + { + if (mask.get(j) < 0) + dst.set(j, b.get(j)); + else + dst.set(j, a.get(j)); + } + return dst; + } + + + inline My__m128 _mm_setzero_ps() { return ZeroBlock; } +#endif + + template + OC_FORCEINLINE void accOneHelper( + T* __restrict xx, + My__m128 xii, + u64 j, u64 i, u64 size, + block* b + ) + { + My__m128 Zero = _mm_setzero_ps(); + +// if constexpr (std::is_same::value) +// { +// My__m128 bb[8]; +// bb[0] = _mm_load_ps((float*)&b[0]); +// bb[1] = _mm_load_ps((float*)&b[1]); +// bb[2] = _mm_load_ps((float*)&b[2]); +// bb[3] = _mm_load_ps((float*)&b[3]); +// bb[4] = _mm_load_ps((float*)&b[4]); +// bb[5] = _mm_load_ps((float*)&b[5]); +// bb[6] = _mm_load_ps((float*)&b[6]); +// bb[7] = _mm_load_ps((float*)&b[7]); +// +// +// bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); +// bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); +// bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); +// bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); +// bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); +// bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); +// bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); +// bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); +// +// block tt[8]; +// memcpy(tt, bb, 8 * 16); +// +// if (!rangeCheck || j + 0 < size) xx[j + 0] = TypeTrait::plus(xx[j + 0], tt[0]); +// if (!rangeCheck || j + 1 < size) xx[j + 1] = TypeTrait::plus(xx[j + 1], tt[1]); +// if (!rangeCheck || j + 2 < size) xx[j + 2] = TypeTrait::plus(xx[j + 2], tt[2]); +// if (!rangeCheck || j + 3 < size) xx[j + 3] = TypeTrait::plus(xx[j + 3], tt[3]); +// if (!rangeCheck || j + 4 < size) xx[j + 4] = TypeTrait::plus(xx[j + 4], tt[4]); +// if (!rangeCheck || j + 5 < size) xx[j + 5] = TypeTrait::plus(xx[j + 5], tt[5]); +// if (!rangeCheck || j + 6 < size) xx[j + 6] = TypeTrait::plus(xx[j + 6], tt[6]); +// if (!rangeCheck || j + 7 < size) xx[j + 7] = TypeTrait::plus(xx[j + 7], tt[7]); +// } +// else +// { + if ((!rangeCheck || j + 0 < size) && b[0].get(0) < 0) xx[j + 0] = TypeTrait::plus(xx[j + 0], xx[i]); + if ((!rangeCheck || j + 1 < size) && b[1].get(0) < 0) xx[j + 1] = TypeTrait::plus(xx[j + 1], xx[i]); + if ((!rangeCheck || j + 2 < size) && b[2].get(0) < 0) xx[j + 2] = TypeTrait::plus(xx[j + 2], xx[i]); + if ((!rangeCheck || j + 3 < size) && b[3].get(0) < 0) xx[j + 3] = TypeTrait::plus(xx[j + 3], xx[i]); + if ((!rangeCheck || j + 4 < size) && b[4].get(0) < 0) xx[j + 4] = TypeTrait::plus(xx[j + 4], xx[i]); + if ((!rangeCheck || j + 5 < size) && b[5].get(0) < 0) xx[j + 5] = TypeTrait::plus(xx[j + 5], xx[i]); + if ((!rangeCheck || j + 6 < size) && b[6].get(0) < 0) xx[j + 6] = TypeTrait::plus(xx[j + 6], xx[i]); + if ((!rangeCheck || j + 7 < size) && b[7].get(0) < 0) xx[j + 7] = TypeTrait::plus(xx[j + 7], xx[i]); +// } + } + + // accumulating row i. + template + OC_FORCEINLINE void accOne( + T* __restrict xx, + u64 i, + u8*& ptr, + PRNG& prng, + u64& q, + u64 qe, + u64 size) { + u64 j = i + 1; + if (width) { + if (q + width > qe) { + refill(prng); + ptr = (u8*)prng.mBuffer.data(); + q = 0; + + } + q += width; + + for (u64 k = 0; k < width; ++k, j += 8) { + assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); + block rnd = block::allSame(*(u8*)ptr++); + + block b[8]; + b[0] = rnd; + b[1] = rnd.slli_epi32<1>(); + b[2] = rnd.slli_epi32<2>(); + b[3] = rnd.slli_epi32<3>(); + b[4] = rnd.slli_epi32<4>(); + b[5] = rnd.slli_epi32<5>(); + b[6] = rnd.slli_epi32<6>(); + b[7] = rnd.slli_epi32<7>(); + +// if constexpr (std::is_same::value) { +// accOneHelper(xx, _mm_setzero_ps(), j, i, size, b); +// } +// else { + My__m128 xii;// = ::_mm_set_ps(0.0f, 0.0f, 0.0f, 0.0f); + memset(&xii, 0, sizeof(My__m128)); + accOneHelper(xx, xii, j, i, size, b); +// } + } + } + + if (!rangeCheck || j < size) { + auto xj = TypeTrait::plus(xx[j], xx[i]); + xx[j] = xj; + } + } + + + // accumulating row i. + template + OC_FORCEINLINE void accOne( + T0* __restrict xx0, + T1* __restrict xx1, + u64 i, + u8*& ptr, + PRNG& prng, + u64& q, + u64 qe, + u64 size) + { + u64 j = i + 1; + if (width) + { + + + if (q + width > qe) + { + refill(prng); + ptr = (u8*)prng.mBuffer.data(); + q = 0; + + } + q += width; + + for (u64 k = 0; k < width; ++k, j += 8) + { + assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); + block rnd = block::allSame(*(u8*)ptr++); + + block b[8]; + b[0] = rnd; + b[1] = rnd.slli_epi32<1>(); + b[2] = rnd.slli_epi32<2>(); + b[3] = rnd.slli_epi32<3>(); + b[4] = rnd.slli_epi32<4>(); + b[5] = rnd.slli_epi32<5>(); + b[6] = rnd.slli_epi32<6>(); + b[7] = rnd.slli_epi32<7>(); + +// if constexpr (std::is_same::value) { +// auto xii0 = _mm_load_ps((float*)(xx0 + i)); +// accOneHelper(xx0, xii0, j, i, size, b); +// } +// else { + accOneHelper(xx0, _mm_setzero_ps(), j, i, size, b); +// } +// if constexpr (std::is_same::value) { +// auto xii1 = _mm_load_ps((float*)(xx1 + i)); +// accOneHelper(xx1, xii1, j, i, size, b); +// } +// else { + accOneHelper(xx1, _mm_setzero_ps(), j, i, size, b); +// } + } + } + + if (!rangeCheck || j < size) + { + xx0[j] = TypeTrait::plus(xx0[j], xx0[i]); + xx1[j] = TypeTrait::plus(xx1[j], xx1[i]); + } + } + + + // accumulate x onto itself. + template + void accumulate(span x) + { + PRNG prng(mSeed ^ OneBlock); + + u64 i = 0; + auto size = x.size(); + auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); + u8* ptr = (u8*)prng.mBuffer.data(); + auto qe = prng.mBuffer.size() * 128 / 8; + u64 q = 0; + T* __restrict xx = x.data(); + + { + +#define CASE(I) case I:\ + for (; i < main; ++i)\ + accOne(xx, i, ptr, prng, q, qe, size);\ + for (; i < size; ++i)\ + accOne(xx, i, ptr, prng, q, qe, size);\ + break + + switch (mAccumulatorSize / 8) + { + CASE(0); + CASE(1); + CASE(2); + CASE(3); + CASE(4); + default: + throw RTE_LOC; + break; + } +#undef CASE + } + } + + + // accumulate x onto itself. + template + void accumulate(span x0, span x1) + { + PRNG prng(mSeed ^ OneBlock); + + u64 i = 0; + auto size = x0.size(); + auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); + u8* ptr = (u8*)prng.mBuffer.data(); + auto qe = prng.mBuffer.size() * 128 / 8; + u64 q = 0; + T0* __restrict xx0 = x0.data(); + T1* __restrict xx1 = x1.data(); + + { + +#define CASE(I) case I:\ + for (; i < main; ++i)\ + accOne(xx0,xx1, i, ptr, prng, q, qe, size);\ + for (; i < size; ++i)\ + accOne(xx0, xx1, i, ptr, prng, q, qe, size);\ + break + + switch (mAccumulatorSize / 8) + { + CASE(0); + CASE(1); + CASE(2); + CASE(3); + CASE(4); + default: + throw RTE_LOC; + break; + } +#undef CASE + } + } + }; +} diff --git a/libOTe/Tools/Subfield/Expander.h b/libOTe/Tools/Subfield/Expander.h new file mode 100644 index 00000000..4f64d559 --- /dev/null +++ b/libOTe/Tools/Subfield/Expander.h @@ -0,0 +1,499 @@ +// � 2023 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +#pragma once + +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Range.h" +#include "libOTe/Tools/LDPC/Mtx.h" +#include "libOTe/Tools/EACode/Util.h" + +namespace osuCrypto::Subfield +{ + + // The encoder for the expander matrix B. + // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly + // with fixed row weight mExpanderWeight. + template + class ExpanderCode + { + public: + + void config( + u64 messageSize, + u64 codeSize = 0 /* default is 5* messageSize */, + u64 expanderWeight = 21, + block seed = block(33333, 33333)) + { + mMessageSize = messageSize; + mCodeSize = codeSize; + mExpanderWeight = expanderWeight; + mSeed = seed; + + } + + // the seed that generates the code. + block mSeed = block(0, 0); + + // The message size of the code. K. + u64 mMessageSize = 0; + + // The codeword size of the code. n. + u64 mCodeSize = 0; + + // The row weight of the B matrix. + u64 mExpanderWeight = 0; + + u64 parityRows() const { return mCodeSize - mMessageSize; } + u64 parityCols() const { return mCodeSize; } + + u64 generatorRows() const { return mMessageSize; } + u64 generatorCols() const { return mCodeSize; } + + + + template + typename std::enable_if::type + expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const + { + auto r = prng.get(); + return ee[r]; + } + + template + typename std::enable_if<(count == 1)>::type + expandOne( + const T* __restrict ee1, + const T2* __restrict ee2, + T* __restrict y1, + T2* __restrict y2, + detail::ExpanderModd& prng)const + { + auto r = prng.get(); + + if (Add) + { + *y1 = TypeTrait::plus(*y1, ee1[r]); + *y2 = TypeTrait::plus(*y2, ee2[r]); + } + else + { + + *y1 = ee1[r]; + *y2 = ee2[r]; + } + } + + template + OC_FORCEINLINE typename std::enable_if<(count > 1), T>::type + expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const + { + if constexpr (count >= 8) + { + u64 rr[8]; + T w[8]; + rr[0] = prng.get(); + rr[1] = prng.get(); + rr[2] = prng.get(); + rr[3] = prng.get(); + rr[4] = prng.get(); + rr[5] = prng.get(); + rr[6] = prng.get(); + rr[7] = prng.get(); + + w[0] = ee[rr[0]]; + w[1] = ee[rr[1]]; + w[2] = ee[rr[2]]; + w[3] = ee[rr[3]]; + w[4] = ee[rr[4]]; + w[5] = ee[rr[5]]; + w[6] = ee[rr[6]]; + w[7] = ee[rr[7]]; + + auto ww = + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + w[0], + w[1]), + w[2]), + w[3]), + w[4]), + w[5]), + w[6]), + w[7]); + + if constexpr (count > 8) + ww = TypeTrait::plus(ww, expandOne(ee, prng)); + return ww; + } + else + { + + auto r = prng.get(); + auto ww = expandOne(ee, prng); + return TypeTrait::plus(ww, ee[r]); + } + } + + template + OC_FORCEINLINE typename std::enable_if<(count > 1)>::type + expandOne( + const T* __restrict ee1, + const T2* __restrict ee2, + T* __restrict y1, + T2* __restrict y2, + detail::ExpanderModd& prng) const + { + if constexpr (count >= 8) + { + u64 rr[8]; + T w1[8]; + T2 w2[8]; + rr[0] = prng.get(); + rr[1] = prng.get(); + rr[2] = prng.get(); + rr[3] = prng.get(); + rr[4] = prng.get(); + rr[5] = prng.get(); + rr[6] = prng.get(); + rr[7] = prng.get(); + + w1[0] = ee1[rr[0]]; + w1[1] = ee1[rr[1]]; + w1[2] = ee1[rr[2]]; + w1[3] = ee1[rr[3]]; + w1[4] = ee1[rr[4]]; + w1[5] = ee1[rr[5]]; + w1[6] = ee1[rr[6]]; + w1[7] = ee1[rr[7]]; + + w2[0] = ee2[rr[0]]; + w2[1] = ee2[rr[1]]; + w2[2] = ee2[rr[2]]; + w2[3] = ee2[rr[3]]; + w2[4] = ee2[rr[4]]; + w2[5] = ee2[rr[5]]; + w2[6] = ee2[rr[6]]; + w2[7] = ee2[rr[7]]; + + auto ww1 = + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + w1[0], + w1[1]), + w1[2]), + w1[3]), + w1[4]), + w1[5]), + w1[6]), + w1[7]); + auto ww2 = + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + TypeTrait::plus( + w2[0], + w2[1]), + w2[2]), + w2[3]), + w2[4]), + w2[5]), + w2[6]), + w2[7]); + + if constexpr (count > 8) + { + T yy1; + T2 yy2; + expandOne(ee1, ee2, &yy1, &yy2, prng); + ww1 = TypeTrait::plus(ww1, yy1); + ww2 = TypeTrait::plus(ww2, yy2); + } + + if constexpr (Add) + { + *y1 = TypeTrait::plus(*y1, ww1); + *y2 = TypeTrait::plus(*y2, ww2); + } + else + { + *y1 = ww1; + *y2 = ww2; + } + + } + else + { + + auto r = prng.get(); + if constexpr (Add) + { + auto w1 = ee1[r]; + auto w2 = ee2[r]; + expandOne(ee1, ee2, y1, y2, prng); + *y1 = TypeTrait::plus(*y1, w1); + *y2 = TypeTrait::plus(*y2, w2); + + } + else + { + + T yy1; + T2 yy2; + expandOne(ee1, ee2, &yy1, &yy2, prng); + *y1 = TypeTrait::plus(yy1, ee1[r]); + *y2 = TypeTrait::plus(yy2, ee2[r]); + } + } + } + + template + void expand( + span e, + span w) const + { + assert(w.size() == mMessageSize); + assert(e.size() == mCodeSize); + detail::ExpanderModd prng(mSeed, mCodeSize); + + const T* __restrict ee = e.data(); + T* __restrict ww = w.data(); + + auto main = mMessageSize / 8 * 8; + u64 i = 0; + + for (; i < main; i += 8) + { +#define CASE(I) \ + case I:\ + if constexpr(Add)\ + {\ + ww[i + 0] = TypeTrait::plus(ww[i + 0], expandOne(ee, prng));\ + ww[i + 1] = TypeTrait::plus(ww[i + 1], expandOne(ee, prng));\ + ww[i + 2] = TypeTrait::plus(ww[i + 2], expandOne(ee, prng));\ + ww[i + 3] = TypeTrait::plus(ww[i + 3], expandOne(ee, prng));\ + ww[i + 4] = TypeTrait::plus(ww[i + 4], expandOne(ee, prng));\ + ww[i + 5] = TypeTrait::plus(ww[i + 5], expandOne(ee, prng));\ + ww[i + 6] = TypeTrait::plus(ww[i + 6], expandOne(ee, prng));\ + ww[i + 7] = TypeTrait::plus(ww[i + 7], expandOne(ee, prng));\ + }\ + else\ + {\ + ww[i + 0] = expandOne(ee, prng);\ + ww[i + 1] = expandOne(ee, prng);\ + ww[i + 2] = expandOne(ee, prng);\ + ww[i + 3] = expandOne(ee, prng);\ + ww[i + 4] = expandOne(ee, prng);\ + ww[i + 5] = expandOne(ee, prng);\ + ww[i + 6] = expandOne(ee, prng);\ + ww[i + 7] = expandOne(ee, prng);\ + }\ + break + + switch (mExpanderWeight) + { + CASE(5); + CASE(7); + CASE(9); + CASE(11); + CASE(21); + CASE(40); + default: + for (u64 jj = 0; jj < 8; ++jj) + { + auto r = prng.get(); + auto wv = ee[r]; + + for (auto j = 1ull; j < mExpanderWeight; ++j) + { + r = prng.get(); + wv = TypeTrait::plus(wv, ee[r]); + } + if constexpr (Add) + ww[i + jj] = TypeTrait::plus(ww[i + jj], wv); + else + ww[i + jj] = wv; + + } + } +#undef CASE + } + + for (; i < mMessageSize; ++i) + { + auto wv = ee[prng.get()]; + for (auto j = 1ull; j < mExpanderWeight; ++j) + wv = TypeTrait::plus(wv, ee[prng.get()]); + + if constexpr (Add) + ww[i] = TypeTrait::plus(ww[i], wv); + else + ww[i] = wv; + } + } + + template + void expand( + span e1, + span e2, + span w1, + span w2 + ) const + { + assert(w1.size() == mMessageSize); + assert(w2.size() == mMessageSize); + assert(e1.size() == mCodeSize); + assert(e2.size() == mCodeSize); + detail::ExpanderModd prng(mSeed, mCodeSize); + + const T* __restrict ee1 = e1.data(); + const T2* __restrict ee2 = e2.data(); + T* __restrict ww1 = w1.data(); + T2* __restrict ww2 = w2.data(); + + auto main = mMessageSize / 8 * 8; + u64 i = 0; + + for (; i < main; i += 8) + { +#define CASE(I) \ + case I:\ + expandOne(ee1, ee2, &ww1[i + 0], &ww2[i + 0], prng);\ + expandOne(ee1, ee2, &ww1[i + 1], &ww2[i + 1], prng);\ + expandOne(ee1, ee2, &ww1[i + 2], &ww2[i + 2], prng);\ + expandOne(ee1, ee2, &ww1[i + 3], &ww2[i + 3], prng);\ + expandOne(ee1, ee2, &ww1[i + 4], &ww2[i + 4], prng);\ + expandOne(ee1, ee2, &ww1[i + 5], &ww2[i + 5], prng);\ + expandOne(ee1, ee2, &ww1[i + 6], &ww2[i + 6], prng);\ + expandOne(ee1, ee2, &ww1[i + 7], &ww2[i + 7], prng);\ + break + + switch (mExpanderWeight) + { + CASE(5); + CASE(7); + CASE(9); + CASE(11); + CASE(21); + CASE(40); + default: + for (u64 jj = 0; jj < 8; ++jj) + { + auto r = prng.get(); + auto wv1 = ee1[r]; + auto wv2 = ee2[r]; + + for (auto j = 1ull; j < mExpanderWeight; ++j) + { + r = prng.get(); + wv1 = TypeTrait::plus(wv1, ee1[r]); + wv2 = TypeTrait::plus(wv2, ee2[r]); + } + if constexpr (Add) + { + ww1[i + jj] = TypeTrait::plus(ww1[i + jj], wv1); + ww2[i + jj] = TypeTrait::plus(ww2[i + jj], wv2); + } + else + { + + ww1[i + jj] = wv1; + ww2[i + jj] = wv2; + } + } + } +#undef CASE + } + + for (; i < mMessageSize; ++i) + { + auto r = prng.get(); + auto wv1 = ee1[r]; + auto wv2 = ee2[r]; + for (auto j = 1ull; j < mExpanderWeight; ++j) + { + r = prng.get(); + wv1 = TypeTrait::plus(wv1, ee1[r]); + wv2 = TypeTrait::plus(wv2, ee2[r]); + + } + if constexpr (Add) + { + ww1[i] = TypeTrait::plus(ww1[i], wv1); + ww2[i] = TypeTrait::plus(ww2[i], wv2); + } + else + { + ww1[i] = wv1; + ww2[i] = wv2; + } + } + } + + + SparseMtx getB() const + { + //PRNG prng(mSeed); + detail::ExpanderModd prng(mSeed, mCodeSize); + PointList points(mMessageSize, mCodeSize); + + std::vector row(mExpanderWeight); + + { + + for (auto i : rng(mMessageSize)) + { + row[0] = prng.get(); + //points.push_back(i, row[0]); + for (auto j : rng(1, mExpanderWeight)) + { + //do { + row[j] = prng.get(); + //} while + auto iter = std::find(row.data(), row.data() + j, row[j]); + if (iter != row.data() + j) + { + row[j] = ~0ull; + *iter = ~0ull; + } + //throw RTE_LOC; + + } + for (auto j : rng(mExpanderWeight)) + { + + if (row[j] != ~0ull) + { + //std::cout << row[j] << " "; + points.push_back(i, row[j]); + } + else + { + //std::cout << "* "; + } + } + //std::cout << std::endl; + } + } + + return points; + } + + }; +} diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h new file mode 100644 index 00000000..4d7878d8 --- /dev/null +++ b/libOTe/Tools/Subfield/Subfield.h @@ -0,0 +1,231 @@ +#include "libOTe/Vole/Noisy/NoisyVoleSender.h" +#include "cryptoTools/Common/BitIterator.h" +#include "cryptoTools/Common/BitVector.h" + +namespace osuCrypto::Subfield { + + struct F128 { + block b; + F128() = default; + explicit F128(const block& b) : b(b) {} +// OC_FORCEINLINE F128 operator+(const F128& rhs) const { +// F128 ret; +// ret.b = b ^ rhs.b; +// return ret; +// } +// OC_FORCEINLINE F128 operator-(const F128& rhs) const { +// F128 ret; +// ret.b = b ^ rhs.b; +// return ret; +// } +// OC_FORCEINLINE F128 operator*(const F128& rhs) const { +// F128 ret; +// ret.b = b.gf128Mul(rhs.b); +// return ret; +// } +// OC_FORCEINLINE bool operator==(const F128& rhs) const { +// return b == rhs.b; +// } +// OC_FORCEINLINE bool operator!=(const F128& rhs) const { +// return b != rhs.b; +// } + }; + + /* + * Primitive TypeTrait for integers + */ + template + struct TypeTraitPrimitive { + using G = T; + using F = T; + + static constexpr size_t bitsG = sizeof(G) * 8; + static constexpr size_t bitsF = sizeof(F) * 8; + static constexpr size_t bytesG = sizeof(G); + static constexpr size_t bytesF = sizeof(F); + + static OC_FORCEINLINE F plus(const F& lhs, const F& rhs) { + return lhs + rhs; + } + static OC_FORCEINLINE F minus(const F& lhs, const F& rhs) { + return lhs - rhs; + } + static OC_FORCEINLINE F mul(const F& lhs, const F& rhs) { + return lhs * rhs; + } + static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { + return lhs == rhs; + } + + static OC_FORCEINLINE BitVector BitVectorF(F& x) { + return {(u8*)&x, bitsF}; + } + + static OC_FORCEINLINE F fromBlock(const block& b) { + return b.get()[0]; + } + static OC_FORCEINLINE F pow(u64 power) { + F ret = 1; + ret <<= power; + return ret; + } + }; + + using TypeTrait64 = TypeTraitPrimitive; + + /* + * TypeTrait for GF(2^128) + */ + struct TypeTraitF128 { + using G = block; + using F = block; + + static constexpr size_t bitsG = sizeof(G) * 8; + static constexpr size_t bitsF = sizeof(F) * 8; + static constexpr size_t bytesG = sizeof(G); + static constexpr size_t bytesF = sizeof(F); + + static OC_FORCEINLINE F plus(const F& lhs, const F& rhs) { + return lhs ^ rhs; + } + static OC_FORCEINLINE F minus(const F& lhs, const F& rhs) { + return lhs ^ rhs; + } + static OC_FORCEINLINE F mul(const F& lhs, const F& rhs) { + return lhs.gf128Mul(rhs); + } + static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { + return lhs == rhs; + } + + static OC_FORCEINLINE BitVector BitVectorF(F& x) { + return {(u8*)&x, bitsF}; + } + + static OC_FORCEINLINE F fromBlock(const block& b) { + return b; + } + static OC_FORCEINLINE F pow(u64 power) { + F ret = ZeroBlock; + *BitIterator((u8*)&ret, power) = 1; + return ret; + } + }; + + // array + template + struct Vec { + std::array v; + OC_FORCEINLINE Vec operator+(const Vec& rhs) const { + Vec ret; + for (u64 i = 0; i < N; ++i) { + ret.v[i] = v[i] + rhs.v[i]; + } + return ret; + } + OC_FORCEINLINE Vec operator-(const Vec& rhs) const { + Vec ret; + for (u64 i = 0; i < N; ++i) { + ret.v[i] = v[i] - rhs.v[i]; + } + return ret; + } + OC_FORCEINLINE Vec operator*(const T& rhs) const { + Vec ret; + for (u64 i = 0; i < N; ++i) { + ret.v[i] = v[i] * rhs; + } + return ret; + } + OC_FORCEINLINE T operator[](u64 idx) const { + return v[idx]; + } + OC_FORCEINLINE T& operator[](u64 idx) { + return v[idx]; + } + OC_FORCEINLINE bool operator==(const Vec& rhs) const { + for (u64 i = 0; i < N; ++i) { + if (v[i] != rhs.v[i]) return false; + } + return true; + } + OC_FORCEINLINE bool operator!=(const Vec& rhs) const { + return !(*this == rhs); + } + }; + + // TypeTraitVec for array of integers + template + struct TypeTraitVec { + using G = T; + using F = Vec; + + static constexpr size_t bitsG = sizeof(G) * 8; + static constexpr size_t bitsF = sizeof(F) * 8; + static constexpr size_t bytesG = sizeof(G); + static constexpr size_t bytesF = sizeof(F); + + static constexpr size_t sizeBlocks = (bytesF + sizeof(block) - 1) / sizeof(block); + static constexpr size_t size = N; + + static OC_FORCEINLINE F plus(const F& lhs, const F& rhs) { + F ret; + for (u64 i = 0; i < N; ++i) { + ret.v[i] = lhs.v[i] + rhs.v[i]; + } + return ret; + } + static OC_FORCEINLINE F minus(const F& lhs, const F& rhs) { + F ret; + for (u64 i = 0; i < N; ++i) { + ret.v[i] = lhs.v[i] - rhs.v[i]; + } + return ret; + } + static OC_FORCEINLINE F mul(const F& lhs, const G& rhs) { + F ret; + for (u64 i = 0; i < N; ++i) { + ret.v[i] = lhs.v[i] * rhs; + } + return ret; + } + static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { + for (u64 i = 0; i < N; ++i) { + if (lhs.v[i] != rhs.v[i]) return false; + } + return true; + } + static OC_FORCEINLINE G plus(const G& lhs, const G& rhs) { + return lhs + rhs; + } + + static OC_FORCEINLINE BitVector BitVectorF(F& x) { + return {(u8*)&x, bitsF}; + } + + static OC_FORCEINLINE F fromBlock(const block& b) { + F ret; + if (N * sizeof(T) <= sizeof(block)) { + memcpy(ret.v.data(), &b, bytesF); + return ret; + } + else { + std::array buf; + for (u64 i = 0; i < sizeBlocks; ++i) { + buf[i] = b + block(i, i); + } + mAesFixedKey.hashBlocks(buf.data(), buf.data()); + memcpy(&ret, &buf, sizeof(F)); + return ret; + } + } + + static OC_FORCEINLINE F pow(u64 power) { + F ret; + memset(&ret, 0, sizeof(ret)); + *BitIterator((u8*)&ret, power) = 1; + return ret; + } + }; + +} diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h new file mode 100644 index 00000000..bcd8bbc4 --- /dev/null +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -0,0 +1,1444 @@ +#pragma once +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/Matrix.h" +#include "cryptoTools/Common/Timer.h" +#include "cryptoTools/Common/Aligned.h" +#include "cryptoTools/Common/Range.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "libOTe/Tools/Coproto.h" +#include "libOTe/Tools/SilentPprf.h" +#include "SubfieldPprf.h" +#include + +namespace osuCrypto::Subfield +{ + template + void copyOut( + span> lvl, + MatrixView output, + u64 totalTrees, + u64 tIdx, + PprfOutputFormat oFormat, + std::function> lvl)>& callback) + { + + if (oFormat == PprfOutputFormat::InterleavedTransposed) + { + // not having an even (8) number of trees is not supported. + if (totalTrees % 8) + throw RTE_LOC; + if (lvl.size() % 16) + throw RTE_LOC; + + // + //auto rowsPer = 16; + //auto step = lvl.size() + + //auto sectionSize = + + if (lvl.size() < 16) + throw RTE_LOC; + + + auto setIdx = tIdx / 8; + auto blocksPerSet = lvl.size() * 8 / 128; + + + + auto numSets = totalTrees / 8; + auto begin = setIdx; + auto step = numSets; + + if (oFormat == PprfOutputFormat::InterleavedTransposed) + { + // todo + throw RTE_LOC; + // auto end = std::min(begin + step * blocksPerSet, output.cols()); + + // for (u64 i = begin, k = 0; i < end; i += step, ++k) + // { + // auto& io = *(std::array*)(&lvl[k * 16]); + // transpose128(io.data()); + // for (u64 j = 0; j < 128; ++j) + // output(j, i) = io[j]; + // } + } + else + { + // no op + } + + + } + else if (oFormat == PprfOutputFormat::Plain) + { + + auto curSize = std::min(totalTrees - tIdx, 8); + if (curSize == 8) + { + + for (u64 i = 0; i < output.rows(); ++i) + { + auto oi = output[i].subspan(tIdx, 8); + auto& ii = lvl[i]; + oi[0] = ii[0]; + oi[1] = ii[1]; + oi[2] = ii[2]; + oi[3] = ii[3]; + oi[4] = ii[4]; + oi[5] = ii[5]; + oi[6] = ii[6]; + oi[7] = ii[7]; + } + } + else + { + for (u64 i = 0; i < output.rows(); ++i) + { + auto oi = output[i].subspan(tIdx, curSize); + auto& ii = lvl[i]; + for (u64 j = 0; j < curSize; ++j) + oi[j] = ii[j]; + } + } + + } + else if (oFormat == PprfOutputFormat::BlockTransposed) + { + + auto curSize = std::min(totalTrees - tIdx, 8); + if (curSize == 8) + { + for (u64 i = 0; i < output.cols(); ++i) + { + auto& ii = lvl[i]; + output(tIdx + 0, i) = ii[0]; + output(tIdx + 1, i) = ii[1]; + output(tIdx + 2, i) = ii[2]; + output(tIdx + 3, i) = ii[3]; + output(tIdx + 4, i) = ii[4]; + output(tIdx + 5, i) = ii[5]; + output(tIdx + 6, i) = ii[6]; + output(tIdx + 7, i) = ii[7]; + } + } + else + { + for (u64 i = 0; i < output.cols(); ++i) + { + auto& ii = lvl[i]; + for (u64 j = 0; j < curSize; ++j) + output(tIdx + j, i) = ii[j]; + } + } + + } + else if (oFormat == PprfOutputFormat::Interleaved) + { + // no op + } + else if (oFormat == PprfOutputFormat::Callback) + callback(tIdx, lvl); + else + throw RTE_LOC; + } + + template + class SilentSubfieldPprfSender : public TimerAdapter + { + public: + using F = typename TypeTrait::F; + u64 mDomain = 0, mDepth = 0, mPntCount = 0; + std::vector mValue; + bool mPrint = false; + TreeAllocator mTreeAlloc; + Matrix> mBaseOTs; + + std::function>)> mOutputFn; + + + SilentSubfieldPprfSender() = default; + SilentSubfieldPprfSender(const SilentSubfieldPprfSender&) = delete; + SilentSubfieldPprfSender(SilentSubfieldPprfSender&&) = delete; + + SilentSubfieldPprfSender(u64 domainSize, u64 pointCount) + { + configure(domainSize, pointCount); + } + + void configure(u64 domainSize, u64 pointCount) + { + mDomain = domainSize; + mDepth = log2ceil(mDomain); + mPntCount = pointCount; + //mPntCount8 = roundUpTo(pointCount, 8); + + mBaseOTs.resize(0, 0); + } + + + // the number of base OTs that should be set. + u64 baseOtCount() const + { + return mDepth * mPntCount; + } + + // returns true if the base OTs are currently set. + bool hasBaseOts() const + { + return mBaseOTs.size(); + } + + + void setBase(span> baseMessages) { + if (baseOtCount() != static_cast(baseMessages.size())) + throw RTE_LOC; + + mBaseOTs.resize(mPntCount, mDepth); + for (u64 i = 0; i < static_cast(mBaseOTs.size()); ++i) + mBaseOTs(i) = baseMessages[i]; + } + + task<> expand(Socket& chls, span value, PRNG& prng, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + { + MatrixView o(output.data(), output.size(), 1); + return expand(chls, value, prng, o, oFormat, activeChildXorDelta, numThreads); + } + + task<> expand( + Socket& chl, + span value, + PRNG& prng, + MatrixView output, + PprfOutputFormat oFormat, + bool activeChildXorDelta, + u64 numThreads) + { + if (activeChildXorDelta) + setValue(value); + setTimePoint("SilentMultiPprfSender.start"); + //gTimer.setTimePoint("send.enter"); + + + + if (oFormat == PprfOutputFormat::Plain) + { + if (output.rows() != mDomain) + throw RTE_LOC; + + if (output.cols() != mPntCount) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::BlockTransposed) + { + if (output.cols() != mDomain) + throw RTE_LOC; + + if (output.rows() != mPntCount) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::InterleavedTransposed) + { + if (output.rows() != 128) + throw RTE_LOC; + + //if (output.cols() > (mDomain * mPntCount + 127) / 128) + // throw RTE_LOC; + + if (mPntCount & 7) + throw RTE_LOC; + } + else if + (oFormat == PprfOutputFormat::Interleaved) + { + if (output.cols() != 1) + throw RTE_LOC; + if (mDomain & 1) + throw RTE_LOC; + + auto rows = output.rows(); + if (rows > (mDomain * mPntCount) || + rows / 128 != (mDomain * mPntCount) / 128) + throw RTE_LOC; + if (mPntCount & 7) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::Callback) + { + if (mDomain & 1) + throw RTE_LOC; + if (mPntCount & 7) + throw RTE_LOC; + } + else + { + throw RTE_LOC; + } + + + MC_BEGIN(task<>, this, numThreads, oFormat, output, &prng, &chl, activeChildXorDelta, + i = u64{}, + dd = u64{} + ); + + + if (oFormat == PprfOutputFormat::Callback && numThreads > 1) + throw RTE_LOC; + + dd = mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); + mTreeAlloc.reserve(numThreads, (1ull << (dd + 1)) + (32 * (dd+1))); + setTimePoint("SilentMultiPprfSender.reserve"); + + mExps.clear(); mExps.reserve(divCeil(mPntCount, 8)); + for (i = 0; i < mPntCount; i += 8) + { + mExps.emplace_back(*this, prng.get(), i, oFormat, output, activeChildXorDelta, chl.fork()); + mExps.back().mFuture = macoro::make_eager(mExps.back().run()); + //MC_AWAIT(mExps.back().run()); + } + + for (i = 0; i < mExps.size(); ++i) + MC_AWAIT(mExps[i].mFuture); + + mExps.clear(); + setTimePoint("SilentMultiPprfSender.join"); + + mBaseOTs = {}; + //mTreeAlloc.clear(); + setTimePoint("SilentMultiPprfSender.de-alloc"); + + MC_END(); + + + } + + void setValue(span value) + { + + mValue.resize(mPntCount); + + if (value.size() == 1) + { + std::fill(mValue.begin(), mValue.end(), value[0]); + } + else + { + if ((u64)value.size() != mPntCount) + throw RTE_LOC; + + std::copy(value.begin(), value.end(), mValue.begin()); + } + } + + void clear() + { + mBaseOTs.resize(0, 0); + mDomain = 0; + mDepth = 0; + mPntCount = 0; + } + + struct Expander + { + SilentSubfieldPprfSender& pprf; + Socket chl; + std::array aes; + PRNG prng; + u64 dd, treeIdx, min, d; + bool mActiveChildXorDelta = true; + + macoro::eager_task mFuture; + std::vector>> mLevels; + + //std::unique_ptr uPtr_; + + // tree will hold the full GGM tree. Note that there are 8 + // indepenendent trees that are being processed together. + // The trees are flattenned to that the children of j are + // located at 2*j and 2*j+1. + span> tree; + + // sums will hold the left and right GGM tree sums + // for each level. For example sums[0][i][5] will + // hold the sum of the left children for level i of + // the 5th tree. + std::array>, 2> sums; + // sums for the last level + std::array, 2> lastSums; + std::vector> lastOts; + + PprfOutputFormat oFormat; + + MatrixView output; + + // The number of real trees for this iteration. + // Returns the i'th level of the current 8 trees. The + // children of node j on level i are located at 2*j and + // 2*j+1 on level i+1. + span> getLevel(u64 i, u64 g) + { + return mLevels[i]; + }; + + span> getLastLevel(u64 i, u64 g) { + if (oFormat == PprfOutputFormat::Interleaved && i == pprf.mDepth) + { + auto b = (AlignedArray*)output.data(); + auto forest = g / 8; + assert(g % 8 == 0); + b += forest * pprf.mDomain; + return span>(b, pprf.mDomain); + } + + throw RTE_LOC; + } + + Expander(SilentSubfieldPprfSender& p, block seed, u64 treeIdx_, + PprfOutputFormat of, MatrixViewo, bool activeChildXorDelta, Socket&& s) + : pprf(p) + , chl(std::move(s)) + , mActiveChildXorDelta(activeChildXorDelta) + { + treeIdx = treeIdx_; + assert((treeIdx & 7) == 0); + output = o; + oFormat = of; + // A public PRF/PRG that we will use for deriving the GGM tree. + aes[0].setKey(toBlock(3242342)); + aes[1].setKey(toBlock(8993849)); + prng.SetSeed(seed); + dd = pprf.mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); + } + + task<> run() + { + MC_BEGIN(task<>, this); + + + #ifdef DEBUG_PRINT_PPRF + chl.asyncSendCopy(mValue); + #endif + // pprf.setTimePoint("SilentMultiPprfSender.begin " + std::to_string(treeIdx)); + { + tree = pprf.mTreeAlloc.get(); + assert(tree.size() >= 1ull << (dd)); + assert((u64)tree.data() % 32 == 0); + mLevels.resize(dd+1); + mLevels[0] = tree.subspan(0, 1); + auto rem = tree.subspan(mLevels[0].size()); + for (u64 i = 1; i < dd + 1; i++) + { + while ((u64)rem.data() % 32) + rem = rem.subspan(1); + + mLevels[i] = rem.subspan(0, mLevels[i - 1].size() * 2); + rem = rem.subspan(mLevels[i].size()); + } + } + // pprf.setTimePoint("SilentMultiPprfSender.alloc " + std::to_string(treeIdx)); + + // This thread will process 8 trees at a time. It will interlace + // the sets of trees are processed with the other threads. + { + memset(lastSums[0].data(), 0, lastSums[0].size() * sizeof(F)); + memset(lastSums[1].data(), 0, lastSums[1].size() * sizeof(F)); + + // The number of real trees for this iteration. + min = std::min(8, pprf.mPntCount - treeIdx); + //gTimer.setTimePoint("send.start" + std::to_string(treeIdx)); + + // Populate the zeroth level of the GGM tree with random seeds. + prng.get(getLevel(0, treeIdx)); + + // Allocate space for our sums of each level. + sums[0].resize(pprf.mDepth); + sums[1].resize(pprf.mDepth); + + // For each level perform the following. + for (u64 d = 0; d < pprf.mDepth; ++d) + { + // The previous level of the GGM tree. + auto level0 = getLevel(d, treeIdx); + + // The next level of theGGM tree that we are populating. + auto level1 = getLevel(d + 1, treeIdx); + + // The total number of children in this level. + auto width = static_cast(level1.size()); + + // For each child, populate the child by expanding the parent. + for (u64 childIdx = 0; childIdx < width; ) + { + // Index of the parent in the previous level. + auto parentIdx = childIdx >> 1; + + // The value of the parent. + auto& parent = level0[parentIdx]; + + // The bit that indicates if we are on the left child (0) + // or on the right child (1). + for (u64 keep = 0; keep < 2; ++keep, ++childIdx) + { + // The child that we will write in this iteration. + auto& child = level1[childIdx]; + + // The sum that this child node belongs to. + auto& sum = sums[keep][d]; + + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + aes[keep].template hashBlocks<8>(parent.data(), child.data()); + + if (d < pprf.mDepth - 1) { + // for intermediate levels, same as before + // Update the running sums for this level. We keep + // a left and right totals for each level. + sum[0] = sum[0] ^ child[0]; + sum[1] = sum[1] ^ child[1]; + sum[2] = sum[2] ^ child[2]; + sum[3] = sum[3] ^ child[3]; + sum[4] = sum[4] ^ child[4]; + sum[5] = sum[5] ^ child[5]; + sum[6] = sum[6] ^ child[6]; + sum[7] = sum[7] ^ child[7]; + } else { + if (getLastLevel(pprf.mDepth, treeIdx).size() <= childIdx) { + childIdx = width; + break; + } + auto& realChild = getLastLevel(pprf.mDepth, treeIdx)[childIdx]; + auto& lastSum = lastSums[keep]; + realChild[0] = TypeTrait::fromBlock(child[0]); + lastSum[0] = TypeTrait::plus(lastSum[0], realChild[0]); + realChild[1] = TypeTrait::fromBlock(child[1]); + lastSum[1] = TypeTrait::plus(lastSum[1], realChild[1]); + realChild[2] = TypeTrait::fromBlock(child[2]); + lastSum[2] = TypeTrait::plus(lastSum[2], realChild[2]); + realChild[3] = TypeTrait::fromBlock(child[3]); + lastSum[3] = TypeTrait::plus(lastSum[3], realChild[3]); + realChild[4] = TypeTrait::fromBlock(child[4]); + lastSum[4] = TypeTrait::plus(lastSum[4], realChild[4]); + realChild[5] = TypeTrait::fromBlock(child[5]); + lastSum[5] = TypeTrait::plus(lastSum[5], realChild[5]); + realChild[6] = TypeTrait::fromBlock(child[6]); + lastSum[6] = TypeTrait::plus(lastSum[6], realChild[6]); + realChild[7] = TypeTrait::fromBlock(child[7]); + lastSum[7] = TypeTrait::plus(lastSum[7], realChild[7]); + } + } + } + } + + + #ifdef DEBUG_PRINT_PPRF + // If we are debugging, then send over the full tree + // to make sure its correct on the other side. + chl.asyncSendCopy(tree); + #endif + + // For all but the last level, mask the sums with the + // OT strings and send them over. + for (u64 d = 0; d < pprf.mDepth - mActiveChildXorDelta; ++d) + { + for (u64 j = 0; j < min; ++j) + { + #ifdef DEBUG_PRINT_PPRF + if (mPrint) + { + std::cout << "c[" << treeIdx + j << "][" << d << "][0] " << sums[0][d][j] << " " << mBaseOTs[treeIdx + j][d][0] << std::endl;; + std::cout << "c[" << treeIdx + j << "][" << d << "][1] " << sums[1][d][j] << " " << mBaseOTs[treeIdx + j][d][1] << std::endl;; + } + #endif + sums[0][d][j] = sums[0][d][j] ^ pprf.mBaseOTs[treeIdx + j][d][0]; + sums[1][d][j] = sums[1][d][j] ^ pprf.mBaseOTs[treeIdx + j][d][1]; + } + } + // pprf.setTimePoint("SilentMultiPprfSender.expand " + std::to_string(treeIdx)); + + if (mActiveChildXorDelta) + { + // For the last level, we are going to do something special. + // The other party is currently missing both leaf children of + // the active parent. Since this is the last level, we want + // the inactive child to just be the normal value but the + // active child should be the correct value XOR the delta. + // This will be done by sending the sums and the sums plus + // delta and ensure that they can only decrypt the correct ones. + d = pprf.mDepth - 1; + //std::vector>& lastOts = lastOts; + lastOts.resize(min); + for (u64 j = 0; j < min; ++j) + { + // Construct the sums where we will allow the delta (mValue) + // to either be on the left child or right child depending + // on which has the active path. + lastOts[j][0] = lastSums[0][j]; + lastOts[j][1] = TypeTrait::plus(lastSums[1][j], pprf.mValue[treeIdx + j]); + lastOts[j][2] = lastSums[1][j]; + lastOts[j][3] = TypeTrait::plus(lastSums[0][j], pprf.mValue[treeIdx + j]); + + // We are going to expand the 128 bit OT string + // into a 256 bit OT string using AES. + std::array masks, maskIn; + maskIn[0] = pprf.mBaseOTs[treeIdx + j][d][0]; + maskIn[1] = pprf.mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; + maskIn[2] = pprf.mBaseOTs[treeIdx + j][d][1]; + maskIn[3] = pprf.mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; + mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); + + #ifdef DEBUG_PRINT_PPRF + if (mPrint) { + std::cout << "c[" << treeIdx + j << "][" << d << "][0] " << sums[0][d][j] << " " << mBaseOTs[treeIdx + j][d][0] << std::endl;; + std::cout << "c[" << treeIdx + j << "][" << d << "][1] " << sums[1][d][j] << " " << mBaseOTs[treeIdx + j][d][1] << std::endl;; + } + #endif + + // Add the OT masks to the sums and send them over. + lastOts[j][0] = TypeTrait::plus(lastOts[j][0], TypeTrait::fromBlock(masks[0])); + lastOts[j][1] = TypeTrait::plus(lastOts[j][1], TypeTrait::fromBlock(masks[1])); + lastOts[j][2] = TypeTrait::plus(lastOts[j][2], TypeTrait::fromBlock(masks[2])); + lastOts[j][3] = TypeTrait::plus(lastOts[j][3], TypeTrait::fromBlock(masks[3])); + } + + // pprf.setTimePoint("SilentMultiPprfSender.last " + std::to_string(treeIdx)); + + // Resize the sums to that they dont include + // the unmasked sums on the last level! + sums[0].resize(pprf.mDepth - 1); + sums[1].resize(pprf.mDepth - 1); + } + + // Send the sums to the other party. + //sendOne(treeGrp); + //chl.asyncSend(std::move(sums[0])); + //chl.asyncSend(std::move(sums[1])); + + MC_AWAIT(chl.send(std::move(sums[0]))); + MC_AWAIT(chl.send(std::move(sums[1]))); + + if (mActiveChildXorDelta) + MC_AWAIT(chl.send(std::move(lastOts))); + + + //// send the special OT messages for the last level. + //chl.asyncSend(std::move(lastOts)); + //gTimer.setTimePoint("send.expand_send"); + + // copy the last level to the output. If desired, this is + // where the transpose is performed. + auto lvl = getLastLevel(pprf.mDepth, treeIdx); + + // s is a checksum that is used for malicious security. + copyOut(lvl, output, pprf.mPntCount, treeIdx, oFormat, pprf.mOutputFn); + + // pprf.setTimePoint("SilentMultiPprfSender.copyOut " + std::to_string(treeIdx)); + + } + + //uPtr_ = {}; + //tree = {}; + pprf.mTreeAlloc.del(tree); + // pprf.setTimePoint("SilentMultiPprfSender.delete " + std::to_string(treeIdx)); + + MC_END(); + } + }; + + std::vector mExps; + }; + + + template + class SilentSubfieldPprfReceiver : public TimerAdapter + { + public: + using F = typename TypeTrait::F; + u64 mDomain = 0, mDepth = 0, mPntCount = 0; + + std::vector mPoints; + + Matrix mBaseOTs; + Matrix mBaseChoices; + bool mPrint = false; + TreeAllocator mTreeAlloc; + block mDebugValue; + std::function>)> mOutputFn; + std::function fromBlock; + + SilentSubfieldPprfReceiver() = default; + SilentSubfieldPprfReceiver(const SilentSubfieldPprfReceiver&) = delete; + SilentSubfieldPprfReceiver(SilentSubfieldPprfReceiver&&) = delete; + + void configure(u64 domainSize, u64 pointCount) + { + mDomain = domainSize; + mDepth = log2ceil(mDomain); + mPntCount = pointCount; + + mBaseOTs.resize(0, 0); + } + + + // For output format Plain or BlockTransposed, the choice bits it + // samples are in blocks of mDepth, with mPntCount blocks total (one for + // each punctured point). For Plain these blocks encode the punctured + // leaf index in big endian, while for BlockTransposed they are in + // little endian. + BitVector sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng) + { + BitVector choices(mPntCount * mDepth); + + // The points are read in blocks of 8, so make sure that there is a + // whole number of blocks. + mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); + for (u64 i = 0; i < mPntCount; ++i) + { + u64 idx; + switch (format) + { + case osuCrypto::PprfOutputFormat::Plain: + case osuCrypto::PprfOutputFormat::BlockTransposed: + do { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = prng.getBit(); + idx = getActivePath(mBaseChoices[i]); + } while (idx >= modulus); + + break; + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::InterleavedTransposed: + case osuCrypto::PprfOutputFormat::Callback: + + // make sure that at least the first element of this tree + // is within the modulus. + idx = interleavedPoint(0, i, mPntCount, mDomain, format); + if (idx >= modulus) + throw RTE_LOC; + + + do { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = prng.getBit(); + idx = getActivePath(mBaseChoices[i]); + + idx = interleavedPoint(idx, i, mPntCount, mDomain, format); + } while (idx >= modulus); + + + break; + default: + throw RTE_LOC; + break; + } + + } + + for (u64 i = 0; i < mBaseChoices.size(); ++i) + { + choices[i] = mBaseChoices(i); + } + + return choices; + } + + // choices is in the same format as the output from sampleChoiceBits. + void setChoiceBits(PprfOutputFormat format, BitVector choices) + { + // Make sure we're given the right number of OTs. + if (choices.size() != baseOtCount()) + throw RTE_LOC; + + mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); + for (u64 i = 0; i < mPntCount; ++i) + { + switch (format) + { + case osuCrypto::PprfOutputFormat::Plain: + case osuCrypto::PprfOutputFormat::BlockTransposed: + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = choices[mDepth * i + j]; + break; + + // Not sure what ordering would be good for Interleaved or + // InterleavedTransposed. + + default: + throw RTE_LOC; + break; + } + + if (getActivePath(mBaseChoices[i]) >= mDomain) + throw RTE_LOC; + } + } + + + // the number of base OTs that should be set. + u64 baseOtCount() const + { + return mDepth * mPntCount; + } + + // returns true if the base OTs are currently set. + bool hasBaseOts() const + { + return mBaseOTs.size(); + } + + + void setBase(span baseMessages) + { + if (baseOtCount() != static_cast(baseMessages.size())) + throw RTE_LOC; + + // The OTs are used in blocks of 8, so make sure that there is a whole + // number of blocks. + mBaseOTs.resize(roundUpTo(mPntCount, 8), mDepth); + memcpy(mBaseOTs.data(), baseMessages.data(), baseMessages.size() * sizeof(block)); + } + + std::vector getPoints(PprfOutputFormat format) + { + std::vector pnts(mPntCount); + getPoints(pnts, format); + return pnts; + } + void getPoints(span points, PprfOutputFormat format) + { + switch (format) + { + case PprfOutputFormat::Plain: + case PprfOutputFormat::BlockTransposed: + + memset(points.data(), 0, points.size() * sizeof(u64)); + for (u64 j = 0; j < mPntCount; ++j) + { + points[j] = getActivePath(mBaseChoices[j]); + } + + break; + case PprfOutputFormat::InterleavedTransposed: + case PprfOutputFormat::Interleaved: + case PprfOutputFormat::Callback: + + if ((u64)points.size() != mPntCount) + throw RTE_LOC; + if (points.size() % 8) + throw RTE_LOC; + + getPoints(points, PprfOutputFormat::Plain); + interleavedPoints(points, mDomain, format); + + break; + default: + throw RTE_LOC; + break; + } + } + + task<> expand(Socket& chl, PRNG& prng, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + { + MatrixView o(output.data(), output.size(), 1); + return expand(chl, prng, o, oFormat, activeChildXorDelta, numThreads); + } + + // activeChildXorDelta says whether the sender is trying to program the + // active child to be its correct value XOR delta. If it is not, the + // active child will just take a random value. + task<> expand(Socket& chl, PRNG& prng, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + { + setTimePoint("SilentMultiPprfReceiver.start"); + + //lout << " d " << mDomain << " p " << mPntCount << " do " << mDepth << std::endl; + + if (oFormat == PprfOutputFormat::Plain) + { + if (output.rows() != mDomain) + throw RTE_LOC; + + if (output.cols() != mPntCount) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::BlockTransposed) + { + if (output.cols() != mDomain) + throw RTE_LOC; + + if (output.rows() != mPntCount) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::InterleavedTransposed) + { + if (output.rows() != 128) + throw RTE_LOC; + + //if (output.cols() > (mDomain * mPntCount + 127) / 128) + // throw RTE_LOC; + + if (mPntCount & 7) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::Interleaved) + { + if (output.cols() != 1) + throw RTE_LOC; + if (mDomain & 1) + throw RTE_LOC; + auto rows = output.rows(); + if (rows > (mDomain * mPntCount) || + rows / 128 != (mDomain * mPntCount) / 128) + throw RTE_LOC; + if (mPntCount & 7) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::Callback) + { + if (mDomain & 1) + throw RTE_LOC; + if (mPntCount & 7) + throw RTE_LOC; + } + else + { + throw RTE_LOC; + } + + mPoints.resize(roundUpTo(mPntCount, 8)); + getPoints(mPoints, PprfOutputFormat::Plain); + + + MC_BEGIN(task<>, this, numThreads, oFormat, output, &chl, activeChildXorDelta, + i = u64{}, + dd = u64{} + ); + + + dd = mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); + mTreeAlloc.reserve(numThreads, (1ull << (dd+1)) + (32 * (dd+1))); + setTimePoint("SilentMultiPprfReceiver.reserve"); + + mExps.clear(); mExps.reserve(divCeil(mPntCount, 8)); + for (i = 0; i < mPntCount; i += 8) + { + mExps.emplace_back(*this, chl.fork(), oFormat, output, activeChildXorDelta, i); + mExps.back().mFuture = macoro::make_eager(mExps.back().run()); + + //MC_AWAIT(mExps.back().run()); + } + + for (i = 0; i < mExps.size(); ++i) + MC_AWAIT(mExps[i].mFuture); + setTimePoint("SilentMultiPprfReceiver.join"); + + mBaseOTs = {}; + setTimePoint("SilentMultiPprfReceiver.de-alloc"); + + MC_END(); + } + + void clear() + { + mBaseOTs.resize(0, 0); + mBaseChoices.resize(0, 0); + mDomain = 0; + mDepth = 0; + mPntCount = 0; + } + + + + struct Expander + { + SilentSubfieldPprfReceiver& pprf; + Socket chl; + + bool mActiveChildXorDelta = false; + std::array aes; + + PprfOutputFormat oFormat; + MatrixView output; + + macoro::eager_task mFuture; + + std::vector>> mLevels; + + // mySums will hold the left and right GGM tree sums + // for each level. For example mySums[5][0] will + // hold the sum of the left children for the 5th tree. This + // sum will be "missing" the children of the active parent. + // The sender will give of one of the full somes so we can + // compute the missing inactive child. + std::array, 2> mySums; + + // sums for the last level + std::array, 2> lastSums; + std::vector> lastOts; + + // A buffer for receiving the sums from the other party. + // These will be masked by the OT strings. + std::array>, 2> theirSums; + + u64 dd, treeIdx; + // tree will hold the full GGM tree. Not that there are 8 + // indepenendent trees that are being processed together. + // The trees are flattenned to that the children of j are + // located at 2*j and 2*j+1. + //std::unique_ptr uPtr_; + span> tree; + + // Returns the i'th level of the current 8 trees. The + // children of node j on level i are located at 2*j and + // 2*j+1 on level i+1. + span> getLevel(u64 i, u64 g, bool f = false) + { + //auto size = (1ull << i); + #ifdef DEBUG_PRINT_PPRF + //auto offset = (size - 1); + //auto b = (f ? ftree.begin() : tree.begin()) + offset; + #else + return mLevels[i]; + #endif + //return span>(b,e); + }; + + span> getLastLevel(u64 i, u64 g, bool f = false) + { + //auto size = (1ull << i); + #ifdef DEBUG_PRINT_PPRF + //auto offset = (size - 1); + //auto b = (f ? ftree.begin() : tree.begin()) + offset; + #else + if (oFormat == PprfOutputFormat::Interleaved && i == pprf.mDepth) + { + auto b = (AlignedArray*)output.data(); + auto forest = g / 8; + assert(g % 8 == 0); + b += forest * pprf.mDomain; + auto zone = span>(b, pprf.mDomain); + return zone; + } + + //assert(tree.size()); + //auto b = tree.begin() + offset; + + throw RTE_LOC; + #endif + //return span>(b,e); + }; + + + Expander(SilentSubfieldPprfReceiver& p, Socket&& s, PprfOutputFormat of, MatrixView o, bool activeChildXorDelta, u64 ti) + : pprf(p) + , chl(std::move(s)) + , mActiveChildXorDelta(activeChildXorDelta) + , oFormat(of) + , output(o) + , treeIdx(ti) + //, threadIdx(tIdx) + { + assert((treeIdx & 7) == 0); + // A public PRF/PRG that we will use for deriving the GGM tree. + aes[0].setKey(toBlock(3242342)); + aes[1].setKey(toBlock(8993849)); + + + theirSums[0].resize(p.mDepth - mActiveChildXorDelta); + theirSums[1].resize(p.mDepth - mActiveChildXorDelta); + + dd = p.mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); + + } + task<> run() + { + + MC_BEGIN(task<>, this); + + + { + tree = pprf.mTreeAlloc.get(); + assert(tree.size() >= 1ull << (dd)); + mLevels.resize(dd+1); // todo: last level block are kept + mLevels[0] = tree.subspan(0, 1); + auto rem = tree.subspan(1); + for (u64 i = 1; i < dd + 1; i++) + { + while ((u64)rem.data() % 32) + rem = rem.subspan(1); + + mLevels[i] = rem.subspan(0, mLevels[i - 1].size() * 2); + rem = rem.subspan(mLevels[i].size()); + } + } + + + #ifdef DEBUG_PRINT_PPRF + // This will be the full tree and is sent by the receiver to help debug. + std::vector> ftree(1ull << (mDepth + 1)); + + // The delta value on the active path. + //block deltaValue; + chl.recv(mDebugValue); + #endif + + + + #ifdef DEBUG_PRINT_PPRF + // prints out the contents of the d'th level. + auto printLevel = [&](u64 d) + { + + auto level0 = getLevel(d); + auto flevel0 = getLevel(d, true); + + std::cout + << "---------------------\nlevel " << d + << "\n---------------------" << std::endl; + + std::array sums{ ZeroBlock ,ZeroBlock }; + for (i64 i = 0; i < level0.size(); ++i) + { + for (u64 j = 0; j < 8; ++j) + { + + if (neq(level0[i][j], flevel0[i][j])) + std::cout << Color::Red; + + std::cout << "p[" << i << "][" << j << "] " + << level0[i][j] << " " << flevel0[i][j] << std::endl << Color::Default; + + if (i == 0 && j == 0) + sums[i & 1] = sums[i & 1] ^ flevel0[i][j]; + } + } + + std::cout << "sums[0] = " << sums[0] << " " << sums[1] << std::endl; + }; + #endif + + + // The number of real trees for this iteration. + memset(lastSums[0].data(), 0, lastSums[0].size() * sizeof(F)); + memset(lastSums[1].data(), 0, lastSums[1].size() * sizeof(F)); + memset(mySums[0].data(), 0, mySums[0].size() * sizeof(F)); + memset(mySums[1].data(), 0, mySums[1].size() * sizeof(F)); + lastOts.resize(8); + + // This thread will process 8 trees at a time. It will interlace + // the sets of trees are processed with the other threads. + { + #ifdef DEBUG_PRINT_PPRF + chl.recv(ftree); + auto l1f = getLevel(1, true); + #endif + + //timer.setTimePoint("recv.start" + std::to_string(treeIdx)); + // Receive their full set of sums for these 8 trees. + MC_AWAIT(chl.recv(theirSums[0])); + MC_AWAIT(chl.recv(theirSums[1])); + + if (mActiveChildXorDelta) + MC_AWAIT(chl.recv(lastOts)); + // pprf.setTimePoint("SilentMultiPprfReceiver.recv " + std::to_string(treeIdx)); + + tree = pprf.mTreeAlloc.get(); + assert(tree.size() >= 1ull << (dd)); + assert((u64)tree.data() % 32 == 0); + + // pprf.setTimePoint("SilentMultiPprfReceiver.alloc " + std::to_string(treeIdx)); + + auto l1 = getLevel(1, treeIdx); + + for (u64 i = 0; i < 8; ++i) + { + // For the non-active path, set the child of the root node + // as the OT message XOR'ed with the correction sum. + int notAi = pprf.mBaseChoices[i + treeIdx][0]; + l1[notAi][i] = pprf.mBaseOTs[i + treeIdx][0] ^ theirSums[notAi][0][i]; + l1[notAi ^ 1][i] = ZeroBlock; + + #ifdef DEBUG_PRINT_PPRF + if (neq(l1[notAi][i], l1f[notAi][i])) { + std::cout << "l1[" << notAi << "][" << i << "] " << l1[notAi][i] << " = " + << (mBaseOTs[i + treeIdx][0]) << " ^ " + << theirSums[notAi][0][i] << " vs " << l1f[notAi][i] << std::endl; + } + #endif + } + + #ifdef DEBUG_PRINT_PPRF + if (mPrint) + printLevel(1); + #endif + + // For all other levels, expand the GGM tree and add in + // the correction along the active path. + for (u64 d = 1; d < pprf.mDepth; ++d) + { + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = getLevel(d, treeIdx); + + // The next level that we want to construct. + auto level1 = getLevel(d + 1, treeIdx); + + // Zero out the previous sums. + memset(mySums[0].data(), 0, mySums[0].size() * sizeof(block)); + memset(mySums[1].data(), 0, mySums[1].size() * sizeof(block)); + + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = static_cast(level1.size()); + for (u64 childIdx = 0; childIdx < width; ) + { + + // Index of the parent in the previous level. + auto parentIdx = childIdx >> 1; + + // The value of the parent. + auto parent = level0[parentIdx]; + + for (u64 keep = 0; keep < 2; ++keep, ++childIdx) + { + + //// The bit that indicates if we are on the left child (0) + //// or on the right child (1). + //u8 keep = childIdx & 1; + + + // The child that we will write in this iteration. + auto& child = level1[childIdx]; + + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + aes[keep].template hashBlocks<8>(parent.data(), child.data()); + + + + #ifdef DEBUG_PRINT_PPRF + // For debugging, set the active path to zero. + for (u64 i = 0; i < 8; ++i) + if (eq(parent[i], ZeroBlock)) + child[i] = ZeroBlock; + #endif + + if (d < pprf.mDepth - 1) { + // Same as before + // Update the running sums for this level. We keep + // a left and right totals for each level. Note that + // we are actually XOR in the incorrect value of the + // children of the active parent (assuming !DEBUG_PRINT_PPRF). + // This is ok since we will later XOR off these incorrect values. + auto& sum = mySums[keep]; + sum[0] = sum[0] ^ child[0]; + sum[1] = sum[1] ^ child[1]; + sum[2] = sum[2] ^ child[2]; + sum[3] = sum[3] ^ child[3]; + sum[4] = sum[4] ^ child[4]; + sum[5] = sum[5] ^ child[5]; + sum[6] = sum[6] ^ child[6]; + sum[7] = sum[7] ^ child[7]; + } else { + if (getLastLevel(pprf.mDepth, treeIdx).size() <= childIdx) { + childIdx = width; + break; + } + auto& realChild = getLastLevel(pprf.mDepth, treeIdx)[childIdx]; + auto& lastSum = lastSums[keep]; + realChild[0] = TypeTrait::fromBlock(child[0]); + lastSum[0] = TypeTrait::plus(lastSum[0], realChild[0]); + realChild[1] = TypeTrait::fromBlock(child[1]); + lastSum[1] = TypeTrait::plus(lastSum[1], realChild[1]); + realChild[2] = TypeTrait::fromBlock(child[2]); + lastSum[2] = TypeTrait::plus(lastSum[2], realChild[2]); + realChild[3] = TypeTrait::fromBlock(child[3]); + lastSum[3] = TypeTrait::plus(lastSum[3], realChild[3]); + realChild[4] = TypeTrait::fromBlock(child[4]); + lastSum[4] = TypeTrait::plus(lastSum[4], realChild[4]); + realChild[5] = TypeTrait::fromBlock(child[5]); + lastSum[5] = TypeTrait::plus(lastSum[5], realChild[5]); + realChild[6] = TypeTrait::fromBlock(child[6]); + lastSum[6] = TypeTrait::plus(lastSum[6], realChild[6]); + realChild[7] = TypeTrait::fromBlock(child[7]); + lastSum[7] = TypeTrait::plus(lastSum[7], realChild[7]); + } + } + } + + // For everything but the last level we have to + // 1) fix our sums so they dont include the incorrect + // values that are the children of the active parent + // 2) Update the non-active child of the active parent. + if (!mActiveChildXorDelta || d != pprf.mDepth - 1) + { + + for (u64 i = 0; i < 8; ++i) + { + // the index of the leaf node that is active. + auto leafIdx = pprf.mPoints[i + treeIdx]; + + // The index of the active child node. + auto activeChildIdx = leafIdx >> (pprf.mDepth - 1 - d); + + // The index of the active child node sibling. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + #ifdef DEBUG_PRINT_PPRF + auto prev = level1[inactiveChildIdx][i]; + #endif + + auto& inactiveChild = level1[inactiveChildIdx][i]; + + + // correct the sum value by XORing off the incorrect + auto correctSum = + inactiveChild ^ + theirSums[notAi][d][i]; + + inactiveChild = + correctSum ^ + mySums[notAi][i] ^ + pprf.mBaseOTs[i + treeIdx][d]; + + #ifdef DEBUG_PRINT_PPRF + if (mPrint) + std::cout << "up[" << i << "] = level1[" << inactiveChildIdx << "][" << i << "] " + << prev << " -> " << level1[inactiveChildIdx][i] << " " << activeChildIdx << " " << inactiveChildIdx << " ~~ " + << mBaseOTs[i + treeIdx][d] << " " << theirSums[notAi][d][i] << " @ " << (i + treeIdx) << " " << d << std::endl; + + auto fLevel1 = getLevel(d + 1, true); + if (neq(fLevel1[inactiveChildIdx][i], inactiveChild)) + throw RTE_LOC; + #endif + } + } + #ifdef DEBUG_PRINT_PPRF + if (mPrint) + printLevel(d + 1); + #endif + + } + + // pprf.setTimePoint("SilentMultiPprfReceiver.expand " + std::to_string(treeIdx)); + + //timer.setTimePoint("recv.expanded"); + + + // copy the last level to the output. If desired, this is + // where the transpose is performed. + auto lvl = getLastLevel(pprf.mDepth, treeIdx); + + if (mActiveChildXorDelta) + { + // Now processes the last level. This one is special + // because we must XOR in the correction value as + // before but we must also fixed the child value for + // the active child. To do this, we will receive 4 + // values. Two for each case (left active or right active). + //timer.setTimePoint("recv.recvLast"); + + auto d = pprf.mDepth - 1; + for (u64 j = 0; j < 8; ++j) + { + // The index of the child on the active path. + auto activeChildIdx = pprf.mPoints[j + treeIdx]; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + std::array masks, maskIn; + + // We are going to expand the 128 bit OT string + // into a 256 bit OT string using AES. + maskIn[0] = pprf.mBaseOTs[j + treeIdx][d]; + maskIn[1] = pprf.mBaseOTs[j + treeIdx][d] ^ AllOneBlock; + mAesFixedKey.template hashBlocks<2>(maskIn.data(), masks.data()); + + // now get the chosen message OT strings by XORing + // the expended (random) OT strings with the lastOts values. + auto& ot0 = lastOts[j][2 * notAi + 0]; + auto& ot1 = lastOts[j][2 * notAi + 1]; + ot0 = TypeTrait::minus(ot0, TypeTrait::fromBlock(masks[0])); + ot1 = TypeTrait::minus(ot1, TypeTrait::fromBlock(masks[1])); + + #ifdef DEBUG_PRINT_PPRF + auto prev = level[inactiveChildIdx][j]; + #endif + + auto& inactiveChild = lvl[inactiveChildIdx][j]; + auto& activeChild = lvl[activeChildIdx][j]; + + // Fix the sums we computed previously to not include the + // incorrect child values. + auto inactiveSum = TypeTrait::minus(lastSums[notAi][j], inactiveChild); + auto activeSum = TypeTrait::minus(lastSums[notAi ^ 1][j], activeChild); + + // Update the inactive and active child to have to correct + // value by XORing their full sum with out partial sum, which + // gives us exactly the value we are missing. + inactiveChild = TypeTrait::minus(ot0, inactiveSum); + activeChild = TypeTrait::minus(ot1, activeSum); + + #ifdef DEBUG_PRINT_PPRF + auto fLevel1 = getLevel(d + 1, true); + if (neq(fLevel1[inactiveChildIdx][j], inactiveChild)) + throw RTE_LOC; + if (neq(fLevel1[activeChildIdx][j], activeChild ^ mDebugValue)) + throw RTE_LOC; + + if (mPrint) + std::cout << "up[" << d << "] = level1[" << (inactiveChildIdx / mPntCount) << "][" << (inactiveChildIdx % mPntCount) << " " + << prev << " -> " << level[inactiveChildIdx][j] << " ~~ " + << mBaseOTs[j + treeIdx][d] << " " << ot0 << " @ " << (j + treeIdx) << " " << d << std::endl; + #endif + } + // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); + + //timer.setTimePoint("recv.expandLast"); + } + else + { + for (auto j : rng(std::min(8, pprf.mPntCount - treeIdx))) + { + + // The index of the child on the active path. + auto activeChildIdx = pprf.mPoints[j + treeIdx]; + lvl[activeChildIdx][j] = F{}; + } + } + + // s is a checksum that is used for malicious security. + copyOut(lvl, output, pprf.mPntCount, treeIdx, oFormat, pprf.mOutputFn); + + // pprf.setTimePoint("SilentMultiPprfReceiver.copy " + std::to_string(treeIdx)); + + //uPtr_ = {}; + //tree = {}; + pprf.mTreeAlloc.del(tree); + + // pprf.setTimePoint("SilentMultiPprfReceiver.delete " + std::to_string(treeIdx)); + + } + + MC_END(); + } + }; + + std::vector mExps; + }; +} \ No newline at end of file diff --git a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp index 767efa5d..14a93af5 100644 --- a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp +++ b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp @@ -53,7 +53,7 @@ namespace osuCrypto MC_AWAIT(runBatch(chl, scratch.subspan(0, messagesFullChunks * chunkSize()))); - assert(scratch[0] != ZeroBlock); + assert(messagesFullChunks == 0 || scratch[0] != ZeroBlock); // Extra blocks MC_AWAIT(runBatch(chl, mExtraW.subspan(0, numExtra * chunkSize()))); diff --git a/libOTe/Vole/Subfield/NoisyVoleReceiver.h b/libOTe/Vole/Subfield/NoisyVoleReceiver.h new file mode 100644 index 00000000..5cdeb35c --- /dev/null +++ b/libOTe/Vole/Subfield/NoisyVoleReceiver.h @@ -0,0 +1,105 @@ +#pragma once +// © 2022 Visa. +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// This code implements features described in [Silver: Silent VOLE and Oblivious +// Transfer from Hardness of Decoding Structured LDPC Codes, +// https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative +// Commons Attribution 4.0 International Public License +// (https://creativecommons.org/licenses/by/4.0/legalcode). + +#include +#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) + +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Timer.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "libOTe/Tools/Coproto.h" +#include "libOTe/TwoChooseOne/OTExtInterface.h" + +namespace osuCrypto::Subfield { + + template + class NoisySubfieldVoleReceiver : public TimerAdapter { + public: + using F = typename TypeTrait::F; + using G = typename TypeTrait::G; + task<> receive(span y, span z, PRNG& prng, + OtSender& ot, Socket& chl) { + MC_BEGIN(task<>, this, y, z, &prng, &ot, &chl, + otMsg = AlignedUnVector>{ TypeTrait::bitsF }); + + setTimePoint("NoisyVoleReceiver.ot.begin"); + + MC_AWAIT(ot.send(otMsg, prng, chl)); + + setTimePoint("NoisyVoleReceiver.ot.end"); + + MC_AWAIT(receive(y, z, prng, otMsg, chl)); + + MC_END(); + } + + task<> receive(span y, span z, PRNG& _, + span> otMsg, + Socket& chl) { + MC_BEGIN(task<>, this, y, z, otMsg, &chl, + msg = Matrix{}, + prng = std::move(PRNG{}) + ); + + if (otMsg.size() != TypeTrait::bitsF) throw RTE_LOC; + if (y.size() != z.size()) throw RTE_LOC; + if (z.size() == 0) throw RTE_LOC; + + setTimePoint("NoisyVoleReceiver.begin"); + + memset(z.data(), 0, TypeTrait::bytesF * z.size()); + msg.resize(otMsg.size(), z.size(), AllocType::Uninitialized); + + for (size_t ii = 0; ii < TypeTrait::bitsF; ++ii) { + prng.SetSeed(otMsg[ii][0], z.size()); + auto& buffer = prng.mBuffer; + auto pow = TypeTrait::pow(ii); + for (size_t j = 0; j < y.size(); ++j) { + auto bufj = TypeTrait::fromBlock(buffer[j]); + z[j] = TypeTrait::plus(z[j], bufj); + F yy = TypeTrait::mul(pow, y[j]); + + msg(ii, j) = TypeTrait::plus(yy, bufj); + } + + prng.SetSeed(otMsg[ii][1], z.size()); + + for (size_t j = 0; j < y.size(); ++j) { + // enc one message under the OT msg. + msg(ii, j) = TypeTrait::plus(msg(ii, j), TypeTrait::fromBlock(prng.mBuffer[j])); + } + } + + MC_AWAIT(chl.send(std::move(msg))); + setTimePoint("NoisyVoleReceiver.done"); + + MC_END(); + } + + }; + +} // namespace osuCrypto +#endif diff --git a/libOTe/Vole/Subfield/NoisyVoleSender.h b/libOTe/Vole/Subfield/NoisyVoleSender.h new file mode 100644 index 00000000..7de8b989 --- /dev/null +++ b/libOTe/Vole/Subfield/NoisyVoleSender.h @@ -0,0 +1,97 @@ +#pragma once +// © 2022 Visa. +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// This code implements features described in [Silver: Silent VOLE and Oblivious +// Transfer from Hardness of Decoding Structured LDPC Codes, +// https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative +// Commons Attribution 4.0 International Public License +// (https://creativecommons.org/licenses/by/4.0/legalcode). + +#include +#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) + +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Timer.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "libOTe/Tools/Coproto.h" +#include "libOTe/TwoChooseOne/OTExtInterface.h" + +namespace osuCrypto::Subfield { + template + class NoisySubfieldVoleSender : public TimerAdapter { + public: + using F = typename TypeTrait::F; + using G = typename TypeTrait::G; + task<> send(F x, span z, PRNG& prng, + OtReceiver& ot, Socket& chl) { + MC_BEGIN(task<>, this, x, z, &prng, &ot, &chl, + bv = TypeTrait::BitVectorF(x), + otMsg = AlignedUnVector{ TypeTrait::bitsF }); + + setTimePoint("NoisyVoleSender.ot.begin"); + + MC_AWAIT(ot.receive(bv, otMsg, prng, chl)); + setTimePoint("NoisyVoleSender.ot.end"); + + MC_AWAIT(send(x, z, prng, otMsg, chl)); + + MC_END(); + } + + task<> send(F x, span z, PRNG& _, + span otMsg, Socket& chl) { + MC_BEGIN(task<>, this, x, z, otMsg, &chl, + prng = std::move(PRNG{}), + msg = Matrix{}, + xb = BitVector{}); + + if (otMsg.size() != TypeTrait::bitsF) + throw RTE_LOC; + setTimePoint("NoisyVoleSender.main"); + + memset(z.data(), 0, TypeTrait::bytesF * z.size()); + msg.resize(otMsg.size(), z.size(), AllocType::Uninitialized); + + MC_AWAIT(chl.recv(msg)); + + setTimePoint("NoisyVoleSender.recvMsg"); + + xb = TypeTrait::BitVectorF(x); + for (size_t i = 0; i < TypeTrait::bitsF; ++i) + { + prng.SetSeed(otMsg[i], z.size()); + + for (u64 j = 0; j < (u64)z.size(); ++j) + { + F bufj = TypeTrait::fromBlock(prng.mBuffer[j]); + F data = xb[i] ? TypeTrait::minus(msg(i, j), bufj) : bufj; + z[j] = TypeTrait::plus(z[j], data); + } + } + setTimePoint("NoisyVoleSender.done"); + + MC_END(); + } + + }; +} // namespace osuCrypto + +#endif \ No newline at end of file diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h new file mode 100644 index 00000000..3a4e25fc --- /dev/null +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -0,0 +1,788 @@ +#pragma once +// © 2022 Visa. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). + +#include +#ifdef ENABLE_SILENT_VOLE + +#include +#include +#include +#include "libOTe/Tools/Subfield/SubfieldPprf.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace osuCrypto::Subfield +{ + + + template + class SilentSubfieldVoleReceiver : public TimerAdapter + { + public: + using F = typename TypeTrait::F; + using G = typename TypeTrait::G; + + static constexpr u64 mScaler = 2; + + enum class State + { + Default, + Configured, + HasBase + }; + + // The current state of the protocol + State mState = State::Default; + + // The number of OTs the user requested. + u64 mRequestedNumOTs = 0; + + // The number of OTs actually produced (at least the number requested). + u64 mN = 0; + + // The length of the noisy vectors (2 * mN for the silver codes). + u64 mN2 = 0; + + // We perform regular LPN, so this is the + // size of the each chunk. + u64 mSizePer = 0; + + u64 mNumPartitions = 0; + + // The noisy coordinates. + std::vector mS; + + // What type of Base OTs should be performed. + SilentBaseType mBaseType; + + // The matrix multiplication type which compresses + // the sparse vector. + MultType mMultType = DefaultMultType; + + ExConvCode mExConvEncoder; + + // The multi-point punctured PRF for generating + // the sparse vectors. + SilentSubfieldPprfReceiver mGen; + + // The internal buffers for holding the expanded vectors. + // mA + mB = mC * delta + AlignedUnVector mA; + + // mA + mB = mC * delta + AlignedUnVector mC; + + std::vector mGapOts; + + u64 mNumThreads = 1; + + bool mDebug = false; + + BitVector mIknpSendBaseChoice, mGapBaseChoice; + + SilentSecType mMalType = SilentSecType::SemiHonest; + + block mMalCheckSeed, mMalCheckX, mDeltaShare; + + AlignedVector mNoiseDeltaShare; + AlignedVector mNoiseValues; + + +#ifdef ENABLE_SOFTSPOKEN_OT + SoftSpokenMalOtSender mOtExtSender; + SoftSpokenMalOtReceiver mOtExtRecver; +#endif + + // // sets the Iknp base OTs that are then used to extend + // void setBaseOts( + // span> baseSendOts); + // + // // return the number of base OTs IKNP needs + // u64 baseOtCount() const; + + u64 baseVoleCount() const + { + return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + } + + // // returns true if the IKNP base OTs are currently set. + // bool hasBaseOts() const; + // + // returns true if the silent base OTs are set. + bool hasSilentBaseOts() const { + return mGen.hasBaseOts(); + }; + // + // // Generate the IKNP base OTs + // task<> genBaseOts(PRNG& prng, Socket& chl) ; + + // Generate the silent base OTs. If the Iknp + // base OTs are set then we do an IKNP extend, + // otherwise we perform a base OT protocol to + // generate the needed OTs. + task<> genSilentBaseOts(PRNG& prng, Socket& chl) + { + using BaseOT = DefaultBaseOT; + + + MC_BEGIN(task<>, this, &prng, &chl, + choice = BitVector{}, + bb = BitVector{}, + msg = AlignedUnVector{}, + baseVole = std::vector{}, + baseOt = BaseOT{}, + chl2 = Socket{}, + prng2 = std::move(PRNG{}), + noiseVals = std::vector{}, + noiseDeltaShares = std::vector{}, + nv = NoisySubfieldVoleReceiver{} + + ); + + setTimePoint("SilentVoleReceiver.genSilent.begin"); + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + choice = sampleBaseChoiceBits(prng); + msg.resize(choice.size()); + + // sample the noise vector noiseVals such that we will compute + // + // C = (000 noiseVals[0] 0000 ... 000 noiseVals[p] 000) + // + // and then we want secret shares of C * delta. As a first step + // we will compute secret shares of + // + // delta * noiseVals + // + // and store our share in voleDeltaShares. This party will then + // compute their share of delta * C as what comes out of the PPRF + // plus voleDeltaShares[i] added to the appreciate spot. Similarly, the + // other party will program the PPRF to output their share of delta * noiseVals. + // + noiseVals = sampleBaseVoleVals(prng); + noiseDeltaShares.resize(noiseVals.size()); + if (mTimer) + nv.setTimer(*mTimer); + + if (mBaseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + + if (mOtExtSender.hasBaseOts() == false) + { + msg.resize(msg.size() + mOtExtSender.baseOtCount()); + bb.resize(mOtExtSender.baseOtCount()); + bb.randomize(prng); + choice.append(bb); + + MC_AWAIT(mOtExtRecver.receive(choice, msg, prng, chl)); + + mOtExtSender.setBaseOts( + span(msg).subspan( + msg.size() - mOtExtSender.baseOtCount(), + mOtExtSender.baseOtCount()), + bb); + + msg.resize(msg.size() - mOtExtSender.baseOtCount()); + MC_AWAIT(nv.receive(noiseVals, noiseDeltaShares, prng, mOtExtSender, chl)); + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + + MC_AWAIT( + macoro::when_all_ready( + nv.receive(noiseVals, noiseDeltaShares, prng2, mOtExtSender, chl2), + mOtExtRecver.receive(choice, msg, prng, chl) + )); + } +#else + throw std::runtime_error("soft spoken must be enabled"); +#endif + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + MC_AWAIT(baseOt.receive(choice, msg, prng, chl)); + MC_AWAIT(nv.receive(noiseVals, noiseDeltaShares, prng2, baseOt, chl2)); + } + + setSilentBaseOts(msg, noiseDeltaShares); + setTimePoint("SilentVoleReceiver.genSilent.done"); + MC_END(); + }; + + // configure the silent OT extension. This sets + // the parameters and figures out how many base OT + // will be needed. These can then be ganerated for + // a different OT extension or using a base OT protocol. + void configure( + u64 numOTs, + SilentBaseType type = SilentBaseType::BaseExtend, + u64 secParam = 128) + { + mState = State::Configured; + u64 gap = 0; + mBaseType = type; + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + + SubfieldExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); + break; + default: + throw RTE_LOC; + break; + } + + mGapOts.resize(gap); + mGen.configure(mSizePer, mNumPartitions); + } + + // return true if this instance has been configured. + bool isConfigured() const { return mState != State::Default; } + + // Returns how many base OTs the silent OT extension + // protocol will needs. + u64 silentBaseOtCount() const + { + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + return mGen.baseOtCount() + mGapOts.size(); + + } + + // The silent base OTs must have specially set base OTs. + // This returns the choice bits that should be used. + // Call this is you want to use a specific base OT protocol + // and then pass the OT messages back using setSilentBaseOts(...). + BitVector sampleBaseChoiceBits(PRNG& prng) { + + if (isConfigured() == false) + throw std::runtime_error("configure(...) must be called first"); + + auto choice = mGen.sampleChoiceBits(mN2, getPprfFormat(), prng); + + mGapBaseChoice.resize(mGapOts.size()); + mGapBaseChoice.randomize(prng); + choice.append(mGapBaseChoice); + + return choice; + } + + std::vector sampleBaseVoleVals(PRNG& prng) + { + if (isConfigured() == false) + throw RTE_LOC; + if (mGapBaseChoice.size() != mGapOts.size()) + throw std::runtime_error("sampleBaseChoiceBits must be called before sampleBaseVoleVals. " LOCATION); + + // sample the values of the noisy coordinate of c + // and perform a noicy vole to get x+y = mD * c + auto w = mNumPartitions + mGapOts.size(); + //std::vector y(w); + mNoiseValues.resize(w); + prng.get(mNoiseValues.data(), mNoiseValues.size()); + + mS.resize(mNumPartitions); + mGen.getPoints(mS, getPprfFormat()); + + // todo + std::vector tmp = mS; + std::sort(tmp.begin(), tmp.end()); + + auto j = mNumPartitions * mSizePer; + + for (u64 i = 0; i < (u64)mGapBaseChoice.size(); ++i) + { + if (mGapBaseChoice[i]) + { + mS.push_back(j + i); + } + } + + // if (mMalType == SilentSecType::Malicious) + // { + // + // mMalCheckSeed = prng.get(); + // mMalCheckX = ZeroBlock; + // auto yIter = mNoiseValues.begin(); + // + // for (u64 i = 0; i < mNumPartitions; ++i) + // { + // auto s = mS[i]; + // auto xs = mMalCheckSeed.gf128Pow(s + 1); + // mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); + // ++yIter; + // } + // + // auto sIter = mS.begin() + mNumPartitions; + // for (u64 i = 0; i < mGapBaseChoice.size(); ++i) + // { + // if (mGapBaseChoice[i]) + // { + // auto s = *sIter; + // auto xs = mMalCheckSeed.gf128Pow(s + 1); + // mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); + // ++sIter; + // } + // ++yIter; + // } + // + // + // std::vector y(mNoiseValues.begin(), mNoiseValues.end()); + // y.push_back(mMalCheckX); + // return y; + // } + + return std::vector(mNoiseValues.begin(), mNoiseValues.end()); + } + + // Set the externally generated base OTs. This choice + // bits must be the one return by sampleBaseChoiceBits(...). + void setSilentBaseOts(span recvBaseOts, + span noiseDeltaShare) + { + if (isConfigured() == false) + throw std::runtime_error("configure(...) must be called first."); + + if (static_cast(recvBaseOts.size()) != silentBaseOtCount()) + throw std::runtime_error("wrong number of silent base OTs"); + + auto genOts = recvBaseOts.subspan(0, mGen.baseOtCount()); + auto gapOts = recvBaseOts.subspan(mGen.baseOtCount(), mGapOts.size()); + + mGen.setBase(genOts); + std::copy(gapOts.begin(), gapOts.end(), mGapOts.begin()); + + // if (mMalType == SilentSecType::Malicious) + // { + // mDeltaShare = noiseDeltaShare.back(); + // noiseDeltaShare = noiseDeltaShare.subspan(0, noiseDeltaShare.size() - 1); + // } + + mNoiseDeltaShare = AlignedVector(noiseDeltaShare.begin(), noiseDeltaShare.end()); + + mState = State::HasBase; + } + + // Perform the actual OT extension. If silent + // base OTs have been generated or set, then + // this function is non-interactive. Otherwise + // the silent base OTs will automatically be performed. + task<> silentReceive( + span c, + span b, + PRNG& prng, + Socket& chl) + { + MC_BEGIN(task<>, this, c, b, &prng, &chl); + if (c.size() != b.size()) + throw RTE_LOC; + + MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); + + std::memcpy(c.data(), mC.data(), c.size() * TypeTrait::bytesG); + std::memcpy(b.data(), mA.data(), b.size() * TypeTrait::bytesF); + clear(); + MC_END(); + } + + // Perform the actual OT extension. If silent + // base OTs have been generated or set, then + // this function is non-interactive. Otherwise + // the silent base OTs will automatically be performed. + task<> silentReceiveInplace( + u64 n, + PRNG& prng, + Socket& chl) + { + MC_BEGIN(task<>, this, n, &prng, &chl, + gapVals = std::vector{}, + myHash = std::array{}, + theirHash = std::array{} + ); + gTimer.setTimePoint("SilentVoleReceiver.ot.enter"); + + if (isConfigured() == false) + { + // first generate 128 normal base OTs + configure(n, SilentBaseType::BaseExtend); + // configure(n, SilentBaseType::Base); + } + + if (mRequestedNumOTs != n) + throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); + + if (hasSilentBaseOts() == false) + { + MC_AWAIT(genSilentBaseOts(prng, chl)); + } + + // allocate mA + mA.resize(0); + mA.resize(mN2); + + setTimePoint("SilentVoleReceiver.alloc"); + + // allocate the space for mC + mC.resize(0); + mC.resize(mN2, AllocType::Zeroed); + setTimePoint("SilentVoleReceiver.alloc.zero"); + + // derandomize the random OTs for the gap + // to have the desired correlation. + gapVals.resize(mGapOts.size()); + + if (gapVals.size()) + MC_AWAIT(chl.recv(gapVals)); + + for (auto g : rng(mGapOts.size())) + { + auto aa = mA.subspan(mNumPartitions * mSizePer); + auto cc = mC.subspan(mNumPartitions * mSizePer); + + auto noise = mNoiseValues.subspan(mNumPartitions); + auto noiseShares = mNoiseDeltaShare.subspan(mNumPartitions); + + if (mGapBaseChoice[g]) + { + cc[g] = noise[g]; + aa[g] = TypeTrait::minus( + TypeTrait::minus(gapVals[g], TypeTrait::fromBlock(AES(mGapOts[g]).ecbEncBlock(ZeroBlock))), + noiseShares[g]); + } + else + { + aa[g] = TypeTrait::fromBlock(mGapOts[g]); + } + } + + setTimePoint("SilentVoleReceiver.recvGap"); + + + + if (mTimer) + mGen.setTimer(*mTimer); + // expand the seeds into mA + MC_AWAIT(mGen.expand(chl, prng, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); + + setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); + + // populate the noisy coordinates of mC and + // update mA to be a secret share of mC * delta + for (u64 i = 0; i < mNumPartitions; ++i) + { + auto pnt = mS[i]; + mC[pnt] = mNoiseValues[i]; + mA[pnt] = TypeTrait::minus(mA[pnt], mNoiseDeltaShare[i]); + } + + + if (mDebug) + { + MC_AWAIT(checkRT(chl)); + setTimePoint("SilentVoleReceiver.expand.checkRT"); + } + + + // if (mMalType == SilentSecType::Malicious) + // { + // MC_AWAIT(chl.send(std::move(mMalCheckSeed))); + // + // myHash = ferretMalCheck(mDeltaShare, mNoiseValues); + // + // MC_AWAIT(chl.recv(theirHash)); + // + // if (theirHash != myHash) + // throw RTE_LOC; + // } + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + if (mTimer) { + mExConvEncoder.setTimer(getTimer()); + } + + mExConvEncoder.template dualEncode2( + mA.subspan(0, mExConvEncoder.mCodeSize), + mC.subspan(0, mExConvEncoder.mCodeSize) + ); + + break; + default: + throw RTE_LOC; + break; + } + + // resize the buffers down to only contain the real elements. + mA.resize(mRequestedNumOTs); + mC.resize(mRequestedNumOTs); + + mNoiseValues = {}; + mNoiseDeltaShare = {}; + + // make the protocol as done and that + // mA,mC are ready to be consumed. + mState = State::Default; + + MC_END(); + } + + + + // internal. + task<> checkRT(Socket& chl) const + { + MC_BEGIN(task<>, this, &chl, + B = AlignedVector(mA.size()), + sparseNoiseDelta = std::vector(mA.size()), + noiseDeltaShare2 = std::vector(), + delta = F{} + ); + //std::vector mB(mA.size()); + MC_AWAIT(chl.recv(delta)); + MC_AWAIT(chl.recv(B)); + MC_AWAIT(chl.recvResize(noiseDeltaShare2)); + + for (u64 i = 0; i < mA.size(); i++) { + F left = TypeTrait::mul(delta, mC[i]); + F right = TypeTrait::minus(mA[i], B[i]); + if (left != right) { + throw RTE_LOC; + } + } + + //check that at locations mS[0],...,mS[..] + // that we hold a sharing mA, mB of + // + // delta * mC = delta * (00000 noiseDeltaShare2[0] 0000 .... 0000 noiseDeltaShare2[m] 0000) + // + // where noiseDeltaShare2[i] is at position mS[i] of mC + // + // That is, I hold mA, mC s.t. + // + // delta * mC = mA + mB + // + +// if (noiseDeltaShare2.size() != mNoiseDeltaShare.size()) +// throw RTE_LOC; +// +// for (auto i : rng(mNoiseDeltaShare.size())) +// { +// if ((mNoiseDeltaShare[i] ^ noiseDeltaShare2[i]) != mNoiseValues[i].gf128Mul(delta)) +// throw RTE_LOC; +// } +// +// { +// +// for (auto i : rng(mNumPartitions* mSizePer)) +// { +// auto iter = std::find(mS.begin(), mS.end(), i); +// if (iter != mS.end()) +// { +// auto d = iter - mS.begin(); +// +// if (mC[i] != mNoiseValues[d]) +// throw RTE_LOC; +// +// if (mNoiseValues[d].gf128Mul(delta) != (mA[i] ^ B[i])) +// { +// std::cout << "bad vole base correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; +// std::cout << "i " << i << std::endl; +// std::cout << "mA[i] " << mA[i] << std::endl; +// std::cout << "mB[i] " << B[i] << std::endl; +// std::cout << "mC[i] " << mC[i] << std::endl; +// std::cout << "delta " << delta << std::endl; +// std::cout << "mA[i] + mB[i] " << (mA[i] ^ B[i]) << std::endl; +// std::cout << "mC[i] * delta " << (mC[i].gf128Mul(delta)) << std::endl; +// +// throw RTE_LOC; +// } +// } +// else +// { +// if (mA[i] != B[i]) +// { +// std::cout << mA[i] << " " << B[i] << std::endl; +// throw RTE_LOC; +// } +// +// if (mC[i] != oc::ZeroBlock) +// throw RTE_LOC; +// } +// } +// +// u64 d = mNumPartitions; +// for (auto j : rng(mGapBaseChoice.size())) +// { +// auto idx = j + mNumPartitions * mSizePer; +// auto aa = mA.subspan(mNumPartitions * mSizePer); +// auto bb = B.subspan(mNumPartitions * mSizePer); +// auto cc = mC.subspan(mNumPartitions * mSizePer); +// auto noise = mNoiseValues.subspan(mNumPartitions); +// //auto noiseShare = mNoiseValues.subspan(mNumPartitions); +// if (mGapBaseChoice[j]) +// { +// if (mS[d++] != idx) +// throw RTE_LOC; +// +// if (cc[j] != noise[j]) +// { +// std::cout << "sparse noise vector mC is not the expected value" << std::endl; +// std::cout << "i j " << idx << " " << j << std::endl; +// std::cout << "mC[i] " << cc[j] << std::endl; +// std::cout << "noise[j] " << noise[j] << std::endl; +// throw RTE_LOC; +// } +// +// if (noise[j].gf128Mul(delta) != (aa[j] ^ bb[j])) +// { +// +// std::cout << "bad vole base GAP correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; +// std::cout << "i " << idx << std::endl; +// std::cout << "mA[i] " << aa[j] << std::endl; +// std::cout << "mB[i] " << bb[j] << std::endl; +// std::cout << "mC[i] " << cc[j] << std::endl; +// std::cout << "delta " << delta << std::endl; +// std::cout << "mA[i] + mB[i] " << (aa[j] ^ bb[j]) << std::endl; +// std::cout << "mC[i] * delta " << (cc[j].gf128Mul(delta)) << std::endl; +// std::cout << "noise * delta " << (noise[j].gf128Mul(delta)) << std::endl; +// throw RTE_LOC; +// } +// +// } +// else +// { +// if (aa[j] != bb[j]) +// throw RTE_LOC; +// +// if (cc[j] != oc::ZeroBlock) +// throw RTE_LOC; +// } +// } +// +// if (d != mS.size()) +// throw RTE_LOC; +// } + + + //{ + + // auto cDelta = B; + // for (u64 i = 0; i < cDelta.size(); ++i) + // cDelta[i] = cDelta[i] ^ mA[i]; + + // std::vector exp(mN2); + // for (u64 i = 0; i < mNumPartitions; ++i) + // { + // auto j = mS[i]; + // exp[j] = noiseDeltaShare2[i]; + // } + + // auto iter = mS.begin() + mNumPartitions; + // for (u64 i = 0, j = mNumPartitions * mSizePer; i < mGapOts.size(); ++i, ++j) + // { + // if (mGapBaseChoice[i]) + // { + // if (*iter != j) + // throw RTE_LOC; + // ++iter; + + // exp[j] = noiseDeltaShare2[mNumPartitions + i]; + // } + // } + + // if (iter != mS.end()) + // throw RTE_LOC; + + // bool failed = false; + // for (u64 i = 0; i < mN2; ++i) + // { + // if (neq(cDelta[i], exp[i])) + // { + // std::cout << i << " / " << mN2 << + // " cd = " << cDelta[i] << + // " exp= " << exp[i] << std::endl; + // failed = true; + // } + // } + + // if (failed) + // throw RTE_LOC; + + // std::cout << "debug check ok" << std::endl; + //} + + MC_END(); + } + + std::array ferretMalCheck( + block deltaShare, + span y) + { + + block xx = mMalCheckSeed; + block sum0 = ZeroBlock; + block sum1 = ZeroBlock; + + + for (u64 i = 0; i < (u64)mA.size(); ++i) + { + block low, high; + xx.gf128Mul(mA[i], low, high); + sum0 = sum0 ^ low; + sum1 = sum1 ^ high; + //mySum = mySum ^ xx.gf128Mul(mA[i]); + + // xx = mMalCheckSeed^{i+1} + xx = xx.gf128Mul(mMalCheckSeed); + } + block mySum = sum0.gf128Reduce(sum1); + + std::array myHash; + RandomOracle ro(32); + ro.Update(mySum ^ deltaShare); + ro.Final(myHash); + return myHash; + } + + PprfOutputFormat getPprfFormat() + { + return PprfOutputFormat::Interleaved; + } + + void clear() + { + mS = {}; + mA = {}; + mC = {}; + mGen.clear(); + mGapBaseChoice = {}; + } + }; +} +#endif \ No newline at end of file diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h new file mode 100644 index 00000000..fa4428a8 --- /dev/null +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -0,0 +1,480 @@ +#pragma once +// © 2022 Visa. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). + +#include +#ifdef ENABLE_SILENT_VOLE + +#include +#include +#include +#include +#include "libOTe/Tools/Subfield/SubfieldPprf.h" +#include +#include +#include +#include +#include +#include +#include +#include +//#define NO_HASH + +namespace osuCrypto::Subfield +{ + + template + inline void SubfieldExConvConfigure( + u64 numOTs, u64 secParam, + MultType mMultType, + u64& mRequestedNumOTs, + u64& mNumPartitions, + u64& mSizePer, + u64& mN2, + u64& mN, + ExConvCode& mEncoder + ) + { + u64 a = 24; + auto mScaler = 2; + u64 w; + double minDist; + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + w = 7; + minDist = 0.1; + break; + case osuCrypto::MultType::ExConv21x24: + w = 21; + minDist = 0.15; + break; + default: + throw RTE_LOC; + break; + } + + mRequestedNumOTs = numOTs; + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = roundUpTo((numOTs * mScaler + mNumPartitions - 1) / mNumPartitions, 8); + mN2 = mSizePer * mNumPartitions; + mN = mN2 / mScaler; + + mEncoder.config(numOTs, numOTs * mScaler, w, a, true); + } + + + template + class SilentSubfieldVoleSender : public TimerAdapter + { + public: + using F = typename TypeTrait::F; + using G = typename TypeTrait::G; + + static constexpr u64 mScaler = 2; + + enum class State + { + Default, + Configured, + HasBase + }; + + + State mState = State::Default; + + SilentSubfieldPprfSender mGen; + + u64 mRequestedNumOTs = 0; + u64 mN2 = 0; + u64 mN = 0; + u64 mNumPartitions = 0; + u64 mSizePer = 0; + u64 mNumThreads = 1; + std::vector> mGapOts; + SilentBaseType mBaseType; + std::vector mNoiseDeltaShares; + + SilentSecType mMalType = SilentSecType::SemiHonest; + +#ifdef ENABLE_SOFTSPOKEN_OT + SoftSpokenMalOtSender mOtExtSender; + SoftSpokenMalOtReceiver mOtExtRecver; +#endif + + MultType mMultType = DefaultMultType; +#ifdef ENABLE_INSECURE_SILVER + SilverEncoder mEncoder; +#endif + ExConvCode mExConvEncoder; + + AlignedUnVector mB; + + ///////////////////////////////////////////////////// + // The standard OT extension interface + ///////////////////////////////////////////////////// + +// // the number of IKNP base OTs that should be set. +// u64 baseOtCount() const; +// +// // returns true if the IKNP base OTs are currently set. +// bool hasBaseOts() const; +// +// // sets the IKNP base OTs that are then used to extend +// void setBaseOts( +// span baseRecvOts, +// const BitVector& choices); + + // use the default base OT class to generate the + // IKNP base OTs that are required. +// task<> genBaseOts(PRNG& prng, Socket& chl) +// { +// return mOtExtSender.genBaseOts(prng, chl); +// } + + ///////////////////////////////////////////////////// + // The native silent OT extension interface + ///////////////////////////////////////////////////// + + u64 baseVoleCount() const { + return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + } + + // Generate the silent base OTs. If the Iknp + // base OTs are set then we do an IKNP extend, + // otherwise we perform a base OT protocol to + // generate the needed OTs. + task<> genSilentBaseOts(PRNG& prng, Socket& chl, cp::optional delta = {}) + { + using BaseOT = DefaultBaseOT; + + + MC_BEGIN(task<>, this, delta, &prng, &chl, + msg = AlignedUnVector>(silentBaseOtCount()), + baseOt = BaseOT{}, + prng2 = std::move(PRNG{}), + xx = BitVector{}, + chl2 = Socket{}, + nv = NoisySubfieldVoleSender{}, + noiseDeltaShares = std::vector{} + ); + setTimePoint("SilentVoleSender.genSilent.begin"); + + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + + delta = delta.value_or(TypeTrait::fromBlock(prng.get())); + xx = TypeTrait::BitVectorF(*delta); + + // compute the correlation for the noisy coordinates. + noiseDeltaShares.resize(baseVoleCount()); + + + if (mBaseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + + if (mOtExtRecver.hasBaseOts() == false) + { + msg.resize(msg.size() + mOtExtRecver.baseOtCount()); + MC_AWAIT(mOtExtSender.send(msg, prng, chl)); + + mOtExtRecver.setBaseOts( + span>(msg).subspan( + msg.size() - mOtExtRecver.baseOtCount(), + mOtExtRecver.baseOtCount())); + msg.resize(msg.size() - mOtExtRecver.baseOtCount()); + + MC_AWAIT(nv.send(*delta, noiseDeltaShares, prng, mOtExtRecver, chl)); + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + MC_AWAIT( + macoro::when_all_ready( + nv.send(*delta, noiseDeltaShares, prng2, mOtExtRecver, chl2), + mOtExtSender.send(msg, prng, chl))); + } +#else + +#endif + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + MC_AWAIT(baseOt.send(msg, prng, chl)); + MC_AWAIT(nv.send(*delta, noiseDeltaShares, prng2, baseOt, chl2)); + // MC_AWAIT( + // macoro::when_all_ready( + // nv.send(*delta, noiseDeltaShares, prng2, baseOt, chl2), + // baseOt.send(msg, prng, chl))); + } + + + setSilentBaseOts(msg, noiseDeltaShares); + setTimePoint("SilentVoleSender.genSilent.done"); + MC_END(); + } + + // configure the silent OT extension. This sets + // the parameters and figures out how many base OT + // will be needed. These can then be ganerated for + // a different OT extension or using a base OT protocol. + void configure( + u64 numOTs, + SilentBaseType type = SilentBaseType::BaseExtend, + u64 secParam = 128) + { + mBaseType = type; + u64 gap = 0; + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + + SubfieldExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); + break; + default: + throw RTE_LOC; + break; + } + + mGapOts.resize(gap); + mGen.configure(mSizePer, mNumPartitions); + + mState = State::Configured; + } + + // return true if this instance has been configured. + bool isConfigured() const { return mState != State::Default; } + + // Returns how many base OTs the silent OT extension + // protocol will needs. + u64 silentBaseOtCount() const + { + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + return mGen.baseOtCount() + mGapOts.size(); + } + + // Set the externally generated base OTs. This choice + // bits must be the one return by sampleBaseChoiceBits(...). + void setSilentBaseOts( + span> sendBaseOts, + span noiseDeltaShares) + { + if ((u64)sendBaseOts.size() != silentBaseOtCount()) + throw RTE_LOC; + + if (noiseDeltaShares.size() != baseVoleCount()) + throw RTE_LOC; + + auto genOt = sendBaseOts.subspan(0, mGen.baseOtCount()); + auto gapOt = sendBaseOts.subspan(genOt.size(), mGapOts.size()); + + mGen.setBase(genOt); + std::copy(gapOt.begin(), gapOt.end(), mGapOts.begin()); + mNoiseDeltaShares.resize(noiseDeltaShares.size()); + std::copy(noiseDeltaShares.begin(), noiseDeltaShares.end(), mNoiseDeltaShares.begin()); + } + + // The native OT extension interface of silent + // OT. The receiver does not get to specify + // which OT message they receiver. Instead + // the protocol picks them at random. Use the + // send(...) interface for the normal behavior. + task<> silentSend( + F delta, + span b, + PRNG& prng, + Socket& chl) + { + MC_BEGIN(task<>, this, delta, b, &prng, &chl); + + MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); + + std::memcpy(b.data(), mB.data(), b.size() * TypeTrait::bytesF); + clear(); + + setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); + MC_END(); + } + + // The native OT extension interface of silent + // OT. The receiver does not get to specify + // which OT message they receiver. Instead + // the protocol picks them at random. Use the + // send(...) interface for the normal behavior. + task<> silentSendInplace( + F delta, + u64 n, + PRNG& prng, + Socket& chl) + { + MC_BEGIN(task<>, this, delta, n, &prng, &chl, + gapVals = std::vector{}, + deltaShare = block{}, + X = block{}, + hash = std::array{}, + noiseShares = span{}, + mbb = span{} + ); + setTimePoint("SilentVoleSender.ot.enter"); + + + if (isConfigured() == false) + { + // first generate 128 normal base OTs + configure(n, SilentBaseType::BaseExtend); + // configure(n, SilentBaseType::Base); + } + + if (mRequestedNumOTs != n) + throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); + + if (mGen.hasBaseOts() == false) + { + // recvs data + MC_AWAIT(genSilentBaseOts(prng, chl, delta)); + } + + setTimePoint("SilentVoleSender.start"); + //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); + +// if (mMalType == SilentSecType::Malicious) +// { +// deltaShare = mNoiseDeltaShares.back(); +// mNoiseDeltaShares.pop_back(); +// } + + // allocate B + mB.resize(0); + mB.resize(mN2); + + // derandomize the random OTs for the gap + // to have the desired correlation. + gapVals.resize(mGapOts.size()); + for (u64 i = mNumPartitions * mSizePer, j = 0; i < mN2; ++i, ++j) + { + auto t = TypeTrait::fromBlock(mGapOts[j][0]); + auto v = TypeTrait::plus(t, mNoiseDeltaShares[mNumPartitions + j]); + gapVals[j] = TypeTrait::plus( + TypeTrait::fromBlock(AES(mGapOts[j][1]).ecbEncBlock(ZeroBlock)), + v); + mB[i] = t; + } + + if (gapVals.size()) + MC_AWAIT(chl.send(std::move(gapVals))); + + + if (mTimer) + mGen.setTimer(*mTimer); + + // program the output the PPRF to be secret shares of + // our secret share of delta * noiseVals. The receiver + // can then manually add their shares of this to the + // output of the PPRF at the correct locations. + noiseShares = span(mNoiseDeltaShares.data(), mNumPartitions); + mbb = mB.subspan(0, mNumPartitions * mSizePer); + MC_AWAIT(mGen.expand(chl, noiseShares, prng, mbb, + PprfOutputFormat::Interleaved, true, mNumThreads)); + + setTimePoint("SilentVoleSender.expand.pprf_transpose"); + if (mDebug) + { + MC_AWAIT(checkRT(chl, delta)); + setTimePoint("SilentVoleSender.expand.checkRT"); + } + + + // if (mMalType == SilentSecType::Malicious) + // { + // MC_AWAIT(chl.recv(X)); + // hash = ferretMalCheck(X, deltaShare); + // MC_AWAIT(chl.send(std::move(hash))); + // } + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + if (mTimer) { + mExConvEncoder.setTimer(getTimer()); + } + mExConvEncoder.template dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); + break; + default: + throw RTE_LOC; + break; + } + + + mB.resize(mRequestedNumOTs); + + mState = State::Default; + mNoiseDeltaShares.clear(); + + MC_END(); + } + + bool mDebug = false; + + task<> checkRT(Socket& chl, F delta) const + { + MC_BEGIN(task<>, this, &chl, delta); + MC_AWAIT(chl.send(delta)); + MC_AWAIT(chl.send(mB)); + MC_AWAIT(chl.send(mNoiseDeltaShares)); + MC_END(); + } + + std::array ferretMalCheck(block X, block deltaShare) + { + + auto xx = X; + block sum0 = ZeroBlock; + block sum1 = ZeroBlock; + for (u64 i = 0; i < (u64)mB.size(); ++i) + { + block low, high; + xx.gf128Mul(mB[i], low, high); + sum0 = sum0 ^ low; + sum1 = sum1 ^ high; + + xx = xx.gf128Mul(X); + } + + block mySum = sum0.gf128Reduce(sum1); + + std::array myHash; + RandomOracle ro(32); + ro.Update(mySum ^ deltaShare); + ro.Final(myHash); + + return myHash; + //chl.send(myHash); + } + + void clear() + { + mB = {}; + mGen.clear(); + } + }; + +} + +#endif \ No newline at end of file diff --git a/libOTe_Tests/Subfield_Test.h b/libOTe_Tests/Subfield_Test.h new file mode 100644 index 00000000..92d8babb --- /dev/null +++ b/libOTe_Tests/Subfield_Test.h @@ -0,0 +1,13 @@ +#include "cryptoTools/Common/CLP.h" + + +namespace osuCrypto::Subfield +{ + + +void Subfield_ExConvCode_encode_test(const oc::CLP& cmd); +void Subfield_Tools_Pprf_test(const oc::CLP& cmd); +void Subfield_Noisy_Vole_test(const oc::CLP& cmd); +void Subfield_Silent_Vole_test(const oc::CLP& cmd); + +} \ No newline at end of file diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp new file mode 100644 index 00000000..963e03bc --- /dev/null +++ b/libOTe_Tests/Subfield_Tests.cpp @@ -0,0 +1,462 @@ +#include "Subfield_Test.h" +#include "libOTe/Tools/Subfield/Subfield.h" +#include "libOTe/Tools/Subfield/ExConvCode.h" +#include "libOTe/Vole/Subfield/NoisyVoleSender.h" +#include "libOTe/Vole/Subfield/NoisyVoleReceiver.h" +#include "libOTe/Vole/Subfield/SilentVoleSender.h" +#include "libOTe/Vole/Subfield/SilentVoleReceiver.h" + +#include "Common.h" + +namespace osuCrypto::Subfield +{ + + using tests_libOTe::eval; + + void Subfield_ExConvCode_encode_test(const oc::CLP& cmd) + { + { + u64 n = 1024; + ExConvCode code; + code.config(n / 2, n, 7, 24, true); + + PRNG prng(ZeroBlock); + block delta = prng.get(); + std::vector y(n), z0(n), z1(n); + prng.get(y.data(), y.size()); + prng.get(z0.data(), z0.size()); + for (u64 i = 0; i < n; ++i) + { + z1[i] = z0[i] ^ delta.gf128Mul(y[i]); + } + + code.dualEncode(z1); +// code.dualEncode(z0); +// code.dualEncode(y); + code.dualEncode2(z0, y); + + for (u64 i = 0; i < n; ++i) + { + block left = delta.gf128Mul(y[i]); + block right = z1[i] ^ z0[i]; + if (left != right) + throw RTE_LOC; + } + } + + { + u64 n = 1024; + ExConvCode> code; + code.config(n / 2, n, 7, 24, true); + + PRNG prng(ZeroBlock); + u8 delta = 111; + std::vector y(n), z0(n), z1(n); + prng.get(y.data(), y.size()); + prng.get(z0.data(), z0.size()); + for (u64 i = 0; i < n; ++i) + { + z1[i] = z0[i] + delta * y[i]; + } + + code.dualEncode(z1); + code.dualEncode2(z0, y); + + for (u64 i = 0; i < n; ++i) + { + u8 left = delta * y[i]; + u8 right = z1[i] - z0[i]; + if (left != right) + throw RTE_LOC; + } + } + + { + u64 n = 1024; + ExConvCode code; + code.config(n / 2, n, 7, 24, true); + + PRNG prng(ZeroBlock); + u64 delta = 111; + std::vector y(n), z0(n), z1(n); + prng.get(y.data(), y.size()); + prng.get(z0.data(), z0.size()); + for (u64 i = 0; i < n; ++i) + { + z1[i] = z0[i] + delta * y[i]; + } + + code.dualEncode(z1); + code.dualEncode2(z0, y); + + for (u64 i = 0; i < n; ++i) + { + u64 left = delta * y[i]; + u64 right = z1[i] - z0[i]; + if (left != right) + throw RTE_LOC; + } + } + } + + void Subfield_Tools_Pprf_test(const oc::CLP& cmd) { +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + //{ + // u64 domain = cmd.getOr("d", 16); + // auto threads = cmd.getOr("t", 1); + // u64 numPoints = cmd.getOr("s", 1) * 8; + + // PRNG prng(ZeroBlock); + + // auto sockets = cp::LocalAsyncSocket::makePair(); + + // auto format = PprfOutputFormat::Interleaved; + // SilentSubfieldPprfSender sender; + // SilentSubfieldPprfReceiver recver; + + // sender.configure(domain, numPoints); + // recver.configure(domain, numPoints); + + // auto numOTs = sender.baseOtCount(); + // std::vector> sendOTs(numOTs); + // std::vector recvOTs(numOTs); + // BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); + // //recvBits.randomize(prng); + + // //recvBits[16] = 1; + // prng.get(sendOTs.data(), sendOTs.size()); + // for (u64 i = 0; i < numOTs; ++i) { + // //recvBits[i] = 0; + // recvOTs[i] = sendOTs[i][recvBits[i]]; + // } + // sender.setBase(sendOTs); + // recver.setBase(recvOTs); + + // //auto cols = (numPoints * domain + 127) / 128; + // Matrix sOut2(numPoints * domain, 1); + // Matrix rOut2(numPoints * domain, 1); + // std::vector points(numPoints); + // recver.getPoints(points, format); + + // std::vector arr(numPoints); + // prng.get(arr.data(), arr.size()); + // auto p0 = sender.expand(sockets[0], arr, prng, sOut2, format, true, threads); + // auto p1 = recver.expand(sockets[1], prng, rOut2, format, true, threads); + + // eval(p0, p1); + // for (u64 i = 0; i < numPoints; i++) { + // u64 point = points[i]; + // auto exp = sOut2(point) + arr[i]; + // if (exp != rOut2(point)) { + // throw RTE_LOC; + // } + // } + //} + +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif + } + + void Subfield_Noisy_Vole_test(const oc::CLP& cmd) { + + { + Timer timer; + timer.setTimePoint("start"); + u64 n = cmd.getOr("n", 400); + block seed = block(0, cmd.getOr("seed", 0)); + PRNG prng(seed); + + u64 x = prng.get(); + std::vector y(n); + std::vector z0(n), z1(n); + prng.get(y.data(), y.size()); + + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; + + recv.setTimer(timer); + send.setTimer(timer); + + auto chls = cp::LocalAsyncSocket::makePair(); + timer.setTimePoint("net"); + + BitVector recvChoice((u8*)&x, 64); + std::vector otRecvMsg(64); + std::vector> otSendMsg(64); + prng.get>(otSendMsg); + for (u64 i = 0; i < 64; ++i) + otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; + timer.setTimePoint("ot"); + + auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); + auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); + + eval(p0, p1); + + for (u64 i = 0; i < n; ++i) + { + if (x * y[i] != (z1[i] - z0[i])) + { + throw RTE_LOC; + } + } + timer.setTimePoint("done"); + + //std::cout << timer << std::endl; + } + + { + Timer timer; + timer.setTimePoint("start"); + u64 n = cmd.getOr("n", 400); + block seed = block(0, cmd.getOr("seed", 0)); + PRNG prng(seed); + + constexpr size_t N = 3; + using TypeTrait = TypeTraitVec; + u64 bitsF = TypeTrait::bitsF; + using F = TypeTrait::F; + using G = TypeTrait::G; + + F x = TypeTrait::fromBlock(prng.get()); + std::vector y(n); + std::vector z0(n), z1(n); + prng.get(y.data(), y.size()); + + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; + + recv.setTimer(timer); + send.setTimer(timer); + + auto chls = cp::LocalAsyncSocket::makePair(); + timer.setTimePoint("net"); + + BitVector recvChoice((u8*)&x, bitsF); + std::vector otRecvMsg(bitsF); + std::vector> otSendMsg(bitsF); + prng.get>(otSendMsg); + for (u64 i = 0; i < bitsF; ++i) + otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; + timer.setTimePoint("ot"); + + auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); + auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); + + eval(p0, p1); + // std::cout << "transferred " << (chls[0].bytesSent() + chls[0].bytesReceived()) << std::endl; + timer.setTimePoint("verify"); + + for (u64 i = 0; i < n; ++i) + { + for (u64 j = 0; j < N; j++) { + G left = x[j] * y[i]; + G right = z1[i][j] - z0[i][j]; + if (left != right) + { + throw RTE_LOC; + } + } + } + timer.setTimePoint("done"); + + // std::cout << timer << std::endl; + } + + { + Timer timer; + timer.setTimePoint("start"); + u64 n = cmd.getOr("n", 400); + block seed = block(0, cmd.getOr("seed", 0)); + PRNG prng(seed); + + block x = prng.get(); + std::vector y(n); + std::vector z0(n), z1(n); + prng.get(y.data(), y.size()); + + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; + + recv.setTimer(timer); + send.setTimer(timer); + + auto chls = cp::LocalAsyncSocket::makePair(); + timer.setTimePoint("net"); + + size_t k = 128; + BitVector recvChoice((u8*)&x, k); + std::vector otRecvMsg(k); + std::vector> otSendMsg(k); + prng.get>(otSendMsg); + for (u64 i = 0; i < k; ++i) + otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; + timer.setTimePoint("ot"); + + auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); + auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); + + eval(p0, p1); + + for (u64 i = 0; i < n; ++i) + { + if (x.gf128Mul(y[i]) != (z1[i] ^ z0[i])) + { + throw RTE_LOC; + } + } + timer.setTimePoint("done"); + + //std::cout << timer << std::endl; + } + } + + void Subfield_Silent_Vole_test(const oc::CLP& cmd) { + using namespace oc::Subfield; +#if defined(ENABLE_SILENTOT) + Timer timer; + timer.setTimePoint("start"); + u64 n = cmd.getOr("n", 102043); + u64 nt = cmd.getOr("nt", std::thread::hardware_concurrency()); + block seed = block(0, cmd.getOr("seed", 0)); + + { + PRNG prng(seed); + u64 x = TypeTrait64::fromBlock(prng.get()); + std::vector c(n), z0(n), z1(n); + + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; + + recv.mMultType = MultType::ExConv7x24; + send.mMultType = MultType::ExConv7x24; + + recv.setTimer(timer); + send.setTimer(timer); + +// recv.mDebug = true; +// send.mDebug = true; + + auto chls = cp::LocalAsyncSocket::makePair(); + + timer.setTimePoint("net"); + + timer.setTimePoint("ot"); + // fakeBase(n, nt, prng, x, recv, send); + + auto p0 = send.silentSend(x, span(z0), prng, chls[0]); + auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); + + eval(p0, p1); + timer.setTimePoint("send"); + for (u64 i = 0; i < n; ++i) { + u64 left = c[i] * x; + u64 right = z1[i] - z0[i]; + if (left != right) { + std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; + std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; + throw RTE_LOC; + } + } + } + + { + PRNG prng(seed); + constexpr size_t N = 10; + using TypeTrait = TypeTraitVec; + using F = TypeTrait::F; + using G = TypeTrait::G; + F x = TypeTrait::fromBlock(prng.get()); + std::vector c(n); + std::vector z0(n), z1(n); + + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; + + recv.mMultType = MultType::ExConv7x24; + send.mMultType = MultType::ExConv7x24; + + recv.setTimer(timer); + send.setTimer(timer); + + // recv.mDebug = true; + // send.mDebug = true; + + auto chls = cp::LocalAsyncSocket::makePair(); + + timer.setTimePoint("net"); + + timer.setTimePoint("ot"); + // fakeBase(n, nt, prng, x, recv, send); + + auto p0 = send.silentSend(x, span(z0), prng, chls[0]); + auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); + + eval(p0, p1); + // std::cout << "transferred " << (chls[0].bytesSent() + chls[0].bytesReceived()) << std::endl; + timer.setTimePoint("verify"); + + timer.setTimePoint("send"); + for (u64 i = 0; i < n; i++) { + for (u64 j = 0; j < N; j++) { + G left = c[i] * x[j]; + G right = z1[i][j] - z0[i][j]; + if (left != right) { + std::cout << "bad " << i << "\n c[i] " << c[i] << " * x[j] " << x[j] << " = " << left << std::endl; + std::cout << "z0[i][j] " << z0[i][j] << " - z1 " << z1[i][j] << " = " << right << std::endl; + throw RTE_LOC; + } + } + } + } + + { + PRNG prng(seed); + block x = prng.get(); + std::vector c(n), z0(n), z1(n); + + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; + + recv.mMultType = MultType::ExConv7x24; + send.mMultType = MultType::ExConv7x24; + + recv.setTimer(timer); + send.setTimer(timer); + +// recv.mDebug = true; +// send.mDebug = true; + + auto chls = cp::LocalAsyncSocket::makePair(); + + timer.setTimePoint("net"); + + timer.setTimePoint("ot"); + // fakeBase(n, nt, prng, x, recv, send); + + auto p0 = send.silentSend(x, span(z0), prng, chls[0]); + auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); + + eval(p0, p1); + timer.setTimePoint("send"); + for (u64 i = 0; i < n; ++i) { + block left = x.gf128Mul(c[i]); + block right = z1[i] ^ z0[i]; + if (left != right) { + std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; + std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; + throw RTE_LOC; + } + } + } + + timer.setTimePoint("done"); + // std::cout << timer << std::endl; +#else + throw UnitTestSkipped("not defined." LOCATION); +#endif + } + +} \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 1c1e72a0..97d0c5ac 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -16,6 +16,7 @@ #include "libOTe_Tests/ExConvCode_Tests.h" #include "libOTe_Tests/EACode_Tests.h" #include "libOTe/Tools/LDPC/Mtx.h" +#include "libOTe_Tests/Subfield_Test.h" using namespace osuCrypto; namespace tests_libOTe @@ -116,5 +117,10 @@ namespace tests_libOTe tc.add("NcoOt_Oos_Test ", NcoOt_Oos_Test); tc.add("NcoOt_genBaseOts_Test ", NcoOt_genBaseOts_Test); + tc.add("Subfield_ExConvCode_encode_test ", Subfield::Subfield_ExConvCode_encode_test); + tc.add("Subfield_Tools_Pprf_test ", Subfield::Subfield_Tools_Pprf_test); + tc.add("Subfield_Noisy_Vole_test ", Subfield::Subfield_Noisy_Vole_test); + tc.add("Subfield_Silent_Vole_test ", Subfield::Subfield_Silent_Vole_test); + }); } From eb99b6286497a89714ffe1312dc01fbcbbbd773c Mon Sep 17 00:00:00 2001 From: Halulu Date: Thu, 11 Jan 2024 12:53:35 +0800 Subject: [PATCH 02/23] merge new pprf --- libOTe/Tools/Subfield/SubfieldPprf.h | 1716 +++++++++------------ libOTe/Vole/Subfield/SilentVoleReceiver.h | 2 +- libOTe/Vole/Subfield/SilentVoleSender.h | 2 +- 3 files changed, 689 insertions(+), 1031 deletions(-) diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h index bcd8bbc4..24549b86 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -13,6 +13,17 @@ namespace osuCrypto::Subfield { + namespace + { + // A public PRF/PRG that we will use for deriving the GGM tree. + const std::array gAes = []() { + std::array aes; + aes[0].setKey(toBlock(3242342)); + aes[1].setKey(toBlock(8993849)); + return aes; + }(); + } + template void copyOut( span> lvl, @@ -22,56 +33,7 @@ namespace osuCrypto::Subfield PprfOutputFormat oFormat, std::function> lvl)>& callback) { - - if (oFormat == PprfOutputFormat::InterleavedTransposed) - { - // not having an even (8) number of trees is not supported. - if (totalTrees % 8) - throw RTE_LOC; - if (lvl.size() % 16) - throw RTE_LOC; - - // - //auto rowsPer = 16; - //auto step = lvl.size() - - //auto sectionSize = - - if (lvl.size() < 16) - throw RTE_LOC; - - - auto setIdx = tIdx / 8; - auto blocksPerSet = lvl.size() * 8 / 128; - - - - auto numSets = totalTrees / 8; - auto begin = setIdx; - auto step = numSets; - - if (oFormat == PprfOutputFormat::InterleavedTransposed) - { - // todo - throw RTE_LOC; - // auto end = std::min(begin + step * blocksPerSet, output.cols()); - - // for (u64 i = begin, k = 0; i < end; i += step, ++k) - // { - // auto& io = *(std::array*)(&lvl[k * 16]); - // transpose128(io.data()); - // for (u64 j = 0; j < 128; ++j) - // output(j, i) = io[j]; - // } - } - else - { - // no op - } - - - } - else if (oFormat == PprfOutputFormat::Plain) + if (oFormat == PprfOutputFormat::ByLeafIndex) { auto curSize = std::min(totalTrees - tIdx, 8); @@ -104,7 +66,7 @@ namespace osuCrypto::Subfield } } - else if (oFormat == PprfOutputFormat::BlockTransposed) + else if (oFormat == PprfOutputFormat::ByTreeIndex) { auto curSize = std::min(totalTrees - tIdx, 8); @@ -134,19 +96,91 @@ namespace osuCrypto::Subfield } } + else if (oFormat == PprfOutputFormat::Callback) + callback(tIdx, lvl); + else + throw RTE_LOC; + } + + template + void allocateExpandBuffer( + u64 depth, + u64 activeChildXorDelta, + std::vector& buff, + span< std::array, 2>>& sums, + span< std::array>& last) + { + + using SumType = std::array, 2>; + using LastType = std::array; + u64 numSums = depth - activeChildXorDelta; + u64 numLast = activeChildXorDelta * 8; + u64 numBytes = numSums * 16 * 16 + numLast * 4 * sizeof(F); + buff.resize(numBytes); + sums = span((SumType*)buff.data(), numSums); + last = span((LastType*)(sums.data() + sums.size()), numLast); + + void* sEnd = sums.data() + sums.size(); + void* lEnd = last.data() + last.size(); + void* end = buff.data() + buff.size(); + if (sEnd > end || lEnd > end) + throw RTE_LOC; + } + + template + void validateExpandFormat( + PprfOutputFormat oFormat, + MatrixView output, + u64 domain, + u64 pntCount + ) + { + + if (oFormat == PprfOutputFormat::ByLeafIndex) + { + if (output.rows() != domain) + throw RTE_LOC; + + if (output.cols() != pntCount) + throw RTE_LOC; + } + else if (oFormat == PprfOutputFormat::ByTreeIndex) + { + if (output.cols() != domain) + throw RTE_LOC; + + if (output.rows() != pntCount) + throw RTE_LOC; + } else if (oFormat == PprfOutputFormat::Interleaved) { - // no op + if (output.cols() != 1) + throw RTE_LOC; + if (domain & 1) + throw RTE_LOC; + + auto rows = output.rows(); + if (rows > (domain * pntCount) || + rows / 128 != (domain * pntCount) / 128) + throw RTE_LOC; + if (pntCount & 7) + throw RTE_LOC; } else if (oFormat == PprfOutputFormat::Callback) - callback(tIdx, lvl); + { + if (domain & 1) + throw RTE_LOC; + if (pntCount & 7) + throw RTE_LOC; + } else + { throw RTE_LOC; + } } template - class SilentSubfieldPprfSender : public TimerAdapter - { + class SilentSubfieldPprfSender : public TimerAdapter { public: using F = typename TypeTrait::F; u64 mDomain = 0, mDepth = 0, mPntCount = 0; @@ -154,21 +188,24 @@ namespace osuCrypto::Subfield bool mPrint = false; TreeAllocator mTreeAlloc; Matrix> mBaseOTs; - - std::function>)> mOutputFn; + + std::function< + void(u64 + treeIdx, span >)> + mOutputFn; SilentSubfieldPprfSender() = default; - SilentSubfieldPprfSender(const SilentSubfieldPprfSender&) = delete; - SilentSubfieldPprfSender(SilentSubfieldPprfSender&&) = delete; - SilentSubfieldPprfSender(u64 domainSize, u64 pointCount) - { + SilentSubfieldPprfSender(const SilentSubfieldPprfSender &) = delete; + + SilentSubfieldPprfSender(SilentSubfieldPprfSender &&) = delete; + + SilentSubfieldPprfSender(u64 domainSize, u64 pointCount) { configure(domainSize, pointCount); } - void configure(u64 domainSize, u64 pointCount) - { + void configure(u64 domainSize, u64 pointCount) { mDomain = domainSize; mDepth = log2ceil(mDomain); mPntCount = pointCount; @@ -179,14 +216,12 @@ namespace osuCrypto::Subfield // the number of base OTs that should be set. - u64 baseOtCount() const - { + u64 baseOtCount() const { return mDepth * mPntCount; } // returns true if the base OTs are currently set. - bool hasBaseOts() const - { + bool hasBaseOts() const { return mBaseOTs.size(); } @@ -200,456 +235,311 @@ namespace osuCrypto::Subfield mBaseOTs(i) = baseMessages[i]; } - task<> expand(Socket& chls, span value, PRNG& prng, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) - { + task<> expand(Socket &chls, span value, block seed, span output, PprfOutputFormat oFormat, + bool activeChildXorDelta, u64 numThreads) { MatrixView o(output.data(), output.size(), 1); - return expand(chls, value, prng, o, oFormat, activeChildXorDelta, numThreads); + return expand(chls, value, seed, o, oFormat, activeChildXorDelta, numThreads); } task<> expand( - Socket& chl, - span value, - PRNG& prng, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 numThreads) - { + Socket &chl, + span value, + block seed, + MatrixView output, + PprfOutputFormat oFormat, + bool activeChildXorDelta, + u64 numThreads) { if (activeChildXorDelta) setValue(value); - setTimePoint("SilentMultiPprfSender.start"); - //gTimer.setTimePoint("send.enter"); - - - if (oFormat == PprfOutputFormat::Plain) - { - if (output.rows() != mDomain) - throw RTE_LOC; + setTimePoint("SilentMultiPprfSender.start"); - if (output.cols() != mPntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::BlockTransposed) - { - if (output.cols() != mDomain) - throw RTE_LOC; + validateExpandFormat(oFormat, output, mDomain, mPntCount); + + MC_BEGIN(task<>, this, numThreads, oFormat, output, seed, &chl, activeChildXorDelta, + i = u64{}, + mTreeAllocDepth = u64{}, + tree = span < AlignedArray>{}, + levels = std::vector> > {}, + lastLevel = span < AlignedArray>{}, + buff = std::vector{}, + sums = span < std::array, 2>>{}, + last = span < std::array>{} + ); - if (output.rows() != mPntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::InterleavedTransposed) - { - if (output.rows() != 128) - throw RTE_LOC; + //if (oFormat == PprfOutputFormat::Callback && numThreads > 1) + // throw RTE_LOC; - //if (output.cols() > (mDomain * mPntCount + 127) / 128) - // throw RTE_LOC; + mTreeAllocDepth = mDepth + 1; // Subfield + mTreeAlloc.reserve(numThreads, (1ull << mTreeAllocDepth) + 2); + setTimePoint("SilentMultiPprfSender.reserve"); - if (mPntCount & 7) - throw RTE_LOC; - } - else if - (oFormat == PprfOutputFormat::Interleaved) - { - if (output.cols() != 1) - throw RTE_LOC; - if (mDomain & 1) - throw RTE_LOC; - - auto rows = output.rows(); - if (rows > (mDomain * mPntCount) || - rows / 128 != (mDomain * mPntCount) / 128) - throw RTE_LOC; - if (mPntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Callback) - { - if (mDomain & 1) - throw RTE_LOC; - if (mPntCount & 7) - throw RTE_LOC; - } - else - { - throw RTE_LOC; - } + levels.resize(mDepth + 1); + allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); + for (i = 0; i < mPntCount; i += 8) { + // for interleaved format, the last level of the tree + // is simply the output. + // Subfield: use lastLevel + if (oFormat == PprfOutputFormat::Interleaved) { + auto b = (AlignedArray *) output.data(); + auto forest = i / 8; + b += forest * mDomain; + lastLevel = span < AlignedArray>(b, mDomain); - MC_BEGIN(task<>, this, numThreads, oFormat, output, &prng, &chl, activeChildXorDelta, - i = u64{}, - dd = u64{} - ); +// auto b = (AlignedArray *) output.data(); +// auto forest = i / 8; +// b += forest * mDomain; +// +// levels.back() = span < AlignedArray> +// (b, mDomain); + } else { + throw RTE_LOC; + } + // allocate the send buffer and partition it. + allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); - if (oFormat == PprfOutputFormat::Callback && numThreads > 1) - throw RTE_LOC; + // exapnd the tree + expandOne(seed, i, activeChildXorDelta, levels, lastLevel, sums, last); - dd = mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - mTreeAlloc.reserve(numThreads, (1ull << (dd + 1)) + (32 * (dd+1))); - setTimePoint("SilentMultiPprfSender.reserve"); + MC_AWAIT(chl.send(std::move(buff))); - mExps.clear(); mExps.reserve(divCeil(mPntCount, 8)); - for (i = 0; i < mPntCount; i += 8) - { - mExps.emplace_back(*this, prng.get(), i, oFormat, output, activeChildXorDelta, chl.fork()); - mExps.back().mFuture = macoro::make_eager(mExps.back().run()); - //MC_AWAIT(mExps.back().run()); - } + // if we aren't interleaved, we need to copy the + // last layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) { + // Subfield: no need to copyOut + throw RTE_LOC; +// copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); + } - for (i = 0; i < mExps.size(); ++i) - MC_AWAIT(mExps[i].mFuture); + } - mExps.clear(); - setTimePoint("SilentMultiPprfSender.join"); + mBaseOTs = {}; + mTreeAlloc.del(tree); + mTreeAlloc.clear(); - mBaseOTs = {}; - //mTreeAlloc.clear(); - setTimePoint("SilentMultiPprfSender.de-alloc"); + setTimePoint("SilentMultiPprfSender.de-alloc"); MC_END(); - - } - void setValue(span value) - { + void setValue(span value) { mValue.resize(mPntCount); - if (value.size() == 1) - { + if (value.size() == 1) { std::fill(mValue.begin(), mValue.end(), value[0]); - } - else - { + } else { if ((u64)value.size() != mPntCount) - throw RTE_LOC; + throw RTE_LOC; std::copy(value.begin(), value.end(), mValue.begin()); } } - void clear() - { + void clear() { mBaseOTs.resize(0, 0); mDomain = 0; mDepth = 0; mPntCount = 0; } - struct Expander - { - SilentSubfieldPprfSender& pprf; - Socket chl; - std::array aes; - PRNG prng; - u64 dd, treeIdx, min, d; - bool mActiveChildXorDelta = true; - - macoro::eager_task mFuture; - std::vector>> mLevels; - - //std::unique_ptr uPtr_; - - // tree will hold the full GGM tree. Note that there are 8 - // indepenendent trees that are being processed together. - // The trees are flattenned to that the children of j are - // located at 2*j and 2*j+1. - span> tree; - - // sums will hold the left and right GGM tree sums - // for each level. For example sums[0][i][5] will - // hold the sum of the left children for level i of - // the 5th tree. - std::array>, 2> sums; - // sums for the last level - std::array, 2> lastSums; - std::vector> lastOts; - - PprfOutputFormat oFormat; - - MatrixView output; - + void expandOne( + block aesSeed, + u64 treeIdx, + bool programActivePath, + span >> levels, + span < AlignedArray > lastLevel, + span , 2>> encSums, + span > lastOts) { // The number of real trees for this iteration. - // Returns the i'th level of the current 8 trees. The - // children of node j on level i are located at 2*j and - // 2*j+1 on level i+1. - span> getLevel(u64 i, u64 g) - { - return mLevels[i]; - }; - - span> getLastLevel(u64 i, u64 g) { - if (oFormat == PprfOutputFormat::Interleaved && i == pprf.mDepth) - { - auto b = (AlignedArray*)output.data(); - auto forest = g / 8; - assert(g % 8 == 0); - b += forest * pprf.mDomain; - return span>(b, pprf.mDomain); - } - - throw RTE_LOC; - } - - Expander(SilentSubfieldPprfSender& p, block seed, u64 treeIdx_, - PprfOutputFormat of, MatrixViewo, bool activeChildXorDelta, Socket&& s) - : pprf(p) - , chl(std::move(s)) - , mActiveChildXorDelta(activeChildXorDelta) - { - treeIdx = treeIdx_; - assert((treeIdx & 7) == 0); - output = o; - oFormat = of; - // A public PRF/PRG that we will use for deriving the GGM tree. - aes[0].setKey(toBlock(3242342)); - aes[1].setKey(toBlock(8993849)); - prng.SetSeed(seed); - dd = pprf.mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - } - - task<> run() - { - MC_BEGIN(task<>, this); - - - #ifdef DEBUG_PRINT_PPRF - chl.asyncSendCopy(mValue); - #endif - // pprf.setTimePoint("SilentMultiPprfSender.begin " + std::to_string(treeIdx)); - { - tree = pprf.mTreeAlloc.get(); - assert(tree.size() >= 1ull << (dd)); - assert((u64)tree.data() % 32 == 0); - mLevels.resize(dd+1); - mLevels[0] = tree.subspan(0, 1); - auto rem = tree.subspan(mLevels[0].size()); - for (u64 i = 1; i < dd + 1; i++) - { - while ((u64)rem.data() % 32) - rem = rem.subspan(1); - - mLevels[i] = rem.subspan(0, mLevels[i - 1].size() * 2); - rem = rem.subspan(mLevels[i].size()); - } - } - // pprf.setTimePoint("SilentMultiPprfSender.alloc " + std::to_string(treeIdx)); - - // This thread will process 8 trees at a time. It will interlace - // the sets of trees are processed with the other threads. - { - memset(lastSums[0].data(), 0, lastSums[0].size() * sizeof(F)); - memset(lastSums[1].data(), 0, lastSums[1].size() * sizeof(F)); - - // The number of real trees for this iteration. - min = std::min(8, pprf.mPntCount - treeIdx); - //gTimer.setTimePoint("send.start" + std::to_string(treeIdx)); - - // Populate the zeroth level of the GGM tree with random seeds. - prng.get(getLevel(0, treeIdx)); - - // Allocate space for our sums of each level. - sums[0].resize(pprf.mDepth); - sums[1].resize(pprf.mDepth); - - // For each level perform the following. - for (u64 d = 0; d < pprf.mDepth; ++d) - { - // The previous level of the GGM tree. - auto level0 = getLevel(d, treeIdx); - - // The next level of theGGM tree that we are populating. - auto level1 = getLevel(d + 1, treeIdx); - - // The total number of children in this level. - auto width = static_cast(level1.size()); - - // For each child, populate the child by expanding the parent. - for (u64 childIdx = 0; childIdx < width; ) - { - // Index of the parent in the previous level. - auto parentIdx = childIdx >> 1; - - // The value of the parent. - auto& parent = level0[parentIdx]; - - // The bit that indicates if we are on the left child (0) - // or on the right child (1). - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - // The child that we will write in this iteration. - auto& child = level1[childIdx]; - - // The sum that this child node belongs to. - auto& sum = sums[keep][d]; - - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - aes[keep].template hashBlocks<8>(parent.data(), child.data()); - - if (d < pprf.mDepth - 1) { - // for intermediate levels, same as before - // Update the running sums for this level. We keep - // a left and right totals for each level. - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } else { - if (getLastLevel(pprf.mDepth, treeIdx).size() <= childIdx) { - childIdx = width; - break; - } - auto& realChild = getLastLevel(pprf.mDepth, treeIdx)[childIdx]; - auto& lastSum = lastSums[keep]; - realChild[0] = TypeTrait::fromBlock(child[0]); - lastSum[0] = TypeTrait::plus(lastSum[0], realChild[0]); - realChild[1] = TypeTrait::fromBlock(child[1]); - lastSum[1] = TypeTrait::plus(lastSum[1], realChild[1]); - realChild[2] = TypeTrait::fromBlock(child[2]); - lastSum[2] = TypeTrait::plus(lastSum[2], realChild[2]); - realChild[3] = TypeTrait::fromBlock(child[3]); - lastSum[3] = TypeTrait::plus(lastSum[3], realChild[3]); - realChild[4] = TypeTrait::fromBlock(child[4]); - lastSum[4] = TypeTrait::plus(lastSum[4], realChild[4]); - realChild[5] = TypeTrait::fromBlock(child[5]); - lastSum[5] = TypeTrait::plus(lastSum[5], realChild[5]); - realChild[6] = TypeTrait::fromBlock(child[6]); - lastSum[6] = TypeTrait::plus(lastSum[6], realChild[6]); - realChild[7] = TypeTrait::fromBlock(child[7]); - lastSum[7] = TypeTrait::plus(lastSum[7], realChild[7]); - } - } - } - } - + auto min = std::min(8, mPntCount - treeIdx); + + // the first level should be size 1, the root of the tree. + // we will populate it with random seeds using aesSeed in counter mode + // based on the tree index. + assert(levels[0].size() == 1); + mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); + + assert(encSums.size() == mDepth - programActivePath); + assert(encSums.size() < 24); + + // space for our sums of each level. Should always be less then + // 24 levels... If not increase the limit or make it a vector. + std::array, 2>, 24> sums; + memset(&sums, 0, sizeof(sums)); + + // Subfield: lastSums + std::array, 2> lastSums{}; + + // For each level perform the following. + for (u64 d = 0; d < mDepth; ++d) { + // The previous level of the GGM tree. + auto level0 = levels[d]; + + // The next level of theGGM tree that we are populating. + auto level1 = levels[d + 1]; + + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // use the optimized approach for intern nodes of the tree + if (d + 1 < mDepth && 0) { +// // For each child, populate the child by expanding the parent. +// for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) { +// // The value of the parent. +// auto &parent = level0.data()[parentIdx]; +// +// auto &child0 = level1.data()[childIdx]; +// auto &child1 = level1.data()[childIdx + 1]; +// mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); +// +// child0[0] = child1[0] ^ parent[0]; +// child0[1] = child1[1] ^ parent[1]; +// child0[2] = child1[2] ^ parent[2]; +// child0[3] = child1[3] ^ parent[3]; +// child0[4] = child1[4] ^ parent[4]; +// child0[5] = child1[5] ^ parent[5]; +// child0[6] = child1[6] ^ parent[6]; +// child0[7] = child1[7] ^ parent[7]; +// +// // Update the running sums for this level. We keep +// // a left and right totals for each level. +// auto &sum = sums[d]; +// sum[0][0] = sum[0][0] ^ child0[0]; +// sum[0][1] = sum[0][1] ^ child0[1]; +// sum[0][2] = sum[0][2] ^ child0[2]; +// sum[0][3] = sum[0][3] ^ child0[3]; +// sum[0][4] = sum[0][4] ^ child0[4]; +// sum[0][5] = sum[0][5] ^ child0[5]; +// sum[0][6] = sum[0][6] ^ child0[6]; +// sum[0][7] = sum[0][7] ^ child0[7]; +// +// child1[0] = child1[0] + parent[0]; +// child1[1] = child1[1] + parent[1]; +// child1[2] = child1[2] + parent[2]; +// child1[3] = child1[3] + parent[3]; +// child1[4] = child1[4] + parent[4]; +// child1[5] = child1[5] + parent[5]; +// child1[6] = child1[6] + parent[6]; +// child1[7] = child1[7] + parent[7]; +// +// sum[1][0] = sum[1][0] ^ child1[0]; +// sum[1][1] = sum[1][1] ^ child1[1]; +// sum[1][2] = sum[1][2] ^ child1[2]; +// sum[1][3] = sum[1][3] ^ child1[3]; +// sum[1][4] = sum[1][4] ^ child1[4]; +// sum[1][5] = sum[1][5] ^ child1[5]; +// sum[1][6] = sum[1][6] ^ child1[6]; +// sum[1][7] = sum[1][7] ^ child1[7]; +// +// } + } else { + // for the leaf nodes we need to hash both children. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) { + // The value of the parent. + auto &parent = level0.data()[parentIdx]; - #ifdef DEBUG_PRINT_PPRF - // If we are debugging, then send over the full tree - // to make sure its correct on the other side. - chl.asyncSendCopy(tree); - #endif + // The bit that indicates if we are on the left child (0) + // or on the right child (1). + for (u64 keep = 0; keep < 2; ++keep, ++childIdx) { + // The child that we will write in this iteration. + auto &child = level1[childIdx]; - // For all but the last level, mask the sums with the - // OT strings and send them over. - for (u64 d = 0; d < pprf.mDepth - mActiveChildXorDelta; ++d) - { - for (u64 j = 0; j < min; ++j) - { - #ifdef DEBUG_PRINT_PPRF - if (mPrint) - { - std::cout << "c[" << treeIdx + j << "][" << d << "][0] " << sums[0][d][j] << " " << mBaseOTs[treeIdx + j][d][0] << std::endl;; - std::cout << "c[" << treeIdx + j << "][" << d << "][1] " << sums[1][d][j] << " " << mBaseOTs[treeIdx + j][d][1] << std::endl;; - } - #endif - sums[0][d][j] = sums[0][d][j] ^ pprf.mBaseOTs[treeIdx + j][d][0]; - sums[1][d][j] = sums[1][d][j] ^ pprf.mBaseOTs[treeIdx + j][d][1]; - } - } - // pprf.setTimePoint("SilentMultiPprfSender.expand " + std::to_string(treeIdx)); + // The sum that this child node belongs to. + auto &sum = sums[d][keep]; - if (mActiveChildXorDelta) - { - // For the last level, we are going to do something special. - // The other party is currently missing both leaf children of - // the active parent. Since this is the last level, we want - // the inactive child to just be the normal value but the - // active child should be the correct value XOR the delta. - // This will be done by sending the sums and the sums plus - // delta and ensure that they can only decrypt the correct ones. - d = pprf.mDepth - 1; - //std::vector>& lastOts = lastOts; - lastOts.resize(min); - for (u64 j = 0; j < min; ++j) - { - // Construct the sums where we will allow the delta (mValue) - // to either be on the left child or right child depending - // on which has the active path. - lastOts[j][0] = lastSums[0][j]; - lastOts[j][1] = TypeTrait::plus(lastSums[1][j], pprf.mValue[treeIdx + j]); - lastOts[j][2] = lastSums[1][j]; - lastOts[j][3] = TypeTrait::plus(lastSums[0][j], pprf.mValue[treeIdx + j]); - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - std::array masks, maskIn; - maskIn[0] = pprf.mBaseOTs[treeIdx + j][d][0]; - maskIn[1] = pprf.mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; - maskIn[2] = pprf.mBaseOTs[treeIdx + j][d][1]; - maskIn[3] = pprf.mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; - mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); - - #ifdef DEBUG_PRINT_PPRF - if (mPrint) { - std::cout << "c[" << treeIdx + j << "][" << d << "][0] " << sums[0][d][j] << " " << mBaseOTs[treeIdx + j][d][0] << std::endl;; - std::cout << "c[" << treeIdx + j << "][" << d << "][1] " << sums[1][d][j] << " " << mBaseOTs[treeIdx + j][d][1] << std::endl;; - } - #endif + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gAes[keep].hashBlocks<8>(parent.data(), child.data()); - // Add the OT masks to the sums and send them over. - lastOts[j][0] = TypeTrait::plus(lastOts[j][0], TypeTrait::fromBlock(masks[0])); - lastOts[j][1] = TypeTrait::plus(lastOts[j][1], TypeTrait::fromBlock(masks[1])); - lastOts[j][2] = TypeTrait::plus(lastOts[j][2], TypeTrait::fromBlock(masks[2])); - lastOts[j][3] = TypeTrait::plus(lastOts[j][3], TypeTrait::fromBlock(masks[3])); + if (d == mDepth - 1) { + // Subfield + auto& realChild = lastLevel[childIdx]; + auto& lastSum = lastSums[keep]; + realChild[0] = TypeTrait::fromBlock(child[0]); + lastSum[0] = TypeTrait::plus(lastSum[0], realChild[0]); + realChild[1] = TypeTrait::fromBlock(child[1]); + lastSum[1] = TypeTrait::plus(lastSum[1], realChild[1]); + realChild[2] = TypeTrait::fromBlock(child[2]); + lastSum[2] = TypeTrait::plus(lastSum[2], realChild[2]); + realChild[3] = TypeTrait::fromBlock(child[3]); + lastSum[3] = TypeTrait::plus(lastSum[3], realChild[3]); + realChild[4] = TypeTrait::fromBlock(child[4]); + lastSum[4] = TypeTrait::plus(lastSum[4], realChild[4]); + realChild[5] = TypeTrait::fromBlock(child[5]); + lastSum[5] = TypeTrait::plus(lastSum[5], realChild[5]); + realChild[6] = TypeTrait::fromBlock(child[6]); + lastSum[6] = TypeTrait::plus(lastSum[6], realChild[6]); + realChild[7] = TypeTrait::fromBlock(child[7]); + lastSum[7] = TypeTrait::plus(lastSum[7], realChild[7]); + } else { + // Update the running sums for this level. We keep + // a left and right totals for each level. + sum[0] = sum[0] ^ child[0]; + sum[1] = sum[1] ^ child[1]; + sum[2] = sum[2] ^ child[2]; + sum[3] = sum[3] ^ child[3]; + sum[4] = sum[4] ^ child[4]; + sum[5] = sum[5] ^ child[5]; + sum[6] = sum[6] ^ child[6]; + sum[7] = sum[7] ^ child[7]; } - - // pprf.setTimePoint("SilentMultiPprfSender.last " + std::to_string(treeIdx)); - - // Resize the sums to that they dont include - // the unmasked sums on the last level! - sums[0].resize(pprf.mDepth - 1); - sums[1].resize(pprf.mDepth - 1); } - - // Send the sums to the other party. - //sendOne(treeGrp); - //chl.asyncSend(std::move(sums[0])); - //chl.asyncSend(std::move(sums[1])); - - MC_AWAIT(chl.send(std::move(sums[0]))); - MC_AWAIT(chl.send(std::move(sums[1]))); - - if (mActiveChildXorDelta) - MC_AWAIT(chl.send(std::move(lastOts))); - - - //// send the special OT messages for the last level. - //chl.asyncSend(std::move(lastOts)); - //gTimer.setTimePoint("send.expand_send"); - - // copy the last level to the output. If desired, this is - // where the transpose is performed. - auto lvl = getLastLevel(pprf.mDepth, treeIdx); - - // s is a checksum that is used for malicious security. - copyOut(lvl, output, pprf.mPntCount, treeIdx, oFormat, pprf.mOutputFn); - - // pprf.setTimePoint("SilentMultiPprfSender.copyOut " + std::to_string(treeIdx)); - } + } + } - //uPtr_ = {}; - //tree = {}; - pprf.mTreeAlloc.del(tree); - // pprf.setTimePoint("SilentMultiPprfSender.delete " + std::to_string(treeIdx)); - - MC_END(); + // For all but the last level, mask the sums with the + // OT strings and send them over. + for (u64 d = 0; d < mDepth - programActivePath; ++d) { + for (u64 j = 0; j < min; ++j) { + encSums[d][0][j] = sums[d][0][j] ^ mBaseOTs[treeIdx + j][d][0]; + encSums[d][1][j] = sums[d][1][j] ^ mBaseOTs[treeIdx + j][d][1]; + } } - }; - std::vector mExps; + if (programActivePath) { + // For the last level, we are going to do something special. + // The other party is currently missing both leaf children of + // the active parent. Since this is the last level, we want + // the inactive child to just be the normal value but the + // active child should be the correct value XOR the delta. + // This will be done by sending the sums and the sums plus + // delta and ensure that they can only decrypt the correct ones. + auto d = mDepth - 1; + assert(lastOts.size() == min); + for (u64 j = 0; j < min; ++j) { + // Construct the sums where we will allow the delta (mValue) + // to either be on the left child or right child depending + // on which has the active path. + lastOts[j][0] = lastSums[0][j]; + lastOts[j][1] = TypeTrait::plus(lastSums[1][j], mValue[treeIdx + j]); + lastOts[j][2] = lastSums[1][j]; + lastOts[j][3] = TypeTrait::plus(lastSums[0][j], mValue[treeIdx + j]); + + // We are going to expand the 128 bit OT string + // into a 256 bit OT string using AES. + std::array masks, maskIn; + maskIn[0] = mBaseOTs[treeIdx + j][d][0]; + maskIn[1] = mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; + maskIn[2] = mBaseOTs[treeIdx + j][d][1]; + maskIn[3] = mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; + mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); + + // Add the OT masks to the sums and send them over. + lastOts[j][0] = TypeTrait::plus(lastOts[j][0], TypeTrait::fromBlock(masks[0])); + lastOts[j][1] = TypeTrait::plus(lastOts[j][1], TypeTrait::fromBlock(masks[1])); + lastOts[j][2] = TypeTrait::plus(lastOts[j][2], TypeTrait::fromBlock(masks[2])); + lastOts[j][3] = TypeTrait::plus(lastOts[j][3], TypeTrait::fromBlock(masks[3])); + } + } + } }; @@ -684,10 +574,10 @@ namespace osuCrypto::Subfield } - // For output format Plain or BlockTransposed, the choice bits it + // For output format ByLeafIndex or ByTreeIndex, the choice bits it // samples are in blocks of mDepth, with mPntCount blocks total (one for - // each punctured point). For Plain these blocks encode the punctured - // leaf index in big endian, while for BlockTransposed they are in + // each punctured point). For ByLeafIndex these blocks encode the punctured + // leaf index in big endian, while for ByTreeIndex they are in // little endian. BitVector sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng) { @@ -701,39 +591,43 @@ namespace osuCrypto::Subfield u64 idx; switch (format) { - case osuCrypto::PprfOutputFormat::Plain: - case osuCrypto::PprfOutputFormat::BlockTransposed: - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - } while (idx >= modulus); - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::InterleavedTransposed: - case osuCrypto::PprfOutputFormat::Callback: - - // make sure that at least the first element of this tree - // is within the modulus. - idx = interleavedPoint(0, i, mPntCount, mDomain, format); - if (idx >= modulus) - throw RTE_LOC; + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + do { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = prng.getBit(); + idx = getActivePath(mBaseChoices[i]); + } while (idx >= modulus); + + break; + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::Callback: + + if (modulus > mPntCount * mDomain) + throw std::runtime_error("modulus too big. " LOCATION); + if (modulus < mPntCount * mDomain / 2) + throw std::runtime_error("modulus too small. " LOCATION); + + // make sure that at least the first element of this tree + // is within the modulus. + idx = interleavedPoint(0, i, mPntCount, mDomain, format); + if (idx >= modulus) + throw RTE_LOC; - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); + do { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = prng.getBit(); + idx = getActivePath(mBaseChoices[i]); - idx = interleavedPoint(idx, i, mPntCount, mDomain, format); - } while (idx >= modulus); + idx = interleavedPoint(idx, i, mPntCount, mDomain, format); + } while (idx >= modulus); - break; - default: - throw RTE_LOC; - break; + break; + default: + throw RTE_LOC; + break; } } @@ -756,24 +650,30 @@ namespace osuCrypto::Subfield mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); for (u64 i = 0; i < mPntCount; ++i) { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = choices[mDepth * i + j]; + switch (format) { - case osuCrypto::PprfOutputFormat::Plain: - case osuCrypto::PprfOutputFormat::BlockTransposed: - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = choices[mDepth * i + j]; - break; - - // Not sure what ordering would be good for Interleaved or - // InterleavedTransposed. + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + if (getActivePath(mBaseChoices[i]) >= mDomain) + throw RTE_LOC; - default: - throw RTE_LOC; - break; + break; + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::Callback: + { + auto idx = getActivePath(mBaseChoices[i]); + auto idx2 = interleavedPoint(idx, i, mPntCount, mDomain, format); + if(idx2 > mPntCount * mDomain) + throw std::runtime_error("the base ot choice bits index outside of the domain. see sampleChoiceBits(...). " LOCATION); + break; + } + default: + throw RTE_LOC; + break; } - - if (getActivePath(mBaseChoices[i]) >= mDomain) - throw RTE_LOC; } } @@ -808,135 +708,115 @@ namespace osuCrypto::Subfield getPoints(pnts, format); return pnts; } - void getPoints(span points, PprfOutputFormat format) + void getPoints(span points, PprfOutputFormat format) { switch (format) { - case PprfOutputFormat::Plain: - case PprfOutputFormat::BlockTransposed: + case PprfOutputFormat::ByLeafIndex: + case PprfOutputFormat::ByTreeIndex: - memset(points.data(), 0, points.size() * sizeof(u64)); - for (u64 j = 0; j < mPntCount; ++j) - { - points[j] = getActivePath(mBaseChoices[j]); - } + memset(points.data(), 0, points.size() * sizeof(u64)); + for (u64 j = 0; j < mPntCount; ++j) + { + points[j] = getActivePath(mBaseChoices[j]); + } - break; - case PprfOutputFormat::InterleavedTransposed: - case PprfOutputFormat::Interleaved: - case PprfOutputFormat::Callback: + break; + case PprfOutputFormat::Interleaved: + case PprfOutputFormat::Callback: - if ((u64)points.size() != mPntCount) - throw RTE_LOC; - if (points.size() % 8) + if ((u64)points.size() != mPntCount) throw RTE_LOC; + if (points.size() % 8) + throw RTE_LOC; - getPoints(points, PprfOutputFormat::Plain); - interleavedPoints(points, mDomain, format); + getPoints(points, PprfOutputFormat::ByLeafIndex); + interleavedPoints(points, mDomain, format); - break; - default: - throw RTE_LOC; - break; + break; + default: + throw RTE_LOC; + break; } } - task<> expand(Socket& chl, PRNG& prng, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + task<> expand(Socket& chl, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) { MatrixView o(output.data(), output.size(), 1); - return expand(chl, prng, o, oFormat, activeChildXorDelta, numThreads); + return expand(chl, o, oFormat, activeChildXorDelta, numThreads); } // activeChildXorDelta says whether the sender is trying to program the // active child to be its correct value XOR delta. If it is not, the // active child will just take a random value. - task<> expand(Socket& chl, PRNG& prng, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + task<> expand(Socket& chl, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) { - setTimePoint("SilentMultiPprfReceiver.start"); - - //lout << " d " << mDomain << " p " << mPntCount << " do " << mDepth << std::endl; - - if (oFormat == PprfOutputFormat::Plain) - { - if (output.rows() != mDomain) - throw RTE_LOC; - - if (output.cols() != mPntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::BlockTransposed) - { - if (output.cols() != mDomain) - throw RTE_LOC; + validateExpandFormat(oFormat, output, mDomain, mPntCount); + + MC_BEGIN(task<>, this, oFormat, output, &chl, activeChildXorDelta, + i = u64{}, + mTreeAllocDepth = u64{}, + tree = span>{}, + levels = std::vector>>{}, + lastLevel = span < AlignedArray>{}, + buff = std::vector{}, + sums = span, 2>>{}, + last = span>{} + ); - if (output.rows() != mPntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::InterleavedTransposed) - { - if (output.rows() != 128) - throw RTE_LOC; + setTimePoint("SilentMultiPprfReceiver.start"); + mPoints.resize(roundUpTo(mPntCount, 8)); + getPoints(mPoints, PprfOutputFormat::ByLeafIndex); - //if (output.cols() > (mDomain * mPntCount + 127) / 128) - // throw RTE_LOC; + mTreeAllocDepth = mDepth + 1; // Subfield + mTreeAlloc.reserve(1, (1ull << mTreeAllocDepth) + 2); + setTimePoint("SilentMultiPprfSender.reserve"); - if (mPntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Interleaved) - { - if (output.cols() != 1) - throw RTE_LOC; - if (mDomain & 1) - throw RTE_LOC; - auto rows = output.rows(); - if (rows > (mDomain * mPntCount) || - rows / 128 != (mDomain * mPntCount) / 128) - throw RTE_LOC; - if (mPntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Callback) - { - if (mDomain & 1) - throw RTE_LOC; - if (mPntCount & 7) - throw RTE_LOC; - } - else - { - throw RTE_LOC; - } + levels.resize(mDepth + 1); + allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); - mPoints.resize(roundUpTo(mPntCount, 8)); - getPoints(mPoints, PprfOutputFormat::Plain); + for (i = 0; i < mPntCount; i += 8) + { + // for interleaved format, the last level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + // Subfield + auto b = (AlignedArray *) output.data(); + auto forest = i / 8; + b += forest * mDomain; + lastLevel = span < AlignedArray>(b, mDomain); + +// auto b = (AlignedArray*)output.data(); +// auto forest = i / 8; +// b += forest * mDomain; +// levels.back() = span>(b, mDomain); + } + // allocate the send buffer and partition it. + allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); - MC_BEGIN(task<>, this, numThreads, oFormat, output, &chl, activeChildXorDelta, - i = u64{}, - dd = u64{} - ); + MC_AWAIT(chl.recv(buff)); + // exapnd the tree + expandOne(i, activeChildXorDelta, levels, lastLevel, sums, last); - dd = mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - mTreeAlloc.reserve(numThreads, (1ull << (dd+1)) + (32 * (dd+1))); - setTimePoint("SilentMultiPprfReceiver.reserve"); + // if we aren't interleaved, we need to copy the + // last layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) { + // Subfield: no need to copyOut + throw RTE_LOC; +// copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); + } + } - mExps.clear(); mExps.reserve(divCeil(mPntCount, 8)); - for (i = 0; i < mPntCount; i += 8) - { - mExps.emplace_back(*this, chl.fork(), oFormat, output, activeChildXorDelta, i); - mExps.back().mFuture = macoro::make_eager(mExps.back().run()); + setTimePoint("SilentMultiPprfReceiver.join"); - //MC_AWAIT(mExps.back().run()); - } + mBaseOTs = {}; + mTreeAlloc.del(tree); + mTreeAlloc.clear(); - for (i = 0; i < mExps.size(); ++i) - MC_AWAIT(mExps[i].mFuture); - setTimePoint("SilentMultiPprfReceiver.join"); - - mBaseOTs = {}; - setTimePoint("SilentMultiPprfReceiver.de-alloc"); + setTimePoint("SilentMultiPprfReceiver.de-alloc"); MC_END(); } @@ -950,310 +830,136 @@ namespace osuCrypto::Subfield mPntCount = 0; } - - - struct Expander + void expandOne( + u64 treeIdx, + bool programActivePath, + span>> levels, + span> lastLevel, + span, 2>> theirSums, + span> lastOts) { - SilentSubfieldPprfReceiver& pprf; - Socket chl; - - bool mActiveChildXorDelta = false; - std::array aes; - - PprfOutputFormat oFormat; - MatrixView output; - - macoro::eager_task mFuture; - - std::vector>> mLevels; - - // mySums will hold the left and right GGM tree sums - // for each level. For example mySums[5][0] will - // hold the sum of the left children for the 5th tree. This - // sum will be "missing" the children of the active parent. - // The sender will give of one of the full somes so we can - // compute the missing inactive child. - std::array, 2> mySums; - - // sums for the last level - std::array, 2> lastSums; - std::vector> lastOts; - - // A buffer for receiving the sums from the other party. - // These will be masked by the OT strings. - std::array>, 2> theirSums; - - u64 dd, treeIdx; - // tree will hold the full GGM tree. Not that there are 8 - // indepenendent trees that are being processed together. - // The trees are flattenned to that the children of j are - // located at 2*j and 2*j+1. - //std::unique_ptr uPtr_; - span> tree; - - // Returns the i'th level of the current 8 trees. The - // children of node j on level i are located at 2*j and - // 2*j+1 on level i+1. - span> getLevel(u64 i, u64 g, bool f = false) - { - //auto size = (1ull << i); - #ifdef DEBUG_PRINT_PPRF - //auto offset = (size - 1); - //auto b = (f ? ftree.begin() : tree.begin()) + offset; - #else - return mLevels[i]; - #endif - //return span>(b,e); - }; - - span> getLastLevel(u64 i, u64 g, bool f = false) - { - //auto size = (1ull << i); - #ifdef DEBUG_PRINT_PPRF - //auto offset = (size - 1); - //auto b = (f ? ftree.begin() : tree.begin()) + offset; - #else - if (oFormat == PprfOutputFormat::Interleaved && i == pprf.mDepth) - { - auto b = (AlignedArray*)output.data(); - auto forest = g / 8; - assert(g % 8 == 0); - b += forest * pprf.mDomain; - auto zone = span>(b, pprf.mDomain); - return zone; - } - - //assert(tree.size()); - //auto b = tree.begin() + offset; - - throw RTE_LOC; - #endif - //return span>(b,e); - }; - - - Expander(SilentSubfieldPprfReceiver& p, Socket&& s, PprfOutputFormat of, MatrixView o, bool activeChildXorDelta, u64 ti) - : pprf(p) - , chl(std::move(s)) - , mActiveChildXorDelta(activeChildXorDelta) - , oFormat(of) - , output(o) - , treeIdx(ti) - //, threadIdx(tIdx) - { - assert((treeIdx & 7) == 0); - // A public PRF/PRG that we will use for deriving the GGM tree. - aes[0].setKey(toBlock(3242342)); - aes[1].setKey(toBlock(8993849)); - - - theirSums[0].resize(p.mDepth - mActiveChildXorDelta); - theirSums[1].resize(p.mDepth - mActiveChildXorDelta); - - dd = p.mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - - } - task<> run() - { - - MC_BEGIN(task<>, this); - + // This thread will process 8 trees at a time. + // special case for the first level. + auto l1 = levels[1]; + for (u64 i = 0; i < 8; ++i) { - tree = pprf.mTreeAlloc.get(); - assert(tree.size() >= 1ull << (dd)); - mLevels.resize(dd+1); // todo: last level block are kept - mLevels[0] = tree.subspan(0, 1); - auto rem = tree.subspan(1); - for (u64 i = 1; i < dd + 1; i++) - { - while ((u64)rem.data() % 32) - rem = rem.subspan(1); - mLevels[i] = rem.subspan(0, mLevels[i - 1].size() * 2); - rem = rem.subspan(mLevels[i].size()); - } + // For the non-active path, set the child of the root node + // as the OT message XOR'ed with the correction sum. + int notAi = mBaseChoices[i + treeIdx][0]; + l1[notAi][i] = mBaseOTs[i + treeIdx][0] ^ theirSums[0][notAi][i]; + l1[notAi ^ 1][i] = ZeroBlock; } + // space for our sums of each level. + std::array, 2> mySums; - #ifdef DEBUG_PRINT_PPRF - // This will be the full tree and is sent by the receiver to help debug. - std::vector> ftree(1ull << (mDepth + 1)); - - // The delta value on the active path. - //block deltaValue; - chl.recv(mDebugValue); - #endif - - - - #ifdef DEBUG_PRINT_PPRF - // prints out the contents of the d'th level. - auto printLevel = [&](u64 d) - { - - auto level0 = getLevel(d); - auto flevel0 = getLevel(d, true); - - std::cout - << "---------------------\nlevel " << d - << "\n---------------------" << std::endl; - - std::array sums{ ZeroBlock ,ZeroBlock }; - for (i64 i = 0; i < level0.size(); ++i) - { - for (u64 j = 0; j < 8; ++j) - { - - if (neq(level0[i][j], flevel0[i][j])) - std::cout << Color::Red; - - std::cout << "p[" << i << "][" << j << "] " - << level0[i][j] << " " << flevel0[i][j] << std::endl << Color::Default; - - if (i == 0 && j == 0) - sums[i & 1] = sums[i & 1] ^ flevel0[i][j]; - } - } - - std::cout << "sums[0] = " << sums[0] << " " << sums[1] << std::endl; - }; - #endif - + // Subfield: lastSums + std::array, 2> lastSums{}; - // The number of real trees for this iteration. - memset(lastSums[0].data(), 0, lastSums[0].size() * sizeof(F)); - memset(lastSums[1].data(), 0, lastSums[1].size() * sizeof(F)); - memset(mySums[0].data(), 0, mySums[0].size() * sizeof(F)); - memset(mySums[1].data(), 0, mySums[1].size() * sizeof(F)); - lastOts.resize(8); - - // This thread will process 8 trees at a time. It will interlace - // the sets of trees are processed with the other threads. + // For all other levels, expand the GGM tree and add in + // the correction along the active path. + for (u64 d = 1; d < mDepth; ++d) { - #ifdef DEBUG_PRINT_PPRF - chl.recv(ftree); - auto l1f = getLevel(1, true); - #endif + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; - //timer.setTimePoint("recv.start" + std::to_string(treeIdx)); - // Receive their full set of sums for these 8 trees. - MC_AWAIT(chl.recv(theirSums[0])); - MC_AWAIT(chl.recv(theirSums[1])); + // The next level that we want to construct. + auto level1 = levels[d + 1]; - if (mActiveChildXorDelta) - MC_AWAIT(chl.recv(lastOts)); - // pprf.setTimePoint("SilentMultiPprfReceiver.recv " + std::to_string(treeIdx)); + // Zero out the previous sums. + memset(mySums.data(), 0, sizeof(mySums)); - tree = pprf.mTreeAlloc.get(); - assert(tree.size() >= 1ull << (dd)); - assert((u64)tree.data() % 32 == 0); + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); - // pprf.setTimePoint("SilentMultiPprfReceiver.alloc " + std::to_string(treeIdx)); - - auto l1 = getLevel(1, treeIdx); - - for (u64 i = 0; i < 8; ++i) + // for internal nodes we the optimized approach. + if (d + 1 < mDepth && 0) { - // For the non-active path, set the child of the root node - // as the OT message XOR'ed with the correction sum. - int notAi = pprf.mBaseChoices[i + treeIdx][0]; - l1[notAi][i] = pprf.mBaseOTs[i + treeIdx][0] ^ theirSums[notAi][0][i]; - l1[notAi ^ 1][i] = ZeroBlock; - - #ifdef DEBUG_PRINT_PPRF - if (neq(l1[notAi][i], l1f[notAi][i])) { - std::cout << "l1[" << notAi << "][" << i << "] " << l1[notAi][i] << " = " - << (mBaseOTs[i + treeIdx][0]) << " ^ " - << theirSums[notAi][0][i] << " vs " << l1f[notAi][i] << std::endl; - } - #endif +// for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) +// { +// // The value of the parent. +// auto parent = level0[parentIdx]; +// +// auto& child0 = level1.data()[childIdx]; +// auto& child1 = level1.data()[childIdx + 1]; +// mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); +// +// child0[0] = child1[0] ^ parent[0]; +// child0[1] = child1[1] ^ parent[1]; +// child0[2] = child1[2] ^ parent[2]; +// child0[3] = child1[3] ^ parent[3]; +// child0[4] = child1[4] ^ parent[4]; +// child0[5] = child1[5] ^ parent[5]; +// child0[6] = child1[6] ^ parent[6]; +// child0[7] = child1[7] ^ parent[7]; +// +// // Update the running sums for this level. We keep +// // a left and right totals for each level. Note that +// // we are actually XOR in the incorrect value of the +// // children of the active parent (assuming !DEBUG_PRINT_PPRF). +// // This is ok since we will later XOR off these incorrect values. +// mySums[0][0] = mySums[0][0] ^ child0[0]; +// mySums[0][1] = mySums[0][1] ^ child0[1]; +// mySums[0][2] = mySums[0][2] ^ child0[2]; +// mySums[0][3] = mySums[0][3] ^ child0[3]; +// mySums[0][4] = mySums[0][4] ^ child0[4]; +// mySums[0][5] = mySums[0][5] ^ child0[5]; +// mySums[0][6] = mySums[0][6] ^ child0[6]; +// mySums[0][7] = mySums[0][7] ^ child0[7]; +// +// child1[0] = child1[0] + parent[0]; +// child1[1] = child1[1] + parent[1]; +// child1[2] = child1[2] + parent[2]; +// child1[3] = child1[3] + parent[3]; +// child1[4] = child1[4] + parent[4]; +// child1[5] = child1[5] + parent[5]; +// child1[6] = child1[6] + parent[6]; +// child1[7] = child1[7] + parent[7]; +// +// mySums[1][0] = mySums[1][0] ^ child1[0]; +// mySums[1][1] = mySums[1][1] ^ child1[1]; +// mySums[1][2] = mySums[1][2] ^ child1[2]; +// mySums[1][3] = mySums[1][3] ^ child1[3]; +// mySums[1][4] = mySums[1][4] ^ child1[4]; +// mySums[1][5] = mySums[1][5] ^ child1[5]; +// mySums[1][6] = mySums[1][6] ^ child1[6]; +// mySums[1][7] = mySums[1][7] ^ child1[7]; +// } } - - #ifdef DEBUG_PRINT_PPRF - if (mPrint) - printLevel(1); - #endif - - // For all other levels, expand the GGM tree and add in - // the correction along the active path. - for (u64 d = 1; d < pprf.mDepth; ++d) + else { - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = getLevel(d, treeIdx); - - // The next level that we want to construct. - auto level1 = getLevel(d + 1, treeIdx); - - // Zero out the previous sums. - memset(mySums[0].data(), 0, mySums[0].size() * sizeof(block)); - memset(mySums[1].data(), 0, mySums[1].size() * sizeof(block)); - - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = static_cast(level1.size()); - for (u64 childIdx = 0; childIdx < width; ) + // for leaf nodes both children should be hashed. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) { - - // Index of the parent in the previous level. - auto parentIdx = childIdx >> 1; - // The value of the parent. auto parent = level0[parentIdx]; for (u64 keep = 0; keep < 2; ++keep, ++childIdx) { - - //// The bit that indicates if we are on the left child (0) - //// or on the right child (1). - //u8 keep = childIdx & 1; - - // The child that we will write in this iteration. auto& child = level1[childIdx]; - // Each parent is expanded into the left and right children + // Each parent is expanded into the left and right children // using a different AES fixed-key. Therefore our OWF is: // // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); // // where each half defines one of the children. - aes[keep].template hashBlocks<8>(parent.data(), child.data()); - - - - #ifdef DEBUG_PRINT_PPRF - // For debugging, set the active path to zero. - for (u64 i = 0; i < 8; ++i) - if (eq(parent[i], ZeroBlock)) - child[i] = ZeroBlock; - #endif + gAes[keep].hashBlocks<8>(parent.data(), child.data()); - if (d < pprf.mDepth - 1) { - // Same as before - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent (assuming !DEBUG_PRINT_PPRF). - // This is ok since we will later XOR off these incorrect values. - auto& sum = mySums[keep]; - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } else { - if (getLastLevel(pprf.mDepth, treeIdx).size() <= childIdx) { - childIdx = width; - break; + // Subfield: + if (d == mDepth - 1) { + if (lastLevel.size() <= childIdx) { + // todo + throw RTE_LOC; } - auto& realChild = getLastLevel(pprf.mDepth, treeIdx)[childIdx]; + auto& realChild = lastLevel[childIdx]; auto& lastSum = lastSums[keep]; realChild[0] = TypeTrait::fromBlock(child[0]); lastSum[0] = TypeTrait::plus(lastSum[0], realChild[0]); @@ -1271,174 +977,126 @@ namespace osuCrypto::Subfield lastSum[6] = TypeTrait::plus(lastSum[6], realChild[6]); realChild[7] = TypeTrait::fromBlock(child[7]); lastSum[7] = TypeTrait::plus(lastSum[7], realChild[7]); + } else { + // Update the running sums for this level. We keep + // a left and right totals for each level. Note that + // we are actually XOR in the incorrect value of the + // children of the active parent (assuming !DEBUG_PRINT_PPRF). + // This is ok since we will later XOR off these incorrect values. + auto& sum = mySums[keep]; + sum[0] = sum[0] ^ child[0]; + sum[1] = sum[1] ^ child[1]; + sum[2] = sum[2] ^ child[2]; + sum[3] = sum[3] ^ child[3]; + sum[4] = sum[4] ^ child[4]; + sum[5] = sum[5] ^ child[5]; + sum[6] = sum[6] ^ child[6]; + sum[7] = sum[7] ^ child[7]; } } } + } - // For everything but the last level we have to - // 1) fix our sums so they dont include the incorrect - // values that are the children of the active parent - // 2) Update the non-active child of the active parent. - if (!mActiveChildXorDelta || d != pprf.mDepth - 1) - { - - for (u64 i = 0; i < 8; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = pprf.mPoints[i + treeIdx]; - - // The index of the active child node. - auto activeChildIdx = leafIdx >> (pprf.mDepth - 1 - d); - - // The index of the active child node sibling. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; - #ifdef DEBUG_PRINT_PPRF - auto prev = level1[inactiveChildIdx][i]; - #endif - - auto& inactiveChild = level1[inactiveChildIdx][i]; - - - // correct the sum value by XORing off the incorrect - auto correctSum = - inactiveChild ^ - theirSums[notAi][d][i]; - - inactiveChild = - correctSum ^ - mySums[notAi][i] ^ - pprf.mBaseOTs[i + treeIdx][d]; - - #ifdef DEBUG_PRINT_PPRF - if (mPrint) - std::cout << "up[" << i << "] = level1[" << inactiveChildIdx << "][" << i << "] " - << prev << " -> " << level1[inactiveChildIdx][i] << " " << activeChildIdx << " " << inactiveChildIdx << " ~~ " - << mBaseOTs[i + treeIdx][d] << " " << theirSums[notAi][d][i] << " @ " << (i + treeIdx) << " " << d << std::endl; - - auto fLevel1 = getLevel(d + 1, true); - if (neq(fLevel1[inactiveChildIdx][i], inactiveChild)) - throw RTE_LOC; - #endif - } - } - #ifdef DEBUG_PRINT_PPRF - if (mPrint) - printLevel(d + 1); - #endif - - } - - // pprf.setTimePoint("SilentMultiPprfReceiver.expand " + std::to_string(treeIdx)); - - //timer.setTimePoint("recv.expanded"); - - - // copy the last level to the output. If desired, this is - // where the transpose is performed. - auto lvl = getLastLevel(pprf.mDepth, treeIdx); - - if (mActiveChildXorDelta) + // For everything but the last level we have to + // 1) fix our sums so they dont include the incorrect + // values that are the children of the active parent + // 2) Update the non-active child of the active parent. + if (!programActivePath || d != mDepth - 1) { - // Now processes the last level. This one is special - // because we must XOR in the correction value as - // before but we must also fixed the child value for - // the active child. To do this, we will receive 4 - // values. Two for each case (left active or right active). - //timer.setTimePoint("recv.recvLast"); - - auto d = pprf.mDepth - 1; - for (u64 j = 0; j < 8; ++j) + for (u64 i = 0; i < 8; ++i) { - // The index of the child on the active path. - auto activeChildIdx = pprf.mPoints[j + treeIdx]; + // the index of the leaf node that is active. + auto leafIdx = mPoints[i + treeIdx]; + + // The index of the active child node. + auto activeChildIdx = leafIdx >> (mDepth - 1 - d); - // The index of the other (inactive) child. + // The index of the active child node sibling. auto inactiveChildIdx = activeChildIdx ^ 1; // The indicator as to the left or right child is inactive auto notAi = inactiveChildIdx & 1; - std::array masks, maskIn; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - maskIn[0] = pprf.mBaseOTs[j + treeIdx][d]; - maskIn[1] = pprf.mBaseOTs[j + treeIdx][d] ^ AllOneBlock; - mAesFixedKey.template hashBlocks<2>(maskIn.data(), masks.data()); - - // now get the chosen message OT strings by XORing - // the expended (random) OT strings with the lastOts values. - auto& ot0 = lastOts[j][2 * notAi + 0]; - auto& ot1 = lastOts[j][2 * notAi + 1]; - ot0 = TypeTrait::minus(ot0, TypeTrait::fromBlock(masks[0])); - ot1 = TypeTrait::minus(ot1, TypeTrait::fromBlock(masks[1])); - - #ifdef DEBUG_PRINT_PPRF - auto prev = level[inactiveChildIdx][j]; - #endif - - auto& inactiveChild = lvl[inactiveChildIdx][j]; - auto& activeChild = lvl[activeChildIdx][j]; - - // Fix the sums we computed previously to not include the - // incorrect child values. - auto inactiveSum = TypeTrait::minus(lastSums[notAi][j], inactiveChild); - auto activeSum = TypeTrait::minus(lastSums[notAi ^ 1][j], activeChild); - - // Update the inactive and active child to have to correct - // value by XORing their full sum with out partial sum, which - // gives us exactly the value we are missing. - inactiveChild = TypeTrait::minus(ot0, inactiveSum); - activeChild = TypeTrait::minus(ot1, activeSum); - - #ifdef DEBUG_PRINT_PPRF - auto fLevel1 = getLevel(d + 1, true); - if (neq(fLevel1[inactiveChildIdx][j], inactiveChild)) - throw RTE_LOC; - if (neq(fLevel1[activeChildIdx][j], activeChild ^ mDebugValue)) - throw RTE_LOC; + auto& inactiveChild = level1[inactiveChildIdx][i]; - if (mPrint) - std::cout << "up[" << d << "] = level1[" << (inactiveChildIdx / mPntCount) << "][" << (inactiveChildIdx % mPntCount) << " " - << prev << " -> " << level[inactiveChildIdx][j] << " ~~ " - << mBaseOTs[j + treeIdx][d] << " " << ot0 << " @ " << (j + treeIdx) << " " << d << std::endl; - #endif - } - // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); + // correct the sum value by XORing off the incorrect + auto correctSum = + inactiveChild ^ + theirSums[d][notAi][i]; - //timer.setTimePoint("recv.expandLast"); - } - else - { - for (auto j : rng(std::min(8, pprf.mPntCount - treeIdx))) - { + inactiveChild = + correctSum ^ + mySums[notAi][i] ^ + mBaseOTs[i + treeIdx][d]; - // The index of the child on the active path. - auto activeChildIdx = pprf.mPoints[j + treeIdx]; - lvl[activeChildIdx][j] = F{}; } } + } - // s is a checksum that is used for malicious security. - copyOut(lvl, output, pprf.mPntCount, treeIdx, oFormat, pprf.mOutputFn); - - // pprf.setTimePoint("SilentMultiPprfReceiver.copy " + std::to_string(treeIdx)); - - //uPtr_ = {}; - //tree = {}; - pprf.mTreeAlloc.del(tree); - - // pprf.setTimePoint("SilentMultiPprfReceiver.delete " + std::to_string(treeIdx)); - + // last level. + if (programActivePath) + { + // Now processes the last level. This one is special + // because we must XOR in the correction value as + // before but we must also fixed the child value for + // the active child. To do this, we will receive 4 + // values. Two for each case (left active or right active). + //timer.setTimePoint("recv.recvLast"); + + auto d = mDepth - 1; + for (u64 j = 0; j < 8; ++j) + { + // The index of the child on the active path. + auto activeChildIdx = mPoints[j + treeIdx]; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + std::array masks, maskIn; + + // We are going to expand the 128 bit OT string + // into a 256 bit OT string using AES. + maskIn[0] = mBaseOTs[j + treeIdx][d]; + maskIn[1] = mBaseOTs[j + treeIdx][d] ^ AllOneBlock; + mAesFixedKey.hashBlocks<2>(maskIn.data(), masks.data()); + + // now get the chosen message OT strings by XORing + // the expended (random) OT strings with the lastOts values. + auto& ot0 = lastOts[j][2 * notAi + 0]; + auto& ot1 = lastOts[j][2 * notAi + 1]; + ot0 = TypeTrait::minus(ot0, TypeTrait::fromBlock(masks[0])); + ot1 = TypeTrait::minus(ot1, TypeTrait::fromBlock(masks[1])); + + auto& inactiveChild = lastLevel[inactiveChildIdx][j]; + auto& activeChild = lastLevel[activeChildIdx][j]; + + // Fix the sums we computed previously to not include the + // incorrect child values. + auto inactiveSum = TypeTrait::minus(lastSums[notAi][j], inactiveChild); + auto activeSum = TypeTrait::minus(lastSums[notAi ^ 1][j], activeChild); + + // Update the inactive and active child to have to correct + // value by XORing their full sum with out partial sum, which + // gives us exactly the value we are missing. + inactiveChild = TypeTrait::minus(ot0, inactiveSum); + activeChild = TypeTrait::minus(ot1, activeSum); } + // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); - MC_END(); + //timer.setTimePoint("recv.expandLast"); + } + else + { + for (auto j : rng(std::min(8, mPntCount - treeIdx))) + { + // The index of the child on the active path. + auto activeChildIdx = mPoints[j + treeIdx]; + lastLevel[activeChildIdx][j] = F{}; } - }; - - std::vector mExps; + } + } }; } \ No newline at end of file diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h index 3a4e25fc..27e810a9 100644 --- a/libOTe/Vole/Subfield/SilentVoleReceiver.h +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -486,7 +486,7 @@ namespace osuCrypto::Subfield if (mTimer) mGen.setTimer(*mTimer); // expand the seeds into mA - MC_AWAIT(mGen.expand(chl, prng, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); + MC_AWAIT(mGen.expand(chl, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h index fa4428a8..9dca20cb 100644 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -389,7 +389,7 @@ namespace osuCrypto::Subfield // output of the PPRF at the correct locations. noiseShares = span(mNoiseDeltaShares.data(), mNumPartitions); mbb = mB.subspan(0, mNumPartitions * mSizePer); - MC_AWAIT(mGen.expand(chl, noiseShares, prng, mbb, + MC_AWAIT(mGen.expand(chl, noiseShares, prng.get(), mbb, PprfOutputFormat::Interleaved, true, mNumThreads)); setTimePoint("SilentVoleSender.expand.pprf_transpose"); From 6ddd384494f8a938f2a24fca72256c04ce1b24c5 Mon Sep 17 00:00:00 2001 From: Halulu Date: Thu, 11 Jan 2024 13:05:06 +0800 Subject: [PATCH 03/23] clear --- libOTe/Tools/Subfield/Subfield.h | 21 --------------------- libOTe/Tools/Subfield/SubfieldPprf.h | 2 +- 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h index 4d7878d8..1d145292 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/Subfield/Subfield.h @@ -8,27 +8,6 @@ namespace osuCrypto::Subfield { block b; F128() = default; explicit F128(const block& b) : b(b) {} -// OC_FORCEINLINE F128 operator+(const F128& rhs) const { -// F128 ret; -// ret.b = b ^ rhs.b; -// return ret; -// } -// OC_FORCEINLINE F128 operator-(const F128& rhs) const { -// F128 ret; -// ret.b = b ^ rhs.b; -// return ret; -// } -// OC_FORCEINLINE F128 operator*(const F128& rhs) const { -// F128 ret; -// ret.b = b.gf128Mul(rhs.b); -// return ret; -// } -// OC_FORCEINLINE bool operator==(const F128& rhs) const { -// return b == rhs.b; -// } -// OC_FORCEINLINE bool operator!=(const F128& rhs) const { -// return b != rhs.b; -// } }; /* diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h index 24549b86..a3d4a46e 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -956,7 +956,7 @@ namespace osuCrypto::Subfield // Subfield: if (d == mDepth - 1) { if (lastLevel.size() <= childIdx) { - // todo + // todo: I have fix in my old code, not sure we need this for the new pprf throw RTE_LOC; } auto& realChild = lastLevel[childIdx]; From 18ae3af2f79d698372a70180bc804b2b20afb0f1 Mon Sep 17 00:00:00 2001 From: Halulu Date: Fri, 12 Jan 2024 10:56:39 +0800 Subject: [PATCH 04/23] add fromBlockG --- libOTe/Tools/Subfield/Subfield.h | 15 +++++++++++++++ libOTe/Vole/Subfield/SilentVoleReceiver.h | 11 +++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h index 1d145292..399d360b 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/Subfield/Subfield.h @@ -43,6 +43,11 @@ namespace osuCrypto::Subfield { static OC_FORCEINLINE F fromBlock(const block& b) { return b.get()[0]; } + + static OC_FORCEINLINE G fromBlockG(const block& b) { + return b.get()[0]; + } + static OC_FORCEINLINE F pow(u64 power) { F ret = 1; ret <<= power; @@ -84,6 +89,11 @@ namespace osuCrypto::Subfield { static OC_FORCEINLINE F fromBlock(const block& b) { return b; } + + static OC_FORCEINLINE G fromBlockG(const block& b) { + return b; + } + static OC_FORCEINLINE F pow(u64 power) { F ret = ZeroBlock; *BitIterator((u8*)&ret, power) = 1; @@ -199,6 +209,11 @@ namespace osuCrypto::Subfield { } } + // assume primitive type for G now + static OC_FORCEINLINE G fromBlockG(const block& b) { + return b.get()[0]; + } + static OC_FORCEINLINE F pow(u64 power) { F ret; memset(&ret, 0, sizeof(ret)); diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h index 27e810a9..ca662a74 100644 --- a/libOTe/Vole/Subfield/SilentVoleReceiver.h +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -301,17 +301,16 @@ namespace osuCrypto::Subfield // sample the values of the noisy coordinate of c // and perform a noicy vole to get x+y = mD * c auto w = mNumPartitions + mGapOts.size(); - //std::vector y(w); + std::vector seeds(w); mNoiseValues.resize(w); - prng.get(mNoiseValues.data(), mNoiseValues.size()); + prng.get(seeds.data(), seeds.size()); + for (size_t i = 0; i < w; i++) { + mNoiseValues[i] = TypeTrait::fromBlockG(seeds[i]); + } mS.resize(mNumPartitions); mGen.getPoints(mS, getPprfFormat()); - // todo - std::vector tmp = mS; - std::sort(tmp.begin(), tmp.end()); - auto j = mNumPartitions * mSizePer; for (u64 i = 0; i < (u64)mGapBaseChoice.size(); ++i) From 4d74ce1ced6b376eeaa57bcc00c9643936fbb494 Mon Sep 17 00:00:00 2001 From: Halulu Date: Fri, 12 Jan 2024 12:34:43 +0800 Subject: [PATCH 05/23] ExConv use member function TypeTrait --- libOTe/Tools/Subfield/ExConvCode.h | 49 ++++++++-------- libOTe/Tools/Subfield/Expander.h | 71 +++++++++++------------ libOTe/Vole/Subfield/SilentVoleReceiver.h | 4 +- libOTe/Vole/Subfield/SilentVoleSender.h | 7 +-- libOTe_Tests/Subfield_Tests.cpp | 21 ++++--- 5 files changed, 76 insertions(+), 76 deletions(-) diff --git a/libOTe/Tools/Subfield/ExConvCode.h b/libOTe/Tools/Subfield/ExConvCode.h index 98ba3140..203a9351 100644 --- a/libOTe/Tools/Subfield/ExConvCode.h +++ b/libOTe/Tools/Subfield/ExConvCode.h @@ -31,11 +31,10 @@ namespace osuCrypto::Subfield // // https://eprint.iacr.org/2023/882 - template class ExConvCode : public TimerAdapter { public: - ExpanderCode mExpander; + ExpanderCode mExpander; // configure the code. The default parameters are choses to balance security and performance. // For additional parameter choices see the paper. @@ -89,7 +88,7 @@ namespace osuCrypto::Subfield u64 generatorCols() const { return mCodeSize; } // Compute w = G * e. e will be modified in the computation. - template + template void dualEncode(span e, span w) { if (e.size() != mCodeSize) @@ -100,7 +99,7 @@ namespace osuCrypto::Subfield if (mSystematic) { - dualEncode(e); + dualEncode(e); memcpy(w.data(), e.data(), w.size() * sizeof(T)); setTimePoint("ExConv.encode.memcpy"); } @@ -109,17 +108,17 @@ namespace osuCrypto::Subfield setTimePoint("ExConv.encode.begin"); - accumulate(e); + accumulate(e); setTimePoint("ExConv.encode.accumulate"); - mExpander.template expand(e, w); + mExpander.expand(e, w); setTimePoint("ExConv.encode.expand"); } } // Compute e[0,...,k-1] = G * e. - template + template void dualEncode(span e) { if (e.size() != mCodeSize) @@ -129,15 +128,15 @@ namespace osuCrypto::Subfield { auto d = e.subspan(mMessageSize); setTimePoint("ExConv.encode.begin"); - accumulate(d); + accumulate(d); setTimePoint("ExConv.encode.accumulate"); - mExpander.template expand(d, e.subspan(0, mMessageSize)); + mExpander.expand(d, e.subspan(0, mMessageSize)); setTimePoint("ExConv.encode.expand"); } else { oc::AlignedUnVector w(mMessageSize); - dualEncode(e, w); + dualEncode(e, w); memcpy(e.data(), w.data(), w.size() * sizeof(T)); setTimePoint("ExConv.encode.memcpy"); @@ -146,7 +145,7 @@ namespace osuCrypto::Subfield // Compute e[0,...,k-1] = G * e. - template + template void dualEncode2(span e0, span e1) { if (e0.size() != mCodeSize) @@ -159,9 +158,9 @@ namespace osuCrypto::Subfield auto d0 = e0.subspan(mMessageSize); auto d1 = e1.subspan(mMessageSize); setTimePoint("ExConv.encode.begin"); - accumulate(d0, d1); + accumulate(d0, d1); setTimePoint("ExConv.encode.accumulate"); - mExpander.template expand( + mExpander.expand( d0, d1, e0.subspan(0, mMessageSize), e1.subspan(0, mMessageSize)); @@ -375,7 +374,7 @@ namespace osuCrypto::Subfield inline My__m128 _mm_setzero_ps() { return ZeroBlock; } #endif - template + template OC_FORCEINLINE void accOneHelper( T* __restrict xx, My__m128 xii, @@ -433,7 +432,7 @@ namespace osuCrypto::Subfield } // accumulating row i. - template + template OC_FORCEINLINE void accOne( T* __restrict xx, u64 i, @@ -472,7 +471,7 @@ namespace osuCrypto::Subfield // else { My__m128 xii;// = ::_mm_set_ps(0.0f, 0.0f, 0.0f, 0.0f); memset(&xii, 0, sizeof(My__m128)); - accOneHelper(xx, xii, j, i, size, b); + accOneHelper(xx, xii, j, i, size, b); // } } } @@ -485,7 +484,7 @@ namespace osuCrypto::Subfield // accumulating row i. - template + template OC_FORCEINLINE void accOne( T0* __restrict xx0, T1* __restrict xx1, @@ -530,14 +529,14 @@ namespace osuCrypto::Subfield // accOneHelper(xx0, xii0, j, i, size, b); // } // else { - accOneHelper(xx0, _mm_setzero_ps(), j, i, size, b); + accOneHelper(xx0, _mm_setzero_ps(), j, i, size, b); // } // if constexpr (std::is_same::value) { // auto xii1 = _mm_load_ps((float*)(xx1 + i)); // accOneHelper(xx1, xii1, j, i, size, b); // } // else { - accOneHelper(xx1, _mm_setzero_ps(), j, i, size, b); + accOneHelper(xx1, _mm_setzero_ps(), j, i, size, b); // } } } @@ -551,7 +550,7 @@ namespace osuCrypto::Subfield // accumulate x onto itself. - template + template void accumulate(span x) { PRNG prng(mSeed ^ OneBlock); @@ -568,9 +567,9 @@ namespace osuCrypto::Subfield #define CASE(I) case I:\ for (; i < main; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ + accOne(xx, i, ptr, prng, q, qe, size);\ for (; i < size; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ + accOne(xx, i, ptr, prng, q, qe, size);\ break switch (mAccumulatorSize / 8) @@ -590,7 +589,7 @@ namespace osuCrypto::Subfield // accumulate x onto itself. - template + template void accumulate(span x0, span x1) { PRNG prng(mSeed ^ OneBlock); @@ -608,9 +607,9 @@ namespace osuCrypto::Subfield #define CASE(I) case I:\ for (; i < main; ++i)\ - accOne(xx0,xx1, i, ptr, prng, q, qe, size);\ + accOne(xx0,xx1, i, ptr, prng, q, qe, size);\ for (; i < size; ++i)\ - accOne(xx0, xx1, i, ptr, prng, q, qe, size);\ + accOne(xx0, xx1, i, ptr, prng, q, qe, size);\ break switch (mAccumulatorSize / 8) diff --git a/libOTe/Tools/Subfield/Expander.h b/libOTe/Tools/Subfield/Expander.h index 4f64d559..d79ff948 100644 --- a/libOTe/Tools/Subfield/Expander.h +++ b/libOTe/Tools/Subfield/Expander.h @@ -17,7 +17,6 @@ namespace osuCrypto::Subfield // The encoder for the expander matrix B. // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly // with fixed row weight mExpanderWeight. - template class ExpanderCode { public: @@ -55,7 +54,7 @@ namespace osuCrypto::Subfield - template + template typename std::enable_if::type expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const { @@ -63,7 +62,7 @@ namespace osuCrypto::Subfield return ee[r]; } - template + template typename std::enable_if<(count == 1)>::type expandOne( const T* __restrict ee1, @@ -87,7 +86,7 @@ namespace osuCrypto::Subfield } } - template + template OC_FORCEINLINE typename std::enable_if<(count > 1), T>::type expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const { @@ -131,19 +130,19 @@ namespace osuCrypto::Subfield w[7]); if constexpr (count > 8) - ww = TypeTrait::plus(ww, expandOne(ee, prng)); + ww = TypeTrait::plus(ww, expandOne(ee, prng)); return ww; } else { auto r = prng.get(); - auto ww = expandOne(ee, prng); + auto ww = expandOne(ee, prng); return TypeTrait::plus(ww, ee[r]); } } - template + template OC_FORCEINLINE typename std::enable_if<(count > 1)>::type expandOne( const T* __restrict ee1, @@ -221,7 +220,7 @@ namespace osuCrypto::Subfield { T yy1; T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); + expandOne(ee1, ee2, &yy1, &yy2, prng); ww1 = TypeTrait::plus(ww1, yy1); ww2 = TypeTrait::plus(ww2, yy2); } @@ -246,7 +245,7 @@ namespace osuCrypto::Subfield { auto w1 = ee1[r]; auto w2 = ee2[r]; - expandOne(ee1, ee2, y1, y2, prng); + expandOne(ee1, ee2, y1, y2, prng); *y1 = TypeTrait::plus(*y1, w1); *y2 = TypeTrait::plus(*y2, w2); @@ -256,14 +255,14 @@ namespace osuCrypto::Subfield T yy1; T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); + expandOne(ee1, ee2, &yy1, &yy2, prng); *y1 = TypeTrait::plus(yy1, ee1[r]); *y2 = TypeTrait::plus(yy2, ee2[r]); } } } - template + template void expand( span e, span w) const @@ -284,25 +283,25 @@ namespace osuCrypto::Subfield case I:\ if constexpr(Add)\ {\ - ww[i + 0] = TypeTrait::plus(ww[i + 0], expandOne(ee, prng));\ - ww[i + 1] = TypeTrait::plus(ww[i + 1], expandOne(ee, prng));\ - ww[i + 2] = TypeTrait::plus(ww[i + 2], expandOne(ee, prng));\ - ww[i + 3] = TypeTrait::plus(ww[i + 3], expandOne(ee, prng));\ - ww[i + 4] = TypeTrait::plus(ww[i + 4], expandOne(ee, prng));\ - ww[i + 5] = TypeTrait::plus(ww[i + 5], expandOne(ee, prng));\ - ww[i + 6] = TypeTrait::plus(ww[i + 6], expandOne(ee, prng));\ - ww[i + 7] = TypeTrait::plus(ww[i + 7], expandOne(ee, prng));\ + ww[i + 0] = TypeTrait::plus(ww[i + 0], expandOne(ee, prng));\ + ww[i + 1] = TypeTrait::plus(ww[i + 1], expandOne(ee, prng));\ + ww[i + 2] = TypeTrait::plus(ww[i + 2], expandOne(ee, prng));\ + ww[i + 3] = TypeTrait::plus(ww[i + 3], expandOne(ee, prng));\ + ww[i + 4] = TypeTrait::plus(ww[i + 4], expandOne(ee, prng));\ + ww[i + 5] = TypeTrait::plus(ww[i + 5], expandOne(ee, prng));\ + ww[i + 6] = TypeTrait::plus(ww[i + 6], expandOne(ee, prng));\ + ww[i + 7] = TypeTrait::plus(ww[i + 7], expandOne(ee, prng));\ }\ else\ {\ - ww[i + 0] = expandOne(ee, prng);\ - ww[i + 1] = expandOne(ee, prng);\ - ww[i + 2] = expandOne(ee, prng);\ - ww[i + 3] = expandOne(ee, prng);\ - ww[i + 4] = expandOne(ee, prng);\ - ww[i + 5] = expandOne(ee, prng);\ - ww[i + 6] = expandOne(ee, prng);\ - ww[i + 7] = expandOne(ee, prng);\ + ww[i + 0] = expandOne(ee, prng);\ + ww[i + 1] = expandOne(ee, prng);\ + ww[i + 2] = expandOne(ee, prng);\ + ww[i + 3] = expandOne(ee, prng);\ + ww[i + 4] = expandOne(ee, prng);\ + ww[i + 5] = expandOne(ee, prng);\ + ww[i + 6] = expandOne(ee, prng);\ + ww[i + 7] = expandOne(ee, prng);\ }\ break @@ -348,7 +347,7 @@ namespace osuCrypto::Subfield } } - template + template void expand( span e1, span e2, @@ -374,14 +373,14 @@ namespace osuCrypto::Subfield { #define CASE(I) \ case I:\ - expandOne(ee1, ee2, &ww1[i + 0], &ww2[i + 0], prng);\ - expandOne(ee1, ee2, &ww1[i + 1], &ww2[i + 1], prng);\ - expandOne(ee1, ee2, &ww1[i + 2], &ww2[i + 2], prng);\ - expandOne(ee1, ee2, &ww1[i + 3], &ww2[i + 3], prng);\ - expandOne(ee1, ee2, &ww1[i + 4], &ww2[i + 4], prng);\ - expandOne(ee1, ee2, &ww1[i + 5], &ww2[i + 5], prng);\ - expandOne(ee1, ee2, &ww1[i + 6], &ww2[i + 6], prng);\ - expandOne(ee1, ee2, &ww1[i + 7], &ww2[i + 7], prng);\ + expandOne(ee1, ee2, &ww1[i + 0], &ww2[i + 0], prng);\ + expandOne(ee1, ee2, &ww1[i + 1], &ww2[i + 1], prng);\ + expandOne(ee1, ee2, &ww1[i + 2], &ww2[i + 2], prng);\ + expandOne(ee1, ee2, &ww1[i + 3], &ww2[i + 3], prng);\ + expandOne(ee1, ee2, &ww1[i + 4], &ww2[i + 4], prng);\ + expandOne(ee1, ee2, &ww1[i + 5], &ww2[i + 5], prng);\ + expandOne(ee1, ee2, &ww1[i + 6], &ww2[i + 6], prng);\ + expandOne(ee1, ee2, &ww1[i + 7], &ww2[i + 7], prng);\ break switch (mExpanderWeight) diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h index ca662a74..03c2e97d 100644 --- a/libOTe/Vole/Subfield/SilentVoleReceiver.h +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -73,7 +73,7 @@ namespace osuCrypto::Subfield // the sparse vector. MultType mMultType = DefaultMultType; - ExConvCode mExConvEncoder; + ExConvCode mExConvEncoder; // The multi-point punctured PRF for generating // the sparse vectors. @@ -526,7 +526,7 @@ namespace osuCrypto::Subfield mExConvEncoder.setTimer(getTimer()); } - mExConvEncoder.template dualEncode2( + mExConvEncoder.dualEncode2( mA.subspan(0, mExConvEncoder.mCodeSize), mC.subspan(0, mExConvEncoder.mCodeSize) ); diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h index 9dca20cb..064cb583 100644 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -29,7 +29,6 @@ namespace osuCrypto::Subfield { - template inline void SubfieldExConvConfigure( u64 numOTs, u64 secParam, MultType mMultType, @@ -38,7 +37,7 @@ namespace osuCrypto::Subfield u64& mSizePer, u64& mN2, u64& mN, - ExConvCode& mEncoder + ExConvCode& mEncoder ) { u64 a = 24; @@ -112,7 +111,7 @@ namespace osuCrypto::Subfield #ifdef ENABLE_INSECURE_SILVER SilverEncoder mEncoder; #endif - ExConvCode mExConvEncoder; + ExConvCode mExConvEncoder; AlignedUnVector mB; @@ -414,7 +413,7 @@ namespace osuCrypto::Subfield if (mTimer) { mExConvEncoder.setTimer(getTimer()); } - mExConvEncoder.template dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); + mExConvEncoder.dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); break; default: throw RTE_LOC; diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp index 963e03bc..b4dbe35e 100644 --- a/libOTe_Tests/Subfield_Tests.cpp +++ b/libOTe_Tests/Subfield_Tests.cpp @@ -16,8 +16,9 @@ namespace osuCrypto::Subfield void Subfield_ExConvCode_encode_test(const oc::CLP& cmd) { { + using TypeTrait = TypeTraitF128; u64 n = 1024; - ExConvCode code; + ExConvCode code; code.config(n / 2, n, 7, 24, true); PRNG prng(ZeroBlock); @@ -30,10 +31,10 @@ namespace osuCrypto::Subfield z1[i] = z0[i] ^ delta.gf128Mul(y[i]); } - code.dualEncode(z1); + code.dualEncode(z1); // code.dualEncode(z0); // code.dualEncode(y); - code.dualEncode2(z0, y); + code.dualEncode2(z0, y); for (u64 i = 0; i < n; ++i) { @@ -45,8 +46,9 @@ namespace osuCrypto::Subfield } { + using TypeTrait = TypeTraitPrimitive; u64 n = 1024; - ExConvCode> code; + ExConvCode code; code.config(n / 2, n, 7, 24, true); PRNG prng(ZeroBlock); @@ -59,8 +61,8 @@ namespace osuCrypto::Subfield z1[i] = z0[i] + delta * y[i]; } - code.dualEncode(z1); - code.dualEncode2(z0, y); + code.dualEncode(z1); + code.dualEncode2(z0, y); for (u64 i = 0; i < n; ++i) { @@ -72,8 +74,9 @@ namespace osuCrypto::Subfield } { + using TypeTrait = TypeTrait64; u64 n = 1024; - ExConvCode code; + ExConvCode code; code.config(n / 2, n, 7, 24, true); PRNG prng(ZeroBlock); @@ -86,8 +89,8 @@ namespace osuCrypto::Subfield z1[i] = z0[i] + delta * y[i]; } - code.dualEncode(z1); - code.dualEncode2(z0, y); + code.dualEncode(z1); + code.dualEncode2(z0, y); for (u64 i = 0; i < n; ++i) { From c147db44331863c77480a83894a9c900368b8c2d Mon Sep 17 00:00:00 2001 From: Halulu Date: Fri, 12 Jan 2024 19:45:59 +0800 Subject: [PATCH 06/23] extern gAes --- libOTe/Tools/Subfield/Subfield.cpp | 11 +++++++++++ libOTe/Tools/Subfield/SubfieldPprf.h | 12 ++---------- 2 files changed, 13 insertions(+), 10 deletions(-) create mode 100644 libOTe/Tools/Subfield/Subfield.cpp diff --git a/libOTe/Tools/Subfield/Subfield.cpp b/libOTe/Tools/Subfield/Subfield.cpp new file mode 100644 index 00000000..ccefd3a5 --- /dev/null +++ b/libOTe/Tools/Subfield/Subfield.cpp @@ -0,0 +1,11 @@ +#include "cryptoTools/Crypto/AES.h" + +namespace osuCrypto::Subfield { + // A public PRF/PRG that we will use for deriving the GGM tree. + extern const std::array gAes = []() { + std::array aes; + aes[0].setKey(toBlock(3242342)); + aes[1].setKey(toBlock(8993849)); + return aes; + }(); +} diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h index a3d4a46e..db82ff22 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -13,16 +13,8 @@ namespace osuCrypto::Subfield { - namespace - { - // A public PRF/PRG that we will use for deriving the GGM tree. - const std::array gAes = []() { - std::array aes; - aes[0].setKey(toBlock(3242342)); - aes[1].setKey(toBlock(8993849)); - return aes; - }(); - } + + extern const std::array gAes; template void copyOut( From baa1c71e8f0f9989246c4927b86031c333353014 Mon Sep 17 00:00:00 2001 From: Halulu Date: Sat, 13 Jan 2024 12:11:16 +0800 Subject: [PATCH 07/23] add DefaultTrait --- libOTe/Tools/Subfield/Subfield.h | 13 ++++++------- libOTe/Tools/Subfield/SubfieldPprf.h | 8 +++----- libOTe/Vole/Subfield/SilentVoleReceiver.h | 7 ++----- libOTe/Vole/Subfield/SilentVoleSender.h | 7 ++----- libOTe_Tests/Subfield_Tests.cpp | 18 +++++++++--------- 5 files changed, 22 insertions(+), 31 deletions(-) diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h index 399d360b..389fcd11 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/Subfield/Subfield.h @@ -3,13 +3,6 @@ #include "cryptoTools/Common/BitVector.h" namespace osuCrypto::Subfield { - - struct F128 { - block b; - F128() = default; - explicit F128(const block& b) : b(b) {} - }; - /* * Primitive TypeTrait for integers */ @@ -55,6 +48,12 @@ namespace osuCrypto::Subfield { } }; + + template + struct DefaultTrait: TypeTraitPrimitive { + static_assert(std::is_same::value, "F and G must be the same type"); + }; + using TypeTrait64 = TypeTraitPrimitive; /* diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h index db82ff22..f2ba75e4 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -171,12 +171,11 @@ namespace osuCrypto::Subfield } } - template + template> class SilentSubfieldPprfSender : public TimerAdapter { public: - using F = typename TypeTrait::F; u64 mDomain = 0, mDepth = 0, mPntCount = 0; - std::vector mValue; + std::vector mValue; bool mPrint = false; TreeAllocator mTreeAlloc; Matrix> mBaseOTs; @@ -535,11 +534,10 @@ namespace osuCrypto::Subfield }; - template + template> class SilentSubfieldPprfReceiver : public TimerAdapter { public: - using F = typename TypeTrait::F; u64 mDomain = 0, mDepth = 0, mPntCount = 0; std::vector mPoints; diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h index 03c2e97d..c6497b3f 100644 --- a/libOTe/Vole/Subfield/SilentVoleReceiver.h +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -29,13 +29,10 @@ namespace osuCrypto::Subfield { - template + template> class SilentSubfieldVoleReceiver : public TimerAdapter { public: - using F = typename TypeTrait::F; - using G = typename TypeTrait::G; - static constexpr u64 mScaler = 2; enum class State @@ -77,7 +74,7 @@ namespace osuCrypto::Subfield // The multi-point punctured PRF for generating // the sparse vectors. - SilentSubfieldPprfReceiver mGen; + SilentSubfieldPprfReceiver mGen; // The internal buffers for holding the expanded vectors. // mA + mB = mC * delta diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h index 064cb583..d84eef18 100644 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -69,13 +69,10 @@ namespace osuCrypto::Subfield } - template + template> class SilentSubfieldVoleSender : public TimerAdapter { public: - using F = typename TypeTrait::F; - using G = typename TypeTrait::G; - static constexpr u64 mScaler = 2; enum class State @@ -88,7 +85,7 @@ namespace osuCrypto::Subfield State mState = State::Default; - SilentSubfieldPprfSender mGen; + SilentSubfieldPprfSender mGen; u64 mRequestedNumOTs = 0; u64 mN2 = 0; diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp index b4dbe35e..fcbd4c9a 100644 --- a/libOTe_Tests/Subfield_Tests.cpp +++ b/libOTe_Tests/Subfield_Tests.cpp @@ -327,11 +327,11 @@ namespace osuCrypto::Subfield { PRNG prng(seed); - u64 x = TypeTrait64::fromBlock(prng.get()); + u64 x = prng.get(); std::vector c(n), z0(n), z1(n); - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; recv.mMultType = MultType::ExConv7x24; send.mMultType = MultType::ExConv7x24; @@ -368,15 +368,15 @@ namespace osuCrypto::Subfield { PRNG prng(seed); constexpr size_t N = 10; + using F = Vec; + using G = u32; using TypeTrait = TypeTraitVec; - using F = TypeTrait::F; - using G = TypeTrait::G; F x = TypeTrait::fromBlock(prng.get()); std::vector c(n); std::vector z0(n), z1(n); - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; recv.mMultType = MultType::ExConv7x24; send.mMultType = MultType::ExConv7x24; @@ -420,8 +420,8 @@ namespace osuCrypto::Subfield block x = prng.get(); std::vector c(n), z0(n), z1(n); - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; recv.mMultType = MultType::ExConv7x24; send.mMultType = MultType::ExConv7x24; From 9edf262b690f3b8b89685a6a946d1849c4cd9010 Mon Sep 17 00:00:00 2001 From: Halulu Date: Sat, 13 Jan 2024 12:21:53 +0800 Subject: [PATCH 08/23] DefaultTrait --- libOTe/Tools/Subfield/Subfield.h | 12 ++++++------ libOTe_Tests/Subfield_Tests.cpp | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h index 389fcd11..18a13b40 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/Subfield/Subfield.h @@ -48,12 +48,6 @@ namespace osuCrypto::Subfield { } }; - - template - struct DefaultTrait: TypeTraitPrimitive { - static_assert(std::is_same::value, "F and G must be the same type"); - }; - using TypeTrait64 = TypeTraitPrimitive; /* @@ -221,4 +215,10 @@ namespace osuCrypto::Subfield { } }; + template + struct DefaultTrait: TypeTraitPrimitive { + static_assert(std::is_same::value, "F and G must be the same type"); + }; + + template<> struct DefaultTrait: TypeTraitF128 {}; } diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp index fcbd4c9a..ab99fcc7 100644 --- a/libOTe_Tests/Subfield_Tests.cpp +++ b/libOTe_Tests/Subfield_Tests.cpp @@ -420,8 +420,8 @@ namespace osuCrypto::Subfield block x = prng.get(); std::vector c(n), z0(n), z1(n); - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; recv.mMultType = MultType::ExConv7x24; send.mMultType = MultType::ExConv7x24; From 7b1c23f00cbbde74596172e6350d94ea3a7e3ca8 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Tue, 16 Jan 2024 22:30:25 -0800 Subject: [PATCH 09/23] new exconv working --- frontend/main.cpp | 1 - libOTe/Tools/EACode/Util.h | 1 - libOTe/Tools/ExConvCode/ExConvCode2.cpp | 117 ++ libOTe/Tools/ExConvCode/ExConvCode2.h | 620 +++++++ libOTe/Tools/ExConvCode/ExConvCode2Impl.h | 3 + libOTe/Tools/ExConvCode/Expander2.h | 261 +++ libOTe/Tools/SilentPprf.cpp | 4 +- libOTe/Tools/SilentPprf.h | 20 +- libOTe/Tools/Subfield/ExConvCode.h | 630 -------- libOTe/Tools/Subfield/Expander.h | 498 ------ libOTe/Tools/Subfield/Subfield.h | 344 ++-- .../{Subfield.cpp => SubfieldPprf.cpp} | 4 +- libOTe/Tools/Subfield/SubfieldPprf.h | 1422 +++++++++-------- libOTe/TwoChooseOne/ConfigureCode.cpp | 42 + libOTe/TwoChooseOne/ConfigureCode.h | 13 + .../Silent/SilentOtExtReceiver.cpp | 9 +- libOTe/Vole/Silent/SilentVoleSender.cpp | 39 +- libOTe/Vole/Subfield/NoisyVoleReceiver.h | 98 +- libOTe/Vole/Subfield/NoisyVoleSender.h | 70 +- libOTe/Vole/Subfield/SilentVoleReceiver.h | 422 ++--- libOTe/Vole/Subfield/SilentVoleSender.h | 175 +- libOTe_Tests/ExConvCode_Tests.cpp | 334 ++-- libOTe_Tests/SilentOT_Tests.cpp | 1 - libOTe_Tests/Subfield_Tests.cpp | 120 +- libOTe_Tests/UnitTests.cpp | 12 +- 25 files changed, 2565 insertions(+), 2695 deletions(-) create mode 100644 libOTe/Tools/ExConvCode/ExConvCode2.cpp create mode 100644 libOTe/Tools/ExConvCode/ExConvCode2.h create mode 100644 libOTe/Tools/ExConvCode/ExConvCode2Impl.h create mode 100644 libOTe/Tools/ExConvCode/Expander2.h delete mode 100644 libOTe/Tools/Subfield/ExConvCode.h delete mode 100644 libOTe/Tools/Subfield/Expander.h rename libOTe/Tools/Subfield/{Subfield.cpp => SubfieldPprf.cpp} (75%) diff --git a/frontend/main.cpp b/frontend/main.cpp index cfed28d5..bdbb8550 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -107,7 +107,6 @@ void minimal() #include "cryptoTools/Crypto/RandomOracle.h" int main(int argc, char** argv) { - CLP cmd; cmd.parse(argc, argv); bool flagSet = false; diff --git a/libOTe/Tools/EACode/Util.h b/libOTe/Tools/EACode/Util.h index e53a7ff1..c92d8b4b 100644 --- a/libOTe/Tools/EACode/Util.h +++ b/libOTe/Tools/EACode/Util.h @@ -121,7 +121,6 @@ namespace osuCrypto else { memcpy(dst, src, vals.size() * sizeof(value_type)); - //throw RTE_LOC; //assert(vals.size() % 32 == 0); for (u64 i = 0; i < vals.size(); i += 32) doMod32(vals.data() + i, &mod, modVal); diff --git a/libOTe/Tools/ExConvCode/ExConvCode2.cpp b/libOTe/Tools/ExConvCode/ExConvCode2.cpp new file mode 100644 index 00000000..aa4521c8 --- /dev/null +++ b/libOTe/Tools/ExConvCode/ExConvCode2.cpp @@ -0,0 +1,117 @@ +#include "ExConvCode2.h" +//#include "ExConvCode2Impl.h" +#include "libOTe/Tools/Subfield/Subfield.h" + +namespace osuCrypto +{ + + + //template void ExConvCode2::dualEncode(span e); + //template void ExConvCode2::dualEncode(span e, span w); + // + //template void ExConvCode2::dualEncode(span e); + //template void ExConvCode2::dualEncode(span e, span w); + + //template void ExConvCode2::dualEncode2(span, span e); + //template void ExConvCode2::accumulate(span, span e); + + //template void ExConvCode2::dualEncode2(span, span e); + //template void ExConvCode2::accumulate(span, span e); + + + // configure the code. The default parameters are choses to balance security and performance. + // For additional parameter choices see the paper. + void ExConvCode2::config( + u64 messageSize, + u64 codeSize, + u64 expanderWeight, + u64 accumulatorSize, + bool systematic, + block seed) + { + if (codeSize == 0) + codeSize = 2 * messageSize; + + mSeed = seed; + mMessageSize = messageSize; + mCodeSize = codeSize; + mAccumulatorSize = accumulatorSize; + mSystematic = systematic; + mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); + } + + //// get the expander matrix + //SparseMtx ExConvCode2::getB() const + //{ + // throw RTE_LOC; + // //if (mSystematic) + // //{ + // // PointList R(mMessageSize, mCodeSize); + // // auto B = mExpander.getB().points(); + + // // for (auto p : B) + // // { + // // R.push_back(p.mRow, mMessageSize + p.mCol); + // // } + // // for (u64 i = 0; i < mMessageSize; ++i) + // // R.push_back(i, i); + + // // return R; + // //} + // //else + // //{ + // // return mExpander.getB(); + // //} + //} + + + // Get the parity check version of the accumulator + //SparseMtx ExConvCode2::getAPar() const + //{ + // throw RTE_LOC; + // //PRNG prng(mSeed ^ OneBlock); + + // //auto n = mCodeSize - mSystematic * mMessageSize; + + // //PointList AP(n, n);; + // //DenseMtx A = DenseMtx::Identity(n); + + // //block rnd; + // //u8* __restrict ptr = (u8*)prng.mBuffer.data(); + // //auto qe = prng.mBuffer.size() * 128; + // //u64 q = 0; + + // //for (u64 i = 0; i < n; ++i) + // //{ + // // accOne(AP, i, ptr, prng, rnd, q, qe, n); + // //} + // //return AP; + //} + + //// get the accumulator matrix + //SparseMtx ExConvCode2::getA() const + //{ + // auto APar = getAPar(); + + // auto A = DenseMtx::Identity(mCodeSize); + + // u64 offset = mSystematic ? mMessageSize : 0ull; + + // for (u64 i = 0; i < APar.rows(); ++i) + // { + // for (auto y : APar.col(i)) + // { + // if (y != i) + // { + // auto ay = A.row(y + offset); + // auto ai = A.row(i + offset); + // ay ^= ai; + // } + // } + // } + + // return A.sparse(); + //} + + +} \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode2.h b/libOTe/Tools/ExConvCode/ExConvCode2.h new file mode 100644 index 00000000..9f7517d2 --- /dev/null +++ b/libOTe/Tools/ExConvCode/ExConvCode2.h @@ -0,0 +1,620 @@ +// © 2023 Visa. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +#pragma once + +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Timer.h" +#include "libOTe/Tools/ExConvCode/Expander2.h" +#include "libOTe/Tools/EACode/Util.h" +#include "libOTe/Tools/Subfield/Subfield.h" + +namespace osuCrypto +{ + + template + struct has_operator_star : std::false_type + {}; + + template + struct has_operator_star < T, std::void_t< + // must have a operator*() member fn + decltype(std::declval().operator*()) + >> + : std::true_type{}; + + + template + struct is_iterator : std::false_type + {}; + + template + struct is_iterator < T, std::void_t< + // must have a operator*() member fn + // or be a pointer + std::enable_if_t< + has_operator_star::value || + std::is_pointer_v> + > + >> + : std::true_type{}; + + // The encoder for the generator matrix G = B * A. dualEncode(...) is the main function + // config(...) should be called first. + // + // B is the expander while A is the convolution. + // + // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly + // with fixed row weight mExpanderWeight. + // + // A is a lower triangular n by n matrix with ones on the diagonal. The + // mAccumulatorSize diagonals left of the main diagonal are uniformly random. + // If mStickyAccumulator, then the first diagonal left of the main is always ones. + // + // See ExConvCode2Instantiations.cpp for how to instantiate new types that + // dualEncode can be called on. + // + // https://eprint.iacr.org/2023/882 + class ExConvCode2 : public TimerAdapter + { + public: + ExpanderCode2 mExpander; + + // configure the code. The default parameters are choses to balance security and performance. + // For additional parameter choices see the paper. + void config( + u64 messageSize, + u64 codeSize, + u64 expanderWeight = 7, + u64 accumulatorWeight = 16, + bool systematic = true, + block seed = block(9996754675674599, 56756745976768754)); + + // the seed that generates the code. + block mSeed = ZeroBlock; + + // The message size of the code. K. + u64 mMessageSize = 0; + + // The codeword size of the code. n. + u64 mCodeSize = 0; + + // The size of the accumulator. + u64 mAccumulatorSize = 0; + + // is the code systematic (true=faster) + bool mSystematic = true; + + // return n-k. code size n, message size k. + u64 parityRows() const { return mCodeSize - mMessageSize; } + + // return code size n. + u64 parityCols() const { return mCodeSize; } + + // return message size k. + u64 generatorRows() const { return mMessageSize; } + + // return code size n. + u64 generatorCols() const { return mCodeSize; } + + // Compute w = G * e. e will be modified in the computation. + // the computation will be done over F using CoeffCtx::plus + template< + typename F, + typename CoeffCtx, + typename SrcIter, + typename DstIter + > + void dualEncode(SrcIter&& e, DstIter&& w); + + // Compute e[0,...,k-1] = G * e. + // the computation will be done over F using CoeffCtx::plus + template< + typename F, + typename CoeffCtx, + typename Iter + > + void dualEncode(Iter&& e); + + // Compute e[0,...,k-1] = G * e. + template< + typename F, + typename G, + typename CoeffCtx, + typename IterF, + typename IterG + > + void dualEncode2(IterF&& e0, IterG&& e1) + { + dualEncode(e0); + dualEncode(e1); + } + + // Private functions ------------------------------------ + + static void refill(PRNG& prng) + { + assert(prng.mBuffer.size() == 256); + //block b[8]; + for (u64 i = 0; i < 256; i += 8) + { + block* __restrict b = prng.mBuffer.data() + i; + block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); + + b[0] = AES::roundEnc(b[0], k[0]); + b[1] = AES::roundEnc(b[1], k[1]); + b[2] = AES::roundEnc(b[2], k[2]); + b[3] = AES::roundEnc(b[3], k[3]); + b[4] = AES::roundEnc(b[4], k[4]); + b[5] = AES::roundEnc(b[5], k[5]); + b[6] = AES::roundEnc(b[6], k[6]); + b[7] = AES::roundEnc(b[7], k[7]); + } + } + + // take x[i] and add it to the next 8 positions if the flag b is 1. + // + // xx[j] += b[j] * x[i] + // + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void accOne8( + Iter&& xi, + Iter&& xj, + Iter&& end, + u8 b); + + // accumulating row i. + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + int AccumulatorSize, + typename Iter + > + OC_FORCEINLINE void accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + std::integral_constant); + + // accumulating row i. generic version + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + std::integral_constant); + + + // accumulate x onto itself. + template< + typename F, + typename CoeffCtx, + typename Iter + > + void accumulate(Iter x) + { + switch (mAccumulatorSize) + { + case 16: + accumulateFixed(std::forward(x)); + break; + case 24: + accumulateFixed(std::forward(x)); + break; + default: + // generic case + accumulateFixed(std::forward(x)); + } + } + + // accumulate x onto itself. + template< + typename F, + typename CoeffCtx, + u64 AccumulatorSize, + typename Iter + > + void accumulateFixed(Iter x); + + }; +} + + +//#include "ExConvCode2Impl.h" + +namespace osuCrypto +{ + + // Compute w = G * e. e will be modified in the computation. + template< + typename F, + typename CoeffCtx, + typename SrcIter, + typename DstIter + > + void ExConvCode2::dualEncode(SrcIter&& e, DstIter&& w) + { + + static_assert(is_iterator::value, "must pass in an iterator to the data, " __FUNCTION__); + static_assert(is_iterator::value, "must pass in an iterator to the data"); + + // try to deref the back. might bounds check. + (void)*(e + mCodeSize - 1); + (void)*(w + mMessageSize - 1); + + if (mSystematic) + { + dualEncode(e); + CoeffCtx::copy(w, w + mMessageSize, e); + setTimePoint("ExConv.encode.memcpy"); + } + else + { + + setTimePoint("ExConv.encode.begin"); + + accumulate(e); + + setTimePoint("ExConv.encode.accumulate"); + + mExpander.expand(e, w); + setTimePoint("ExConv.encode.expand"); + } + } + + // Compute e[0,...,k-1] = G * e. + template + void ExConvCode2::dualEncode(Iter&& e) + { + static_assert(is_iterator::value, "must pass in an iterator to the data"); + + (void)*(e + mCodeSize - 1); + + if (mSystematic) + { + auto d = e + mMessageSize; + setTimePoint("ExConv.encode.begin"); + accumulate(d); + setTimePoint("ExConv.encode.accumulate"); + mExpander.expand(d, e); + setTimePoint("ExConv.encode.expand"); + } + else + { + CoeffCtx::template Vec w; + CoeffCtx::resize(w, mMessageSize); + dualEncode(e, w.begin()); + CoeffCtx::copy(w.begin(), w.end(), e); + //memcpy(e.data(), w.data(), w.size() * sizeof(T)); + setTimePoint("ExConv.encode.memcpy"); + + } + } + + + //// Compute e[0,...,k-1] = G * e. + //template + //void ExConvCode2::dualEncode2(span e0, span e1) + //{ + // if (e0.size() != mCodeSize) + // throw RTE_LOC; + // if (e1.size() != mCodeSize) + // throw RTE_LOC; + + // if (mSystematic) + // { + // auto d0 = e0.subspan(mMessageSize); + // auto d1 = e1.subspan(mMessageSize); + // setTimePoint("ExConv.encode.begin"); + // accumulate(d0, d1); + // setTimePoint("ExConv.encode.accumulate"); + // mExpander.expand( + // d0, d1, + // e0.subspan(0, mMessageSize), + // e1.subspan(0, mMessageSize)); + // setTimePoint("ExConv.encode.expand"); + // } + // else + // { + // //oc::AlignedUnVector w0(mMessageSize); + // //dualEncode(e, w); + // //memcpy(e.data(), w.data(), w.size() * sizeof(T)); + // //setTimePoint("ExConv.encode.memcpy"); + + // // not impl. + // throw RTE_LOC; + // } + //} + + +#ifdef ENABLE_SSE + using My__m128 = __m128; +#else + using My__m128 = block; + + inline My__m128 _mm_load_ps(float* b) { return *(block*)b; } + + // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_ps&ig_expand=557 + inline My__m128 _mm_blendv_ps(My__m128 a, My__m128 b, My__m128 mask) + { + My__m128 dst; + for (u64 j = 0; j < 4; ++j) + { + if (mask.get(j) < 0) + dst.set(j, b.get(j)); + else + dst.set(j, a.get(j)); + } + return dst; + } + + + inline My__m128 _mm_setzero_ps() { return ZeroBlock; } +#endif + + // take x[i] and add it to the next 8 positions if the flag b is 1. + // + // xx[j] += b[j] * x[i] + // + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void ExConvCode2::accOne8( + Iter&& xi, + Iter&& xj, + Iter&& end, + u8 b) + { + if constexpr (std::is_same::value) + { + block rnd = block::allSame(b); + + block bshift[8]; + bshift[0] = rnd.slli_epi32<7>(); + bshift[1] = rnd.slli_epi32<6>(); + bshift[2] = rnd.slli_epi32<5>(); + bshift[3] = rnd.slli_epi32<4>(); + bshift[4] = rnd.slli_epi32<3>(); + bshift[5] = rnd.slli_epi32<2>(); + bshift[6] = rnd.slli_epi32<1>(); + bshift[7] = rnd; + + My__m128 bb[8]; + auto xii = _mm_load_ps((float*)(&*xi)); + My__m128 Zero = _mm_setzero_ps(); + + // bbj = bj + bb[0] = _mm_load_ps((float*)&bshift[0]); + bb[1] = _mm_load_ps((float*)&bshift[1]); + bb[2] = _mm_load_ps((float*)&bshift[2]); + bb[3] = _mm_load_ps((float*)&bshift[3]); + bb[4] = _mm_load_ps((float*)&bshift[4]); + bb[5] = _mm_load_ps((float*)&bshift[5]); + bb[6] = _mm_load_ps((float*)&bshift[6]); + bb[7] = _mm_load_ps((float*)&bshift[7]); + + // bbj = bj * xi + bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); + bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); + bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); + bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); + bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); + bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); + bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); + bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); + + block tt[8]; + memcpy(tt, bb, 8 * 16); + + assert((((b >> 0) & 1) ? *xi : ZeroBlock) == tt[0]); + assert((((b >> 1) & 1) ? *xi : ZeroBlock) == tt[1]); + assert((((b >> 2) & 1) ? *xi : ZeroBlock) == tt[2]); + assert((((b >> 3) & 1) ? *xi : ZeroBlock) == tt[3]); + assert((((b >> 4) & 1) ? *xi : ZeroBlock) == tt[4]); + assert((((b >> 5) & 1) ? *xi : ZeroBlock) == tt[5]); + assert((((b >> 6) & 1) ? *xi : ZeroBlock) == tt[6]); + assert((((b >> 7) & 1) ? *xi : ZeroBlock) == tt[7]); + + // xj += bj * xi + if (rangeCheck && xj + 0 == end) return; CoeffCtx::plus(*(xj + 0), *(xj + 0), tt[0]); + if (rangeCheck && xj + 1 == end) return; CoeffCtx::plus(*(xj + 1), *(xj + 1), tt[1]); + if (rangeCheck && xj + 2 == end) return; CoeffCtx::plus(*(xj + 2), *(xj + 2), tt[2]); + if (rangeCheck && xj + 3 == end) return; CoeffCtx::plus(*(xj + 3), *(xj + 3), tt[3]); + if (rangeCheck && xj + 4 == end) return; CoeffCtx::plus(*(xj + 4), *(xj + 4), tt[4]); + if (rangeCheck && xj + 5 == end) return; CoeffCtx::plus(*(xj + 5), *(xj + 5), tt[5]); + if (rangeCheck && xj + 6 == end) return; CoeffCtx::plus(*(xj + 6), *(xj + 6), tt[6]); + if (rangeCheck && xj + 7 == end) return; CoeffCtx::plus(*(xj + 7), *(xj + 7), tt[7]); + } + else + { + auto b0 = b & 1; + auto b1 = b & 2; + auto b2 = b & 4; + auto b3 = b & 8; + auto b4 = b & 16; + auto b5 = b & 32; + auto b6 = b & 64; + auto b7 = b & 128; + + if (rangeCheck && xj + 0 == end) return; if (b0) CoeffCtx::plus(*(xj + 0), *(xj + 0), *xi); + if (rangeCheck && xj + 1 == end) return; if (b1) CoeffCtx::plus(*(xj + 1), *(xj + 1), *xi); + if (rangeCheck && xj + 2 == end) return; if (b2) CoeffCtx::plus(*(xj + 2), *(xj + 2), *xi); + if (rangeCheck && xj + 3 == end) return; if (b3) CoeffCtx::plus(*(xj + 3), *(xj + 3), *xi); + if (rangeCheck && xj + 4 == end) return; if (b4) CoeffCtx::plus(*(xj + 4), *(xj + 4), *xi); + if (rangeCheck && xj + 5 == end) return; if (b5) CoeffCtx::plus(*(xj + 5), *(xj + 5), *xi); + if (rangeCheck && xj + 6 == end) return; if (b6) CoeffCtx::plus(*(xj + 6), *(xj + 6), *xi); + if (rangeCheck && xj + 7 == end) return; if (b7) CoeffCtx::plus(*(xj + 7), *(xj + 7), *xi); + } + } + + + + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void ExConvCode2::accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + std::integral_constant _) + { + + // xj += xi + std::remove_reference_t xj = xi + 1; + if (!rangeCheck || xj < end) + { + CoeffCtx::plus(*xj, *xj, *xi); + ++xj; + } + + // xj += bj * xi + u64 k = 0; + for (; k < mAccumulatorSize - 7; k += 8) + { + accOne8(xi, xj, end, *matrixCoeff++); + + if constexpr (rangeCheck) + { + auto r = end - xj; + xj += std::min(r, 8); + } + else + { + xj += 8; + } + } + for (; k < mAccumulatorSize; ) + { + auto b = *matrixCoeff++; + + for (u64 j = 0; j < 8 && k < mAccumulatorSize; ++j, ++k) + { + + if (rangeCheck == false || (xj != end)) + { + if (b & 1) + CoeffCtx::plus(*xj, *xj, *xi); + + ++xj; + b >>= 1; + } + + } + } + } + + + // add xi to all of the future locations + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + int AccumulatorSize, + typename Iter + > + OC_FORCEINLINE void ExConvCode2::accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + std::integral_constant) + { + static_assert(AccumulatorSize, "should have called the other overload"); + static_assert(AccumulatorSize % 8 == 0, "must be a multiple of 8"); + + // xj += xi + std::remove_reference_t xj = xi + 1; + if (!rangeCheck || xj < end) + { + CoeffCtx::plus(*xj, *xj, *xi); + ++xj; + } + + // xj += bj * xi + for (u64 k = 0; k < AccumulatorSize; k += 8) + { + accOne8(xi, xj, end, *matrixCoeff++); + + if constexpr (rangeCheck) + { + auto r = end - xj; + xj += std::min(r, 8); + } + else + { + xj += 8; + } + } + } + + // accumulate x onto itself. + template< + typename F, + typename CoeffCtx, + u64 AccumulatorSize, + typename Iter + > + void ExConvCode2::accumulateFixed(Iter xi) + { + auto end = xi + (mCodeSize - mSystematic * mMessageSize); + auto main = end - 1 - mAccumulatorSize; + + PRNG prng(mSeed ^ OneBlock); + u8* mtxCoeffIter = (u8*)prng.mBuffer.data(); + auto mtxCoeffEnd = mtxCoeffIter + prng.mBuffer.size() * sizeof(block) - divCeil(mAccumulatorSize, 8); + + // AccumulatorSize == 0 is the generic case, otherwise + // AccumulatorSize should be equal to mAccumulatorSize. + static_assert(AccumulatorSize % 8 == 0); + if (AccumulatorSize && mAccumulatorSize != AccumulatorSize) + throw RTE_LOC; + + while (xi < main) + { + if (mtxCoeffIter > mtxCoeffEnd) + { + // generate more mtx coefficients + refill(prng); + mtxCoeffIter = (u8*)prng.mBuffer.data(); + } + + // add xi to the next positions + accOne(xi, end, mtxCoeffIter++, std::integral_constant{}); + ++xi; + } + + while (xi < end) + { + if (mtxCoeffIter > mtxCoeffEnd) + { + // generate more mtx coefficients + refill(prng); + mtxCoeffIter = (u8*)prng.mBuffer.data(); + } + + // add xi to the next positions + accOne(xi, end, mtxCoeffIter++, std::integral_constant{}); + ++xi; + } + } + +} \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode2Impl.h b/libOTe/Tools/ExConvCode/ExConvCode2Impl.h new file mode 100644 index 00000000..da722947 --- /dev/null +++ b/libOTe/Tools/ExConvCode/ExConvCode2Impl.h @@ -0,0 +1,3 @@ +#pragma once +#include "ExConvCode2.h" + diff --git a/libOTe/Tools/ExConvCode/Expander2.h b/libOTe/Tools/ExConvCode/Expander2.h new file mode 100644 index 00000000..2dd6420b --- /dev/null +++ b/libOTe/Tools/ExConvCode/Expander2.h @@ -0,0 +1,261 @@ +// � 2023 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +#pragma once + +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Range.h" +#include "libOTe/Tools/LDPC/Mtx.h" +#include "libOTe/Tools/EACode/Util.h" + +namespace osuCrypto +{ + + template + auto getRestrictPtr(Iter& c) + { + //if constexpr (coproto::has_data_member_func::value) + //{ + // return (decltype(c.data())__restrict) c.data(); + //} + //else + { + return c; + } + } + + // The encoder for the expander matrix B. + // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly + // with fixed row weight mExpanderWeight. + class ExpanderCode2 + { + public: + + void config( + u64 messageSize, + u64 codeSize, + u64 expanderWeight, + block seed = block(33333, 33333)) + { + mMessageSize = messageSize; + mCodeSize = codeSize; + mExpanderWeight = expanderWeight; + mSeed = seed; + } + + // the seed that generates the code. + block mSeed = block(0, 0); + + // The message size of the code. K. + u64 mMessageSize = 0; + + // The codeword size of the code. n. + u64 mCodeSize = 0; + + // The row weight of the B matrix. + u64 mExpanderWeight = 0; + + u64 parityRows() const { return mCodeSize - mMessageSize; } + u64 parityCols() const { return mCodeSize; } + + u64 generatorRows() const { return mMessageSize; } + u64 generatorCols() const { return mCodeSize; } + + // compute a eight output. + // the result is written to the dst iterator/ptr. + // + template< + typename CoeffCtx, + typename DstIter, + typename SrcIter + > + void expandEight( + DstIter&& dst, + SrcIter&& ee, + detail::ExpanderModd& prng, + CoeffCtx ctx) const; + + + template< + typename F, + typename CoeffCtx, + bool add, + typename SrcIter, + typename DstIter + > + void expand( + SrcIter&& input, + DstIter&& output, + CoeffCtx ctx = {} + ) const; + + + template< + typename F, + typename CoeffCtx + > + typename CoeffCtx::template Vec getB(CoeffCtx ctx = {}) const; + + }; + + template< + typename CoeffCtx, + typename DstIter, + typename SrcIter + > + OC_FORCEINLINE void + ExpanderCode2::expandEight( + DstIter&& dst, + SrcIter&& ee, + detail::ExpanderModd& prng, + CoeffCtx ctx) const + { + u64 rr[8]; + rr[0] = prng.get(); + rr[1] = prng.get(); + rr[2] = prng.get(); + rr[3] = prng.get(); + rr[4] = prng.get(); + rr[5] = prng.get(); + rr[6] = prng.get(); + rr[7] = prng.get(); + + ctx.plus(*(dst + 0), *(dst + 0), *(ee + rr[0])); + ctx.plus(*(dst + 1), *(dst + 1), *(ee + rr[1])); + ctx.plus(*(dst + 2), *(dst + 2), *(ee + rr[2])); + ctx.plus(*(dst + 3), *(dst + 3), *(ee + rr[3])); + ctx.plus(*(dst + 4), *(dst + 4), *(ee + rr[4])); + ctx.plus(*(dst + 5), *(dst + 5), *(ee + rr[5])); + ctx.plus(*(dst + 6), *(dst + 6), *(ee + rr[6])); + ctx.plus(*(dst + 7), *(dst + 7), *(ee + rr[7])); + + } + + + template< + typename F, + typename CoeffCtx, + bool Add, + typename SrcIter, + typename DstIter + > + void ExpanderCode2::expand( + SrcIter&& input, + DstIter&& output, + CoeffCtx ctx) const + { + (void)*(input + (mCodeSize - 1)); + (void)*(output + (mMessageSize - 1)); + + detail::ExpanderModd prng(mSeed, mCodeSize); + + auto main = mMessageSize / 8 * 8; + u64 i = 0; + + for (; i < main; i += 8, output += 8) + { + if constexpr (Add == false) + { + ctx.zero(output, output + 8); + } + + for (auto j = 0ull; j < mExpanderWeight; ++j) + { + // temp[0...7] = expand(input) + expandEight( + output, input, + prng, ctx); + } + } + + if constexpr (Add == false) + { + ctx.zero(output, output + (mMessageSize-i)); + } + + for (; i < mMessageSize; ++i, ++output) + { + for (auto j = 0ull; j < mExpanderWeight; ++j) + { + ctx.plus(*output, *output, *(input + prng.get())); + } + } + } + + + template< + typename F, + typename CoeffCtx + > + inline typename CoeffCtx::template Vec ExpanderCode2::getB(CoeffCtx ctx) const + { + + typename CoeffCtx::template Vec e, x; + ctx.resize(e, mCodeSize); + ctx.resize(x, mMessageSize * mCodeSize); + + for (u64 i = 0; i < e.size(); ++i) + { + // construct the i'th unit vector as input. + ctx.zero(e.begin(), e.end()); + ctx.one(e.begin() + i, e.begin() + i + 1); + + // expand it to geth the i'th row of the matrix + expand(e.begin(), x.begin() + i * mCodeSize); + } + + return x; + } + + + // //detail::ExpanderModd prng(mSeed, mCodeSize); + // //PointList points(mMessageSize, mCodeSize); + + // //u64 i = 0; + // //auto main = mMessageSize / 8 * 8; + + // //// for the main phase we process 8 expands in parallel. + // //Matrix rows(8, mExpanderWeight); + // //for (; i < main; i += 8) + // //{ + // // for (auto j = 0ull; j < mExpanderWeight; ++j) + // // { + // // for (u64 k = 0; k < 8; ++k) + // // rows(k, j) = prng.get(); + // // } + + // // for (auto j = 0ull; j < mExpanderWeight; ++j) + // // { + // // for (u64 k = 0; k < 8; ++k) + // // { + // // auto rk = rows[k]; + // // // we could have duplicates that cancel. + // // auto count = std::count(rk.begin(), rk.end(), rk[j]); + // // if (count == 1 || (count > 1 && std::find(rk.begin(), rk.end(), rk[j]) == rk.begin() + j)) + // // points.push_back(i + k, rk[j]); + // // } + // // } + // //} + + // //for (; i < mMessageSize; ++i) + // //{ + // // auto rk = rows[0]; + // // for (auto j = 0ull; j < mExpanderWeight; ++j) + // // rk[j] = prng.get(); + + // // for (auto j = 0ull; j < mExpanderWeight; ++j) + // // { + // // // we could have duplicates that cancel. + // // auto count = std::count(rk.begin(), rk.end(), rk[j]); + // // if (count == 1 || (count > 1 && std::find(rk.begin(), rk.end(), rk[j]) == rk.begin() + j)) + // // points.push_back(i, rk[j]); + // // } + // //} + + // //return points; + //} + +} diff --git a/libOTe/Tools/SilentPprf.cpp b/libOTe/Tools/SilentPprf.cpp index 28bd2850..d9d5e968 100644 --- a/libOTe/Tools/SilentPprf.cpp +++ b/libOTe/Tools/SilentPprf.cpp @@ -515,7 +515,7 @@ namespace osuCrypto } void allocateExpandTree( - u64 dpeth, + u64 depth, TreeAllocator& alloc, span>& tree, std::vector>>& levels) @@ -524,7 +524,7 @@ namespace osuCrypto assert((u64)tree.data() % 32 == 0); levels[0] = tree.subspan(0, 1); auto rem = tree.subspan(2); - for (auto i : rng(1ull, dpeth)) + for (auto i : rng(1ull, depth)) { levels[i] = rem.subspan(0, levels[i - 1].size() * 2); assert((u64)levels[i].data() % 32 == 0); diff --git a/libOTe/Tools/SilentPprf.h b/libOTe/Tools/SilentPprf.h index 6a619171..44600021 100644 --- a/libOTe/Tools/SilentPprf.h +++ b/libOTe/Tools/SilentPprf.h @@ -30,11 +30,11 @@ namespace osuCrypto { // The i'th row holds the i'th leaf for all trees. // The j'th tree is in the j'th column. - ByLeafIndex, + ByLeafIndex, // The i'th row holds the i'th tree. // The j'th leaf is in the j'th column. - ByTreeIndex, + ByTreeIndex, // The native output mode. The output will be // a single row with all leaf values. @@ -48,11 +48,11 @@ namespace osuCrypto // ... // // These are all flattened into a single row. - Interleaved, + Interleaved, // call the user's callback. The leaves will be in // Interleaved format. - Callback + Callback }; enum class OTType @@ -147,7 +147,7 @@ namespace osuCrypto span< std::array>& last); void allocateExpandTree( - u64 dpeth, + u64 depth, TreeAllocator& alloc, span>& tree, std::vector>>& levels); @@ -160,7 +160,7 @@ namespace osuCrypto bool mPrint = false; TreeAllocator mTreeAlloc; Matrix> mBaseOTs; - + std::function>)> mOutputFn; SilentMultiPprfSender() = default; @@ -197,7 +197,7 @@ namespace osuCrypto void setBase(span> baseMessages); - + task<> expand(Socket& chls, span value, block seed, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) { MatrixView o(output.data(), output.size(), 1); @@ -205,10 +205,10 @@ namespace osuCrypto } task<> expand( - Socket& chl, - span value, + Socket& chl, + span value, block seed, - MatrixView output, + MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads); diff --git a/libOTe/Tools/Subfield/ExConvCode.h b/libOTe/Tools/Subfield/ExConvCode.h deleted file mode 100644 index 203a9351..00000000 --- a/libOTe/Tools/Subfield/ExConvCode.h +++ /dev/null @@ -1,630 +0,0 @@ -// � 2023 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -#pragma once - -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/Timer.h" -#include "libOTe/Tools/Subfield/Expander.h" -#include "libOTe/Tools/EACode/Util.h" - -namespace osuCrypto::Subfield -{ - - // The encoder for the generator matrix G = B * A. dualEncode(...) is the main function - // config(...) should be called first. - // - // B is the expander while A is the convolution. - // - // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly - // with fixed row weight mExpanderWeight. - // - // A is a lower triangular n by n matrix with ones on the diagonal. The - // mAccumulatorSize diagonals left of the main diagonal are uniformly random. - // If mStickyAccumulator, then the first diagonal left of the main is always ones. - // - // See ExConvCodeInstantiations.cpp for how to instantiate new types that - // dualEncode can be called on. - // - // https://eprint.iacr.org/2023/882 - - class ExConvCode : public TimerAdapter - { - public: - ExpanderCode mExpander; - - // configure the code. The default parameters are choses to balance security and performance. - // For additional parameter choices see the paper. - void config( - u64 messageSize, - u64 codeSize = 0 /*2 * messageSize is default */, - u64 expanderWeight = 7, - u64 accumulatorSize = 16, - bool systematic = true, - block seed = block(99999, 88888)) - { - if (codeSize == 0) - codeSize = 2 * messageSize; - - if (accumulatorSize % 8) - throw std::runtime_error("ExConvCode accumulator size must be a multiple of 8." LOCATION); - - mSeed = seed; - mMessageSize = messageSize; - mCodeSize = codeSize; - mAccumulatorSize = accumulatorSize; - mSystematic = systematic; - mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); - } - - // the seed that generates the code. - block mSeed = ZeroBlock; - - // The message size of the code. K. - u64 mMessageSize = 0; - - // The codeword size of the code. n. - u64 mCodeSize = 0; - - // The size of the accumulator. - u64 mAccumulatorSize = 0; - - // is the code systematic (true=faster) - bool mSystematic = true; - - // return n-k. code size n, message size k. - u64 parityRows() const { return mCodeSize - mMessageSize; } - - // return code size n. - u64 parityCols() const { return mCodeSize; } - - // return message size k. - u64 generatorRows() const { return mMessageSize; } - - // return code size n. - u64 generatorCols() const { return mCodeSize; } - - // Compute w = G * e. e will be modified in the computation. - template - void dualEncode(span e, span w) - { - if (e.size() != mCodeSize) - throw RTE_LOC; - - if (w.size() != mMessageSize) - throw RTE_LOC; - - if (mSystematic) - { - dualEncode(e); - memcpy(w.data(), e.data(), w.size() * sizeof(T)); - setTimePoint("ExConv.encode.memcpy"); - } - else - { - - setTimePoint("ExConv.encode.begin"); - - accumulate(e); - - setTimePoint("ExConv.encode.accumulate"); - - mExpander.expand(e, w); - setTimePoint("ExConv.encode.expand"); - } - } - - // Compute e[0,...,k-1] = G * e. - template - void dualEncode(span e) - { - if (e.size() != mCodeSize) - throw RTE_LOC; - - if (mSystematic) - { - auto d = e.subspan(mMessageSize); - setTimePoint("ExConv.encode.begin"); - accumulate(d); - setTimePoint("ExConv.encode.accumulate"); - mExpander.expand(d, e.subspan(0, mMessageSize)); - setTimePoint("ExConv.encode.expand"); - } - else - { - oc::AlignedUnVector w(mMessageSize); - dualEncode(e, w); - memcpy(e.data(), w.data(), w.size() * sizeof(T)); - setTimePoint("ExConv.encode.memcpy"); - - } - } - - - // Compute e[0,...,k-1] = G * e. - template - void dualEncode2(span e0, span e1) - { - if (e0.size() != mCodeSize) - throw RTE_LOC; - if (e1.size() != mCodeSize) - throw RTE_LOC; - - if (mSystematic) - { - auto d0 = e0.subspan(mMessageSize); - auto d1 = e1.subspan(mMessageSize); - setTimePoint("ExConv.encode.begin"); - accumulate(d0, d1); - setTimePoint("ExConv.encode.accumulate"); - mExpander.expand( - d0, d1, - e0.subspan(0, mMessageSize), - e1.subspan(0, mMessageSize)); - setTimePoint("ExConv.encode.expand"); - } - else - { - //oc::AlignedUnVector w0(mMessageSize); - //dualEncode(e, w); - //memcpy(e.data(), w.data(), w.size() * sizeof(T)); - //setTimePoint("ExConv.encode.memcpy"); - - // not impl. - throw RTE_LOC; - - } - } - - // get the expander matrix - SparseMtx getB() const - { - if (mSystematic) - { - PointList R(mMessageSize, mCodeSize); - auto B = mExpander.getB().points(); - - for (auto p : B) - { - R.push_back(p.mRow, mMessageSize + p.mCol); - } - for (u64 i = 0; i < mMessageSize; ++i) - R.push_back(i, i); - - return R; - } - else - { - return mExpander.getB(); - } - - } - - // Get the parity check version of the accumulator - SparseMtx getAPar() const - { - PRNG prng(mSeed ^ OneBlock); - - auto n = mCodeSize - mSystematic * mMessageSize; - - PointList AP(n, n);; - DenseMtx A = DenseMtx::Identity(n); - - block rnd; - u8* __restrict ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128; - u64 q = 0; - - for (u64 i = 0; i < n; ++i) - { - accOne(AP, i, ptr, prng, rnd, q, qe, n); - } - return AP; - } - - // get the accumulator matrix - SparseMtx getA() const - { - auto APar = getAPar(); - - auto A = DenseMtx::Identity(mCodeSize); - - u64 offset = mSystematic ? mMessageSize : 0ull; - - for (u64 i = 0; i < APar.rows(); ++i) - { - for (auto y : APar.col(i)) - { - //std::cout << y << " "; - if (y != i) - { - auto ay = A.row(y + offset); - auto ai = A.row(i + offset); - ay ^= ai; - } - } - - //std::cout << "\n" << A << std::endl; - } - - return A.sparse(); - } - - // Private functions ------------------------------------ - - inline static void refill(PRNG& prng) - { - assert(prng.mBuffer.size() == 256); - //block b[8]; - for (u64 i = 0; i < 256; i += 8) - { - //auto idx = mPrng.mBuffer[i].get(); - block* __restrict b = prng.mBuffer.data() + i; - block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); - //for (u64 j = 0; j < 8; ++j) - //{ - // b = b ^ mPrng.mBuffer.data()[idx[j]]; - //} - b[0] = AES::roundEnc(b[0], k[0]); - b[1] = AES::roundEnc(b[1], k[1]); - b[2] = AES::roundEnc(b[2], k[2]); - b[3] = AES::roundEnc(b[3], k[3]); - b[4] = AES::roundEnc(b[4], k[4]); - b[5] = AES::roundEnc(b[5], k[5]); - b[6] = AES::roundEnc(b[6], k[6]); - b[7] = AES::roundEnc(b[7], k[7]); - - b[0] = b[0] ^ k[0]; - b[1] = b[1] ^ k[1]; - b[2] = b[2] ^ k[2]; - b[3] = b[3] ^ k[3]; - b[4] = b[4] ^ k[4]; - b[5] = b[5] ^ k[5]; - b[6] = b[6] ^ k[6]; - b[7] = b[7] ^ k[7]; - } - } - - // generate the point list for accumulating row i. - void accOne( - PointList& pl, - u64 i, - u8* __restrict& ptr, - PRNG& prng, - block& rnd, - u64& q, - u64 qe, - u64 size) const - { - u64 j = i + 1; - pl.push_back(i, i); - - if (q + mAccumulatorSize > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; - } - - - for (u64 k = 0; k < mAccumulatorSize; k += 8, q += 8, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - rnd = block::allSame(*ptr); - ++ptr; - - //std::cout << "r " << rnd << std::endl; - auto b0 = rnd; - auto b1 = rnd.slli_epi32<1>(); - auto b2 = rnd.slli_epi32<2>(); - auto b3 = rnd.slli_epi32<3>(); - auto b4 = rnd.slli_epi32<4>(); - auto b5 = rnd.slli_epi32<5>(); - auto b6 = rnd.slli_epi32<6>(); - auto b7 = rnd.slli_epi32<7>(); - //rnd = rnd.mm_slli_epi32<8>(); - - if (j + 0 < size && b0.get(0) < 0) pl.push_back(j + 0, i); - if (j + 1 < size && b1.get(0) < 0) pl.push_back(j + 1, i); - if (j + 2 < size && b2.get(0) < 0) pl.push_back(j + 2, i); - if (j + 3 < size && b3.get(0) < 0) pl.push_back(j + 3, i); - if (j + 4 < size && b4.get(0) < 0) pl.push_back(j + 4, i); - if (j + 5 < size && b5.get(0) < 0) pl.push_back(j + 5, i); - if (j + 6 < size && b6.get(0) < 0) pl.push_back(j + 6, i); - if (j + 7 < size && b7.get(0) < 0) pl.push_back(j + 7, i); - } - - - //if (mWrapping) - { - if (j < size) - pl.push_back(j, i); - ++j; - } - - } - -#ifdef ENABLE_SSE - - using My__m128 = __m128; - -#else - using My__m128 = block; - - inline My__m128 _mm_load_ps(float* b) { return *(block*)b; } - - // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_ps&ig_expand=557 - inline My__m128 _mm_blendv_ps(My__m128 a, My__m128 b, My__m128 mask) - { - My__m128 dst; - for (u64 j = 0; j < 4; ++j) - { - if (mask.get(j) < 0) - dst.set(j, b.get(j)); - else - dst.set(j, a.get(j)); - } - return dst; - } - - - inline My__m128 _mm_setzero_ps() { return ZeroBlock; } -#endif - - template - OC_FORCEINLINE void accOneHelper( - T* __restrict xx, - My__m128 xii, - u64 j, u64 i, u64 size, - block* b - ) - { - My__m128 Zero = _mm_setzero_ps(); - -// if constexpr (std::is_same::value) -// { -// My__m128 bb[8]; -// bb[0] = _mm_load_ps((float*)&b[0]); -// bb[1] = _mm_load_ps((float*)&b[1]); -// bb[2] = _mm_load_ps((float*)&b[2]); -// bb[3] = _mm_load_ps((float*)&b[3]); -// bb[4] = _mm_load_ps((float*)&b[4]); -// bb[5] = _mm_load_ps((float*)&b[5]); -// bb[6] = _mm_load_ps((float*)&b[6]); -// bb[7] = _mm_load_ps((float*)&b[7]); -// -// -// bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); -// bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); -// bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); -// bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); -// bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); -// bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); -// bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); -// bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); -// -// block tt[8]; -// memcpy(tt, bb, 8 * 16); -// -// if (!rangeCheck || j + 0 < size) xx[j + 0] = TypeTrait::plus(xx[j + 0], tt[0]); -// if (!rangeCheck || j + 1 < size) xx[j + 1] = TypeTrait::plus(xx[j + 1], tt[1]); -// if (!rangeCheck || j + 2 < size) xx[j + 2] = TypeTrait::plus(xx[j + 2], tt[2]); -// if (!rangeCheck || j + 3 < size) xx[j + 3] = TypeTrait::plus(xx[j + 3], tt[3]); -// if (!rangeCheck || j + 4 < size) xx[j + 4] = TypeTrait::plus(xx[j + 4], tt[4]); -// if (!rangeCheck || j + 5 < size) xx[j + 5] = TypeTrait::plus(xx[j + 5], tt[5]); -// if (!rangeCheck || j + 6 < size) xx[j + 6] = TypeTrait::plus(xx[j + 6], tt[6]); -// if (!rangeCheck || j + 7 < size) xx[j + 7] = TypeTrait::plus(xx[j + 7], tt[7]); -// } -// else -// { - if ((!rangeCheck || j + 0 < size) && b[0].get(0) < 0) xx[j + 0] = TypeTrait::plus(xx[j + 0], xx[i]); - if ((!rangeCheck || j + 1 < size) && b[1].get(0) < 0) xx[j + 1] = TypeTrait::plus(xx[j + 1], xx[i]); - if ((!rangeCheck || j + 2 < size) && b[2].get(0) < 0) xx[j + 2] = TypeTrait::plus(xx[j + 2], xx[i]); - if ((!rangeCheck || j + 3 < size) && b[3].get(0) < 0) xx[j + 3] = TypeTrait::plus(xx[j + 3], xx[i]); - if ((!rangeCheck || j + 4 < size) && b[4].get(0) < 0) xx[j + 4] = TypeTrait::plus(xx[j + 4], xx[i]); - if ((!rangeCheck || j + 5 < size) && b[5].get(0) < 0) xx[j + 5] = TypeTrait::plus(xx[j + 5], xx[i]); - if ((!rangeCheck || j + 6 < size) && b[6].get(0) < 0) xx[j + 6] = TypeTrait::plus(xx[j + 6], xx[i]); - if ((!rangeCheck || j + 7 < size) && b[7].get(0) < 0) xx[j + 7] = TypeTrait::plus(xx[j + 7], xx[i]); -// } - } - - // accumulating row i. - template - OC_FORCEINLINE void accOne( - T* __restrict xx, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size) { - u64 j = i + 1; - if (width) { - if (q + width > qe) { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; - - } - q += width; - - for (u64 k = 0; k < width; ++k, j += 8) { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - block rnd = block::allSame(*(u8*)ptr++); - - block b[8]; - b[0] = rnd; - b[1] = rnd.slli_epi32<1>(); - b[2] = rnd.slli_epi32<2>(); - b[3] = rnd.slli_epi32<3>(); - b[4] = rnd.slli_epi32<4>(); - b[5] = rnd.slli_epi32<5>(); - b[6] = rnd.slli_epi32<6>(); - b[7] = rnd.slli_epi32<7>(); - -// if constexpr (std::is_same::value) { -// accOneHelper(xx, _mm_setzero_ps(), j, i, size, b); -// } -// else { - My__m128 xii;// = ::_mm_set_ps(0.0f, 0.0f, 0.0f, 0.0f); - memset(&xii, 0, sizeof(My__m128)); - accOneHelper(xx, xii, j, i, size, b); -// } - } - } - - if (!rangeCheck || j < size) { - auto xj = TypeTrait::plus(xx[j], xx[i]); - xx[j] = xj; - } - } - - - // accumulating row i. - template - OC_FORCEINLINE void accOne( - T0* __restrict xx0, - T1* __restrict xx1, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size) - { - u64 j = i + 1; - if (width) - { - - - if (q + width > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; - - } - q += width; - - for (u64 k = 0; k < width; ++k, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - block rnd = block::allSame(*(u8*)ptr++); - - block b[8]; - b[0] = rnd; - b[1] = rnd.slli_epi32<1>(); - b[2] = rnd.slli_epi32<2>(); - b[3] = rnd.slli_epi32<3>(); - b[4] = rnd.slli_epi32<4>(); - b[5] = rnd.slli_epi32<5>(); - b[6] = rnd.slli_epi32<6>(); - b[7] = rnd.slli_epi32<7>(); - -// if constexpr (std::is_same::value) { -// auto xii0 = _mm_load_ps((float*)(xx0 + i)); -// accOneHelper(xx0, xii0, j, i, size, b); -// } -// else { - accOneHelper(xx0, _mm_setzero_ps(), j, i, size, b); -// } -// if constexpr (std::is_same::value) { -// auto xii1 = _mm_load_ps((float*)(xx1 + i)); -// accOneHelper(xx1, xii1, j, i, size, b); -// } -// else { - accOneHelper(xx1, _mm_setzero_ps(), j, i, size, b); -// } - } - } - - if (!rangeCheck || j < size) - { - xx0[j] = TypeTrait::plus(xx0[j], xx0[i]); - xx1[j] = TypeTrait::plus(xx1[j], xx1[i]); - } - } - - - // accumulate x onto itself. - template - void accumulate(span x) - { - PRNG prng(mSeed ^ OneBlock); - - u64 i = 0; - auto size = x.size(); - auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); - u8* ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128 / 8; - u64 q = 0; - T* __restrict xx = x.data(); - - { - -#define CASE(I) case I:\ - for (; i < main; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ - for (; i < size; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ - break - - switch (mAccumulatorSize / 8) - { - CASE(0); - CASE(1); - CASE(2); - CASE(3); - CASE(4); - default: - throw RTE_LOC; - break; - } -#undef CASE - } - } - - - // accumulate x onto itself. - template - void accumulate(span x0, span x1) - { - PRNG prng(mSeed ^ OneBlock); - - u64 i = 0; - auto size = x0.size(); - auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); - u8* ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128 / 8; - u64 q = 0; - T0* __restrict xx0 = x0.data(); - T1* __restrict xx1 = x1.data(); - - { - -#define CASE(I) case I:\ - for (; i < main; ++i)\ - accOne(xx0,xx1, i, ptr, prng, q, qe, size);\ - for (; i < size; ++i)\ - accOne(xx0, xx1, i, ptr, prng, q, qe, size);\ - break - - switch (mAccumulatorSize / 8) - { - CASE(0); - CASE(1); - CASE(2); - CASE(3); - CASE(4); - default: - throw RTE_LOC; - break; - } -#undef CASE - } - } - }; -} diff --git a/libOTe/Tools/Subfield/Expander.h b/libOTe/Tools/Subfield/Expander.h deleted file mode 100644 index d79ff948..00000000 --- a/libOTe/Tools/Subfield/Expander.h +++ /dev/null @@ -1,498 +0,0 @@ -// � 2023 Peter Rindal. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -#pragma once - -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/Range.h" -#include "libOTe/Tools/LDPC/Mtx.h" -#include "libOTe/Tools/EACode/Util.h" - -namespace osuCrypto::Subfield -{ - - // The encoder for the expander matrix B. - // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly - // with fixed row weight mExpanderWeight. - class ExpanderCode - { - public: - - void config( - u64 messageSize, - u64 codeSize = 0 /* default is 5* messageSize */, - u64 expanderWeight = 21, - block seed = block(33333, 33333)) - { - mMessageSize = messageSize; - mCodeSize = codeSize; - mExpanderWeight = expanderWeight; - mSeed = seed; - - } - - // the seed that generates the code. - block mSeed = block(0, 0); - - // The message size of the code. K. - u64 mMessageSize = 0; - - // The codeword size of the code. n. - u64 mCodeSize = 0; - - // The row weight of the B matrix. - u64 mExpanderWeight = 0; - - u64 parityRows() const { return mCodeSize - mMessageSize; } - u64 parityCols() const { return mCodeSize; } - - u64 generatorRows() const { return mMessageSize; } - u64 generatorCols() const { return mCodeSize; } - - - - template - typename std::enable_if::type - expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const - { - auto r = prng.get(); - return ee[r]; - } - - template - typename std::enable_if<(count == 1)>::type - expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng)const - { - auto r = prng.get(); - - if (Add) - { - *y1 = TypeTrait::plus(*y1, ee1[r]); - *y2 = TypeTrait::plus(*y2, ee2[r]); - } - else - { - - *y1 = ee1[r]; - *y2 = ee2[r]; - } - } - - template - OC_FORCEINLINE typename std::enable_if<(count > 1), T>::type - expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const - { - if constexpr (count >= 8) - { - u64 rr[8]; - T w[8]; - rr[0] = prng.get(); - rr[1] = prng.get(); - rr[2] = prng.get(); - rr[3] = prng.get(); - rr[4] = prng.get(); - rr[5] = prng.get(); - rr[6] = prng.get(); - rr[7] = prng.get(); - - w[0] = ee[rr[0]]; - w[1] = ee[rr[1]]; - w[2] = ee[rr[2]]; - w[3] = ee[rr[3]]; - w[4] = ee[rr[4]]; - w[5] = ee[rr[5]]; - w[6] = ee[rr[6]]; - w[7] = ee[rr[7]]; - - auto ww = - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - w[0], - w[1]), - w[2]), - w[3]), - w[4]), - w[5]), - w[6]), - w[7]); - - if constexpr (count > 8) - ww = TypeTrait::plus(ww, expandOne(ee, prng)); - return ww; - } - else - { - - auto r = prng.get(); - auto ww = expandOne(ee, prng); - return TypeTrait::plus(ww, ee[r]); - } - } - - template - OC_FORCEINLINE typename std::enable_if<(count > 1)>::type - expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng) const - { - if constexpr (count >= 8) - { - u64 rr[8]; - T w1[8]; - T2 w2[8]; - rr[0] = prng.get(); - rr[1] = prng.get(); - rr[2] = prng.get(); - rr[3] = prng.get(); - rr[4] = prng.get(); - rr[5] = prng.get(); - rr[6] = prng.get(); - rr[7] = prng.get(); - - w1[0] = ee1[rr[0]]; - w1[1] = ee1[rr[1]]; - w1[2] = ee1[rr[2]]; - w1[3] = ee1[rr[3]]; - w1[4] = ee1[rr[4]]; - w1[5] = ee1[rr[5]]; - w1[6] = ee1[rr[6]]; - w1[7] = ee1[rr[7]]; - - w2[0] = ee2[rr[0]]; - w2[1] = ee2[rr[1]]; - w2[2] = ee2[rr[2]]; - w2[3] = ee2[rr[3]]; - w2[4] = ee2[rr[4]]; - w2[5] = ee2[rr[5]]; - w2[6] = ee2[rr[6]]; - w2[7] = ee2[rr[7]]; - - auto ww1 = - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - w1[0], - w1[1]), - w1[2]), - w1[3]), - w1[4]), - w1[5]), - w1[6]), - w1[7]); - auto ww2 = - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - TypeTrait::plus( - w2[0], - w2[1]), - w2[2]), - w2[3]), - w2[4]), - w2[5]), - w2[6]), - w2[7]); - - if constexpr (count > 8) - { - T yy1; - T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); - ww1 = TypeTrait::plus(ww1, yy1); - ww2 = TypeTrait::plus(ww2, yy2); - } - - if constexpr (Add) - { - *y1 = TypeTrait::plus(*y1, ww1); - *y2 = TypeTrait::plus(*y2, ww2); - } - else - { - *y1 = ww1; - *y2 = ww2; - } - - } - else - { - - auto r = prng.get(); - if constexpr (Add) - { - auto w1 = ee1[r]; - auto w2 = ee2[r]; - expandOne(ee1, ee2, y1, y2, prng); - *y1 = TypeTrait::plus(*y1, w1); - *y2 = TypeTrait::plus(*y2, w2); - - } - else - { - - T yy1; - T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); - *y1 = TypeTrait::plus(yy1, ee1[r]); - *y2 = TypeTrait::plus(yy2, ee2[r]); - } - } - } - - template - void expand( - span e, - span w) const - { - assert(w.size() == mMessageSize); - assert(e.size() == mCodeSize); - detail::ExpanderModd prng(mSeed, mCodeSize); - - const T* __restrict ee = e.data(); - T* __restrict ww = w.data(); - - auto main = mMessageSize / 8 * 8; - u64 i = 0; - - for (; i < main; i += 8) - { -#define CASE(I) \ - case I:\ - if constexpr(Add)\ - {\ - ww[i + 0] = TypeTrait::plus(ww[i + 0], expandOne(ee, prng));\ - ww[i + 1] = TypeTrait::plus(ww[i + 1], expandOne(ee, prng));\ - ww[i + 2] = TypeTrait::plus(ww[i + 2], expandOne(ee, prng));\ - ww[i + 3] = TypeTrait::plus(ww[i + 3], expandOne(ee, prng));\ - ww[i + 4] = TypeTrait::plus(ww[i + 4], expandOne(ee, prng));\ - ww[i + 5] = TypeTrait::plus(ww[i + 5], expandOne(ee, prng));\ - ww[i + 6] = TypeTrait::plus(ww[i + 6], expandOne(ee, prng));\ - ww[i + 7] = TypeTrait::plus(ww[i + 7], expandOne(ee, prng));\ - }\ - else\ - {\ - ww[i + 0] = expandOne(ee, prng);\ - ww[i + 1] = expandOne(ee, prng);\ - ww[i + 2] = expandOne(ee, prng);\ - ww[i + 3] = expandOne(ee, prng);\ - ww[i + 4] = expandOne(ee, prng);\ - ww[i + 5] = expandOne(ee, prng);\ - ww[i + 6] = expandOne(ee, prng);\ - ww[i + 7] = expandOne(ee, prng);\ - }\ - break - - switch (mExpanderWeight) - { - CASE(5); - CASE(7); - CASE(9); - CASE(11); - CASE(21); - CASE(40); - default: - for (u64 jj = 0; jj < 8; ++jj) - { - auto r = prng.get(); - auto wv = ee[r]; - - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv = TypeTrait::plus(wv, ee[r]); - } - if constexpr (Add) - ww[i + jj] = TypeTrait::plus(ww[i + jj], wv); - else - ww[i + jj] = wv; - - } - } -#undef CASE - } - - for (; i < mMessageSize; ++i) - { - auto wv = ee[prng.get()]; - for (auto j = 1ull; j < mExpanderWeight; ++j) - wv = TypeTrait::plus(wv, ee[prng.get()]); - - if constexpr (Add) - ww[i] = TypeTrait::plus(ww[i], wv); - else - ww[i] = wv; - } - } - - template - void expand( - span e1, - span e2, - span w1, - span w2 - ) const - { - assert(w1.size() == mMessageSize); - assert(w2.size() == mMessageSize); - assert(e1.size() == mCodeSize); - assert(e2.size() == mCodeSize); - detail::ExpanderModd prng(mSeed, mCodeSize); - - const T* __restrict ee1 = e1.data(); - const T2* __restrict ee2 = e2.data(); - T* __restrict ww1 = w1.data(); - T2* __restrict ww2 = w2.data(); - - auto main = mMessageSize / 8 * 8; - u64 i = 0; - - for (; i < main; i += 8) - { -#define CASE(I) \ - case I:\ - expandOne(ee1, ee2, &ww1[i + 0], &ww2[i + 0], prng);\ - expandOne(ee1, ee2, &ww1[i + 1], &ww2[i + 1], prng);\ - expandOne(ee1, ee2, &ww1[i + 2], &ww2[i + 2], prng);\ - expandOne(ee1, ee2, &ww1[i + 3], &ww2[i + 3], prng);\ - expandOne(ee1, ee2, &ww1[i + 4], &ww2[i + 4], prng);\ - expandOne(ee1, ee2, &ww1[i + 5], &ww2[i + 5], prng);\ - expandOne(ee1, ee2, &ww1[i + 6], &ww2[i + 6], prng);\ - expandOne(ee1, ee2, &ww1[i + 7], &ww2[i + 7], prng);\ - break - - switch (mExpanderWeight) - { - CASE(5); - CASE(7); - CASE(9); - CASE(11); - CASE(21); - CASE(40); - default: - for (u64 jj = 0; jj < 8; ++jj) - { - auto r = prng.get(); - auto wv1 = ee1[r]; - auto wv2 = ee2[r]; - - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv1 = TypeTrait::plus(wv1, ee1[r]); - wv2 = TypeTrait::plus(wv2, ee2[r]); - } - if constexpr (Add) - { - ww1[i + jj] = TypeTrait::plus(ww1[i + jj], wv1); - ww2[i + jj] = TypeTrait::plus(ww2[i + jj], wv2); - } - else - { - - ww1[i + jj] = wv1; - ww2[i + jj] = wv2; - } - } - } -#undef CASE - } - - for (; i < mMessageSize; ++i) - { - auto r = prng.get(); - auto wv1 = ee1[r]; - auto wv2 = ee2[r]; - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv1 = TypeTrait::plus(wv1, ee1[r]); - wv2 = TypeTrait::plus(wv2, ee2[r]); - - } - if constexpr (Add) - { - ww1[i] = TypeTrait::plus(ww1[i], wv1); - ww2[i] = TypeTrait::plus(ww2[i], wv2); - } - else - { - ww1[i] = wv1; - ww2[i] = wv2; - } - } - } - - - SparseMtx getB() const - { - //PRNG prng(mSeed); - detail::ExpanderModd prng(mSeed, mCodeSize); - PointList points(mMessageSize, mCodeSize); - - std::vector row(mExpanderWeight); - - { - - for (auto i : rng(mMessageSize)) - { - row[0] = prng.get(); - //points.push_back(i, row[0]); - for (auto j : rng(1, mExpanderWeight)) - { - //do { - row[j] = prng.get(); - //} while - auto iter = std::find(row.data(), row.data() + j, row[j]); - if (iter != row.data() + j) - { - row[j] = ~0ull; - *iter = ~0ull; - } - //throw RTE_LOC; - - } - for (auto j : rng(mExpanderWeight)) - { - - if (row[j] != ~0ull) - { - //std::cout << row[j] << " "; - points.push_back(i, row[j]); - } - else - { - //std::cout << "* "; - } - } - //std::cout << std::endl; - } - } - - return points; - } - - }; -} diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h index 18a13b40..890f54fe 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/Subfield/Subfield.h @@ -1,224 +1,246 @@ +#pragma once #include "libOTe/Vole/Noisy/NoisyVoleSender.h" #include "cryptoTools/Common/BitIterator.h" #include "cryptoTools/Common/BitVector.h" -namespace osuCrypto::Subfield { +namespace osuCrypto { + /* - * Primitive TypeTrait for integers + * Primitive CoeffCtx for integers-like types */ - template - struct TypeTraitPrimitive { - using G = T; - using F = T; - - static constexpr size_t bitsG = sizeof(G) * 8; - static constexpr size_t bitsF = sizeof(F) * 8; - static constexpr size_t bytesG = sizeof(G); - static constexpr size_t bytesF = sizeof(F); + struct CoeffCtxInteger + { - static OC_FORCEINLINE F plus(const F& lhs, const F& rhs) { - return lhs + rhs; + template + static OC_FORCEINLINE void plus(R&& ret, F1&& lhs, F2&& rhs) { + ret = lhs + rhs; } - static OC_FORCEINLINE F minus(const F& lhs, const F& rhs) { - return lhs - rhs; + + template + static OC_FORCEINLINE void minus(R&& ret, F1&& lhs, F2&& rhs) { + ret = lhs - rhs; } - static OC_FORCEINLINE F mul(const F& lhs, const F& rhs) { - return lhs * rhs; + template + static OC_FORCEINLINE void mul(R&& ret, F1&& lhs, F2&& rhs) { + ret = lhs * rhs; } - static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { + + template + static OC_FORCEINLINE bool eq(F&& lhs, F&& rhs) { return lhs == rhs; } - static OC_FORCEINLINE BitVector BitVectorF(F& x) { - return {(u8*)&x, bitsF}; - } - static OC_FORCEINLINE F fromBlock(const block& b) { - return b.get()[0]; + // the bit size require to prepresent F + // the protocol will perform binary decomposition + // of F using this many bits + template + static u64 bitSize() + { + return sizeof(F) * 8; } - static OC_FORCEINLINE G fromBlockG(const block& b) { - return b.get()[0]; - } - static OC_FORCEINLINE F pow(u64 power) { - F ret = 1; - ret <<= power; - return ret; + template + static OC_FORCEINLINE BitVector binaryDecomposition(F& x) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + return { (u8*)&x, sizeof(F) * 8 }; } - }; - using TypeTrait64 = TypeTraitPrimitive; + template + static OC_FORCEINLINE void fromBlock(F& ret, const block& b) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); - /* - * TypeTrait for GF(2^128) - */ - struct TypeTraitF128 { - using G = block; - using F = block; - - static constexpr size_t bitsG = sizeof(G) * 8; - static constexpr size_t bitsF = sizeof(F) * 8; - static constexpr size_t bytesG = sizeof(G); - static constexpr size_t bytesF = sizeof(F); - - static OC_FORCEINLINE F plus(const F& lhs, const F& rhs) { - return lhs ^ rhs; - } - static OC_FORCEINLINE F minus(const F& lhs, const F& rhs) { - return lhs ^ rhs; - } - static OC_FORCEINLINE F mul(const F& lhs, const F& rhs) { - return lhs.gf128Mul(rhs); - } - static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { - return lhs == rhs; + if constexpr (sizeof(F) <= sizeof(block)) + { + memcpy(&ret, &b, sizeof(F)); + } + else + { + auto constexpr size = (sizeof(F) + sizeof(block) - 1) / sizeof(block); + std::array buffer; + mAesFixedKey.ecbEncCounterMode(b, buffer); + memcpy(&ret, buffer.data(), sizeof(ret)); + } } - static OC_FORCEINLINE BitVector BitVectorF(F& x) { - return {(u8*)&x, bitsF}; + template + static OC_FORCEINLINE void pow(F& ret, u64 power) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + memset(&ret, 0, sizeof(F)); + *BitIterator((u8*)&ret, power) = 1; } - static OC_FORCEINLINE F fromBlock(const block& b) { - return b; + + template + static OC_FORCEINLINE void copy(F& dst, const F& src) + { + dst = src; } - static OC_FORCEINLINE G fromBlockG(const block& b) { - return b; + template + static OC_FORCEINLINE void copy( + SrcIter begin, + SrcIter end, + DstIter dstBegin) + { + using F1 = std::remove_reference_t; + using F2 = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + static_assert(std::is_same_v, "src and destication types are not the same."); + + std::copy(begin, end, dstBegin); } - static OC_FORCEINLINE F pow(u64 power) { - F ret = ZeroBlock; - *BitIterator((u8*)&ret, power) = 1; - return ret; + // must have + // .size() + // operator[] that returns the element. + // begin() iterator + // end() iterator + template + using Vec = AlignedUnVector; + + // the size of F when serialized. + template + static u64 byteSize() + { + return sizeof(F); } - }; - // array - template - struct Vec { - std::array v; - OC_FORCEINLINE Vec operator+(const Vec& rhs) const { - Vec ret; - for (u64 i = 0; i < N; ++i) { - ret.v[i] = v[i] + rhs.v[i]; + + // deserialize buff into dst + template + static void deserialize(Vec& dst, span buff) + { + if (dst.size() * sizeof(F) != buff.size()) + { + std::cout << "bad buffer size " << LOCATION << std::endl; + std::terminate(); } - return ret; - } - OC_FORCEINLINE Vec operator-(const Vec& rhs) const { - Vec ret; - for (u64 i = 0; i < N; ++i) { - ret.v[i] = v[i] - rhs.v[i]; + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + memcpy(dst.data(), buff.data(), buff.size()); + } + + // serial buff into dst + template + static void serialize(span dst, Vec& buff) + { + if (buff.size() * sizeof(F) != dst.size()) + { + std::cout << "bad buffer size " << LOCATION << std::endl; + std::terminate(); } - return ret; + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + memcpy(dst.data(), buff.data(), dst.size()); } - OC_FORCEINLINE Vec operator*(const T& rhs) const { - Vec ret; - for (u64 i = 0; i < N; ++i) { - ret.v[i] = v[i] * rhs; + + template + static void zero(Iter begin, Iter end) + { + using F = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + + if (begin != end) + { + auto n = std::distance(begin, end); + assert(n > 0); + memset(&*begin, 0, n * sizeof(F)); } - return ret; } - OC_FORCEINLINE T operator[](u64 idx) const { - return v[idx]; + + template + static void one(Iter begin, Iter end) + { + std::fill(begin, end, 1); } - OC_FORCEINLINE T& operator[](u64 idx) { - return v[idx]; + + // resize Vec + template + static void resize(FVec&& f, u64 size) + { + f.resize(size); } - OC_FORCEINLINE bool operator==(const Vec& rhs) const { - for (u64 i = 0; i < N; ++i) { - if (v[i] != rhs.v[i]) return false; - } - return true; + }; + + // CoeffCtx for GF fields. + // ^ operator is used for addition. + struct CoeffCtxGF : CoeffCtxInteger + { + + template + static OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + ret = lhs ^ rhs; } - OC_FORCEINLINE bool operator!=(const Vec& rhs) const { - return !(*this == rhs); + template + static OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { + ret = lhs ^ rhs; } }; - // TypeTraitVec for array of integers - template - struct TypeTraitVec { - using G = T; - using F = Vec; + // block does not use operator* + struct CoeffCtxGFBlock : CoeffCtxGF + { + static OC_FORCEINLINE void mul(block& ret, const block& lhs, const block& rhs) { + ret = lhs.gf128Mul(rhs); + } + }; - static constexpr size_t bitsG = sizeof(G) * 8; - static constexpr size_t bitsF = sizeof(F) * 8; - static constexpr size_t bytesG = sizeof(G); - static constexpr size_t bytesF = sizeof(F); - static constexpr size_t sizeBlocks = (bytesF + sizeof(block) - 1) / sizeof(block); - static constexpr size_t size = N; + template + struct CoeffCtxArray : CoeffCtxInteger + { + using F = std::array; - static OC_FORCEINLINE F plus(const F& lhs, const F& rhs) { - F ret; - for (u64 i = 0; i < N; ++i) { - ret.v[i] = lhs.v[i] + rhs.v[i]; - } - return ret; - } - static OC_FORCEINLINE F minus(const F& lhs, const F& rhs) { - F ret; - for (u64 i = 0; i < N; ++i) { - ret.v[i] = lhs.v[i] - rhs.v[i]; + static OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + for (u64 i = 0; i < lhs.size(); ++i) { + ret[i] = lhs[i] + rhs[i]; } - return ret; } - static OC_FORCEINLINE F mul(const F& lhs, const G& rhs) { - F ret; - for (u64 i = 0; i < N; ++i) { - ret.v[i] = lhs.v[i] * rhs; - } - return ret; + + static OC_FORCEINLINE void plus(G& ret, const G& lhs, const G& rhs) { + ret = lhs + rhs; } - static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { - for (u64 i = 0; i < N; ++i) { - if (lhs.v[i] != rhs.v[i]) return false; + + static OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) + { + for (u64 i = 0; i < lhs.size(); ++i) { + ret[i] = lhs[i] - rhs[i]; } - return true; - } - static OC_FORCEINLINE G plus(const G& lhs, const G& rhs) { - return lhs + rhs; } - static OC_FORCEINLINE BitVector BitVectorF(F& x) { - return {(u8*)&x, bitsF}; + static OC_FORCEINLINE void minus(G& ret, const G& lhs, const G& rhs) { + ret = lhs - rhs; } - static OC_FORCEINLINE F fromBlock(const block& b) { - F ret; - if (N * sizeof(T) <= sizeof(block)) { - memcpy(ret.v.data(), &b, bytesF); - return ret; - } - else { - std::array buf; - for (u64 i = 0; i < sizeBlocks; ++i) { - buf[i] = b + block(i, i); - } - mAesFixedKey.hashBlocks(buf.data(), buf.data()); - memcpy(&ret, &buf, sizeof(F)); - return ret; + static OC_FORCEINLINE void mul(F& ret, const F& lhs, const G& rhs) + { + for (u64 i = 0; i < lhs.size(); ++i) { + ret[i] = lhs[i] * rhs; } } - // assume primitive type for G now - static OC_FORCEINLINE G fromBlockG(const block& b) { - return b.get()[0]; + static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) + { + for (u64 i = 0; i < lhs.size(); ++i) { + if (lhs[i] != rhs[i]) + return false; + } + return true; } - static OC_FORCEINLINE F pow(u64 power) { - F ret; - memset(&ret, 0, sizeof(ret)); - *BitIterator((u8*)&ret, power) = 1; - return ret; + static OC_FORCEINLINE bool eq(const G& lhs, const G& rhs) + { + return lhs == rhs; } }; - template - struct DefaultTrait: TypeTraitPrimitive { - static_assert(std::is_same::value, "F and G must be the same type"); + template + struct DefaultCoeffCtx : CoeffCtxInteger { }; - template<> struct DefaultTrait: TypeTraitF128 {}; + // GF128 vole + template<> struct DefaultCoeffCtx : CoeffCtxGFBlock {}; + + // OT + template<> struct DefaultCoeffCtx : CoeffCtxGFBlock {}; } diff --git a/libOTe/Tools/Subfield/Subfield.cpp b/libOTe/Tools/Subfield/SubfieldPprf.cpp similarity index 75% rename from libOTe/Tools/Subfield/Subfield.cpp rename to libOTe/Tools/Subfield/SubfieldPprf.cpp index ccefd3a5..0a6c4455 100644 --- a/libOTe/Tools/Subfield/Subfield.cpp +++ b/libOTe/Tools/Subfield/SubfieldPprf.cpp @@ -1,8 +1,8 @@ #include "cryptoTools/Crypto/AES.h" -namespace osuCrypto::Subfield { +namespace osuCrypto { // A public PRF/PRG that we will use for deriving the GGM tree. - extern const std::array gAes = []() { + extern const std::array gGgmAes = []() { std::array aes; aes[0].setKey(toBlock(3242342)); aes[1].setKey(toBlock(8993849)); diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h index f2ba75e4..4166bcbe 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -10,50 +10,69 @@ #include "libOTe/Tools/SilentPprf.h" #include "SubfieldPprf.h" #include +#include "libOTe/Tools/Subfield/Subfield.h" -namespace osuCrypto::Subfield +namespace osuCrypto { - extern const std::array gAes; + extern const std::array gGgmAes; - template + inline void allocateExpandTree( + TreeAllocator& alloc, + span>& tree, + std::vector>>& levels) + { + tree = alloc.get(); + assert((u64)tree.data() % 32 == 0); + levels[0] = tree.subspan(0, 1); + auto rem = tree.subspan(2); + for (auto i : rng(1ull, levels.size())) + { + levels[i] = rem.subspan(0, levels[i - 1].size() * 2); + assert((u64)levels[i].data() % 32 == 0); + rem = rem.subspan(levels[i].size()); + } + } + + template void copyOut( - span> lvl, - MatrixView output, + VecF& leaf, + VecF& output, u64 totalTrees, - u64 tIdx, + u64 treeIndex, PprfOutputFormat oFormat, - std::function> lvl)>& callback) + std::function& callback) { + auto curSize = std::min(totalTrees - treeIndex, 8); + auto domain = leaf.size() / 8; if (oFormat == PprfOutputFormat::ByLeafIndex) { - - auto curSize = std::min(totalTrees - tIdx, 8); if (curSize == 8) { - - for (u64 i = 0; i < output.rows(); ++i) + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) { - auto oi = output[i].subspan(tIdx, 8); - auto& ii = lvl[i]; - oi[0] = ii[0]; - oi[1] = ii[1]; - oi[2] = ii[2]; - oi[3] = ii[3]; - oi[4] = ii[4]; - oi[5] = ii[5]; - oi[6] = ii[6]; - oi[7] = ii[7]; + auto oIdx = totalTrees * leafIndex + treeIndex; + auto iIdx = leafIndex * 8; + output[oIdx + 0] = leaf[iIdx + 0]; + output[oIdx + 1] = leaf[iIdx + 1]; + output[oIdx + 2] = leaf[iIdx + 2]; + output[oIdx + 3] = leaf[iIdx + 3]; + output[oIdx + 4] = leaf[iIdx + 4]; + output[oIdx + 5] = leaf[iIdx + 5]; + output[oIdx + 6] = leaf[iIdx + 6]; + output[oIdx + 7] = leaf[iIdx + 7]; } } else { - for (u64 i = 0; i < output.rows(); ++i) + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) { - auto oi = output[i].subspan(tIdx, curSize); - auto& ii = lvl[i]; + //auto oi = output[leafIndex].subspan(treeIndex, curSize); + //auto& ii = leaf[leafIndex]; + auto oIdx = totalTrees * leafIndex + treeIndex; + auto iIdx = leafIndex * 8; for (u64 j = 0; j < curSize; ++j) - oi[j] = ii[j]; + output[oIdx + j] = leaf[iIdx + j]; } } @@ -61,136 +80,122 @@ namespace osuCrypto::Subfield else if (oFormat == PprfOutputFormat::ByTreeIndex) { - auto curSize = std::min(totalTrees - tIdx, 8); if (curSize == 8) { - for (u64 i = 0; i < output.cols(); ++i) + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) { - auto& ii = lvl[i]; - output(tIdx + 0, i) = ii[0]; - output(tIdx + 1, i) = ii[1]; - output(tIdx + 2, i) = ii[2]; - output(tIdx + 3, i) = ii[3]; - output(tIdx + 4, i) = ii[4]; - output(tIdx + 5, i) = ii[5]; - output(tIdx + 6, i) = ii[6]; - output(tIdx + 7, i) = ii[7]; + auto iIdx = leafIndex * 8; + + output[(treeIndex + 0) * domain + leafIndex] = leaf[iIdx + 0]; + output[(treeIndex + 1) * domain + leafIndex] = leaf[iIdx + 1]; + output[(treeIndex + 2) * domain + leafIndex] = leaf[iIdx + 2]; + output[(treeIndex + 3) * domain + leafIndex] = leaf[iIdx + 3]; + output[(treeIndex + 4) * domain + leafIndex] = leaf[iIdx + 4]; + output[(treeIndex + 5) * domain + leafIndex] = leaf[iIdx + 5]; + output[(treeIndex + 6) * domain + leafIndex] = leaf[iIdx + 6]; + output[(treeIndex + 7) * domain + leafIndex] = leaf[iIdx + 7]; } } else { - for (u64 i = 0; i < output.cols(); ++i) + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) { - auto& ii = lvl[i]; + auto iIdx = leafIndex * 8; for (u64 j = 0; j < curSize; ++j) - output(tIdx + j, i) = ii[j]; + output[(treeIndex + j) * domain + leafIndex] = leaf[iIdx + j]; } } } else if (oFormat == PprfOutputFormat::Callback) - callback(tIdx, lvl); + callback(treeIndex, leaf); else throw RTE_LOC; } - template + template void allocateExpandBuffer( - u64 depth, - u64 activeChildXorDelta, - std::vector& buff, - span< std::array, 2>>& sums, - span< std::array>& last) + u64 depth, + u64 programPuncturedPoint, + std::vector& buff, + span, 2>>& sums, + span& leaf) { + u64 elementSize = CoeffCtx::byteSize(); + using SumType = std::array, 2>; - using LastType = std::array; - u64 numSums = depth - activeChildXorDelta; - u64 numLast = activeChildXorDelta * 8; - u64 numBytes = numSums * 16 * 16 + numLast * 4 * sizeof(F); + + // the number of internal levels. We process 8 trees at a time + u64 numSums = (depth - programPuncturedPoint) * 8; + + // the number of leaf level that we will program + u64 numleaf = programPuncturedPoint * 8; + + // num of bytes they will take up. + u64 numBytes = numSums * 2 * sizeof(block) + numleaf * 4 * elementSize; + + // allocate the buffer and partition them. buff.resize(numBytes); sums = span((SumType*)buff.data(), numSums); - last = span((LastType*)(sums.data() + sums.size()), numLast); + leaf = span((u8*)(sums.data() + sums.size()), numleaf * 4 * elementSize); void* sEnd = sums.data() + sums.size(); - void* lEnd = last.data() + last.size(); + void* lEnd = leaf.data() + leaf.size(); void* end = buff.data() + buff.size(); if (sEnd > end || lEnd > end) throw RTE_LOC; } - template + template void validateExpandFormat( - PprfOutputFormat oFormat, - MatrixView output, - u64 domain, - u64 pntCount - ) + PprfOutputFormat oFormat, + VecF& output, + u64 domain, + u64 pntCount) { - - if (oFormat == PprfOutputFormat::ByLeafIndex) - { - if (output.rows() != domain) - throw RTE_LOC; - - if (output.cols() != pntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::ByTreeIndex) - { - if (output.cols() != domain) - throw RTE_LOC; - - if (output.rows() != pntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Interleaved) + switch (oFormat) { - if (output.cols() != 1) + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + case osuCrypto::PprfOutputFormat::Interleaved: + if (output.size() != domain * pntCount) throw RTE_LOC; - if (domain & 1) + break; + case osuCrypto::PprfOutputFormat::Callback: + if (output.size()) throw RTE_LOC; - - auto rows = output.rows(); - if (rows > (domain * pntCount) || - rows / 128 != (domain * pntCount) / 128) - throw RTE_LOC; - if (pntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Callback) - { - if (domain & 1) - throw RTE_LOC; - if (pntCount & 7) - throw RTE_LOC; - } - else - { + break; + default: throw RTE_LOC; + break; } + } - template> + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > class SilentSubfieldPprfSender : public TimerAdapter { public: u64 mDomain = 0, mDepth = 0, mPntCount = 0; std::vector mValue; - bool mPrint = false; TreeAllocator mTreeAlloc; Matrix> mBaseOTs; - std::function< - void(u64 - treeIdx, span >)> - mOutputFn; + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; + + std::function mOutputFn; SilentSubfieldPprfSender() = default; - SilentSubfieldPprfSender(const SilentSubfieldPprfSender &) = delete; + SilentSubfieldPprfSender(const SilentSubfieldPprfSender&) = delete; - SilentSubfieldPprfSender(SilentSubfieldPprfSender &&) = delete; + SilentSubfieldPprfSender(SilentSubfieldPprfSender&&) = delete; SilentSubfieldPprfSender(u64 domainSize, u64 pointCount) { configure(domainSize, pointCount); @@ -226,91 +231,84 @@ namespace osuCrypto::Subfield mBaseOTs(i) = baseMessages[i]; } - task<> expand(Socket &chls, span value, block seed, span output, PprfOutputFormat oFormat, - bool activeChildXorDelta, u64 numThreads) { - MatrixView o(output.data(), output.size(), 1); - return expand(chls, value, seed, o, oFormat, activeChildXorDelta, numThreads); - } + //task<> expand(Socket& chls, span value, block seed, span output, PprfOutputFormat oFormat, + // bool programPuncturedPoint, u64 numThreads) { + // MatrixView o(output.data(), output.size(), 1); + // return expand(chls, value, seed, o, oFormat, programPuncturedPoint, numThreads); + //} task<> expand( - Socket &chl, - span value, - block seed, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 numThreads) { - if (activeChildXorDelta) + Socket& chl, + const VecF& value, + block seed, + VecF& output, + PprfOutputFormat oFormat, + bool programPuncturedPoint, + u64 numThreads) { + if (programPuncturedPoint) setValue(value); setTimePoint("SilentMultiPprfSender.start"); validateExpandFormat(oFormat, output, mDomain, mPntCount); - MC_BEGIN(task<>, this, numThreads, oFormat, output, seed, &chl, activeChildXorDelta, - i = u64{}, - mTreeAllocDepth = u64{}, - tree = span < AlignedArray>{}, - levels = std::vector> > {}, - lastLevel = span < AlignedArray>{}, - buff = std::vector{}, - sums = span < std::array, 2>>{}, - last = span < std::array>{} - ); - - //if (oFormat == PprfOutputFormat::Callback && numThreads > 1) - // throw RTE_LOC; + MC_BEGIN(task<>, this, numThreads, oFormat, output, seed, &chl, programPuncturedPoint, + treeIndex = u64{}, + tree = span>{}, + levels = std::vector> >{}, + leafIndex = u64{}, + leafLevelPtr = (VecF*)nullptr, + leafLevel = VecF{}, + buff = std::vector{}, + encSums = span, 2>>{}, + leafMsgs = span{} - mTreeAllocDepth = mDepth + 1; // Subfield - mTreeAlloc.reserve(numThreads, (1ull << mTreeAllocDepth) + 2); - setTimePoint("SilentMultiPprfSender.reserve"); + ); - levels.resize(mDepth + 1); - allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); + mTreeAlloc.reserve(numThreads, (1ull << mDepth) + 2); + setTimePoint("SilentMultiPprfSender.reserve"); - for (i = 0; i < mPntCount; i += 8) { - // for interleaved format, the last level of the tree - // is simply the output. - // Subfield: use lastLevel - if (oFormat == PprfOutputFormat::Interleaved) { - auto b = (AlignedArray *) output.data(); - auto forest = i / 8; - b += forest * mDomain; - lastLevel = span < AlignedArray>(b, mDomain); + levels.resize(mDepth); + allocateExpandTree(mTreeAlloc, tree, levels); -// auto b = (AlignedArray *) output.data(); -// auto forest = i / 8; -// b += forest * mDomain; -// -// levels.back() = span < AlignedArray> -// (b, mDomain); - } else { - throw RTE_LOC; - } + for (treeIndex = 0; treeIndex < mPntCount; treeIndex += 8) + { + // for interleaved format, the leaf level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + leafIndex = treeIndex * mDomain; + leafLevelPtr = &output; + } + else + { + // we will use leaf level as a buffer before + // copying the result to the output. + leafIndex = 0; + CoeffCtx::resize(leafLevel, mDomain * 8); + leafLevelPtr = &leafLevel; + } - // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); + // allocate the send buffer and partition it. + allocateExpandBuffer(mDepth, programPuncturedPoint, buff, encSums, leafMsgs); - // exapnd the tree - expandOne(seed, i, activeChildXorDelta, levels, lastLevel, sums, last); + // exapnd the tree + expandOne(seed, treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs); - MC_AWAIT(chl.send(std::move(buff))); + MC_AWAIT(chl.send(std::move(buff))); - // if we aren't interleaved, we need to copy the - // last layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) { - // Subfield: no need to copyOut - throw RTE_LOC; -// copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); - } + // if we aren't interleaved, we need to copy the + // leaf layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); - } + } - mBaseOTs = {}; - mTreeAlloc.del(tree); - mTreeAlloc.clear(); + mBaseOTs = {}; + mTreeAlloc.del(tree); + mTreeAlloc.clear(); - setTimePoint("SilentMultiPprfSender.de-alloc"); + setTimePoint("SilentMultiPprfSender.de-alloc"); MC_END(); } @@ -321,9 +319,10 @@ namespace osuCrypto::Subfield if (value.size() == 1) { std::fill(mValue.begin(), mValue.end(), value[0]); - } else { + } + else { if ((u64)value.size() != mPntCount) - throw RTE_LOC; + throw RTE_LOC; std::copy(value.begin(), value.end(), mValue.begin()); } @@ -337,15 +336,15 @@ namespace osuCrypto::Subfield } void expandOne( - block aesSeed, - u64 treeIdx, - bool programActivePath, - span >> levels, - span < AlignedArray > lastLevel, - span , 2>> encSums, - span > lastOts) { - // The number of real trees for this iteration. - auto min = std::min(8, mPntCount - treeIdx); + block aesSeed, + u64 treeIdx, + bool programPuncturedPoint, + span>> levels, + VecF& leafLevel, + u64 leafOffset, + span, 2>> encSums, + span leafMsgs) + { // the first level should be size 1, the root of the tree. // we will populate it with random seeds using aesSeed in counter mode @@ -353,202 +352,267 @@ namespace osuCrypto::Subfield assert(levels[0].size() == 1); mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); - assert(encSums.size() == mDepth - programActivePath); - assert(encSums.size() < 24); + assert(encSums.size() == mDepth - 1); // space for our sums of each level. Should always be less then // 24 levels... If not increase the limit or make it a vector. - std::array, 2>, 24> sums; - memset(&sums, 0, sizeof(sums)); - - // Subfield: lastSums - std::array, 2> lastSums{}; + std::array, 2> sums; + // use the optimized approach for intern nodes of the tree // For each level perform the following. - for (u64 d = 0; d < mDepth; ++d) { + for (u64 d = 0; d < mDepth - 1; ++d) + { + // clear the sums + memset(&sums, 0, sizeof(sums)); + + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + // The previous level of the GGM tree. - auto level0 = levels[d]; + auto parents = levels[d]; // The next level of theGGM tree that we are populating. - auto level1 = levels[d + 1]; + auto children = levels[d + 1]; - // The total number of parents in this level. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); + // For each child, populate the child by expanding the parent. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) + { + // The value of the parent. + auto& parent = parents.data()[parentIdx]; + + auto& child0 = children.data()[childIdx]; + auto& child1 = children.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + // inspired by the Expand Accumualte idea to + // use + // + // child0 = AES(parent) ^ parent + // child1 = AES(parent) + parent + // + // but instead we are a bit more conservative and + // compute + // + // child0 = AES:Round(AES(parent), parent) + // = AES:Round(AES(parent), 0) ^ parent + // child1 = AES(parent) + parent + // + // That is, we applies an additional AES round function + // to the first child before XORing it with parent. + child0[0] = AES::roundEnc(child1[0], parent[0]); + child0[1] = AES::roundEnc(child1[1], parent[1]); + child0[2] = AES::roundEnc(child1[2], parent[2]); + child0[3] = AES::roundEnc(child1[3], parent[3]); + child0[4] = AES::roundEnc(child1[4], parent[4]); + child0[5] = AES::roundEnc(child1[5], parent[5]); + child0[6] = AES::roundEnc(child1[6], parent[6]); + child0[7] = AES::roundEnc(child1[7], parent[7]); + + // Update the running sums for this level. We keep + // a left and right totals for each level. + sums[0][0] = sums[0][0] ^ child0[0]; + sums[0][1] = sums[0][1] ^ child0[1]; + sums[0][2] = sums[0][2] ^ child0[2]; + sums[0][3] = sums[0][3] ^ child0[3]; + sums[0][4] = sums[0][4] ^ child0[4]; + sums[0][5] = sums[0][5] ^ child0[5]; + sums[0][6] = sums[0][6] ^ child0[6]; + sums[0][7] = sums[0][7] ^ child0[7]; + + // child1 = AES(parent) + parent + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + sums[1][0] = sums[1][0] ^ child1[0]; + sums[1][1] = sums[1][1] ^ child1[1]; + sums[1][2] = sums[1][2] ^ child1[2]; + sums[1][3] = sums[1][3] ^ child1[3]; + sums[1][4] = sums[1][4] ^ child1[4]; + sums[1][5] = sums[1][5] ^ child1[5]; + sums[1][6] = sums[1][6] ^ child1[6]; + sums[1][7] = sums[1][7] ^ child1[7]; - // use the optimized approach for intern nodes of the tree - if (d + 1 < mDepth && 0) { -// // For each child, populate the child by expanding the parent. -// for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) { -// // The value of the parent. -// auto &parent = level0.data()[parentIdx]; -// -// auto &child0 = level1.data()[childIdx]; -// auto &child1 = level1.data()[childIdx + 1]; -// mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); -// -// child0[0] = child1[0] ^ parent[0]; -// child0[1] = child1[1] ^ parent[1]; -// child0[2] = child1[2] ^ parent[2]; -// child0[3] = child1[3] ^ parent[3]; -// child0[4] = child1[4] ^ parent[4]; -// child0[5] = child1[5] ^ parent[5]; -// child0[6] = child1[6] ^ parent[6]; -// child0[7] = child1[7] ^ parent[7]; -// -// // Update the running sums for this level. We keep -// // a left and right totals for each level. -// auto &sum = sums[d]; -// sum[0][0] = sum[0][0] ^ child0[0]; -// sum[0][1] = sum[0][1] ^ child0[1]; -// sum[0][2] = sum[0][2] ^ child0[2]; -// sum[0][3] = sum[0][3] ^ child0[3]; -// sum[0][4] = sum[0][4] ^ child0[4]; -// sum[0][5] = sum[0][5] ^ child0[5]; -// sum[0][6] = sum[0][6] ^ child0[6]; -// sum[0][7] = sum[0][7] ^ child0[7]; -// -// child1[0] = child1[0] + parent[0]; -// child1[1] = child1[1] + parent[1]; -// child1[2] = child1[2] + parent[2]; -// child1[3] = child1[3] + parent[3]; -// child1[4] = child1[4] + parent[4]; -// child1[5] = child1[5] + parent[5]; -// child1[6] = child1[6] + parent[6]; -// child1[7] = child1[7] + parent[7]; -// -// sum[1][0] = sum[1][0] ^ child1[0]; -// sum[1][1] = sum[1][1] ^ child1[1]; -// sum[1][2] = sum[1][2] ^ child1[2]; -// sum[1][3] = sum[1][3] ^ child1[3]; -// sum[1][4] = sum[1][4] ^ child1[4]; -// sum[1][5] = sum[1][5] ^ child1[5]; -// sum[1][6] = sum[1][6] ^ child1[6]; -// sum[1][7] = sum[1][7] ^ child1[7]; -// -// } - } else { - // for the leaf nodes we need to hash both children. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) { - // The value of the parent. - auto &parent = level0.data()[parentIdx]; - - // The bit that indicates if we are on the left child (0) - // or on the right child (1). - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) { - // The child that we will write in this iteration. - auto &child = level1[childIdx]; - - // The sum that this child node belongs to. - auto &sum = sums[d][keep]; - - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gAes[keep].hashBlocks<8>(parent.data(), child.data()); - - if (d == mDepth - 1) { - // Subfield - auto& realChild = lastLevel[childIdx]; - auto& lastSum = lastSums[keep]; - realChild[0] = TypeTrait::fromBlock(child[0]); - lastSum[0] = TypeTrait::plus(lastSum[0], realChild[0]); - realChild[1] = TypeTrait::fromBlock(child[1]); - lastSum[1] = TypeTrait::plus(lastSum[1], realChild[1]); - realChild[2] = TypeTrait::fromBlock(child[2]); - lastSum[2] = TypeTrait::plus(lastSum[2], realChild[2]); - realChild[3] = TypeTrait::fromBlock(child[3]); - lastSum[3] = TypeTrait::plus(lastSum[3], realChild[3]); - realChild[4] = TypeTrait::fromBlock(child[4]); - lastSum[4] = TypeTrait::plus(lastSum[4], realChild[4]); - realChild[5] = TypeTrait::fromBlock(child[5]); - lastSum[5] = TypeTrait::plus(lastSum[5], realChild[5]); - realChild[6] = TypeTrait::fromBlock(child[6]); - lastSum[6] = TypeTrait::plus(lastSum[6], realChild[6]); - realChild[7] = TypeTrait::fromBlock(child[7]); - lastSum[7] = TypeTrait::plus(lastSum[7], realChild[7]); - } else { - // Update the running sums for this level. We keep - // a left and right totals for each level. - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } - } - } + } + + // encrypt the sums and write them to the output. + for (u64 j = 0; j < 8; ++j) + { + encSums[d][0][j] = sums[0][j] ^ mBaseOTs[treeIdx + j][d][0]; + encSums[d][1][j] = sums[1][j] ^ mBaseOTs[treeIdx + j][d][1]; } } - // For all but the last level, mask the sums with the - // OT strings and send them over. - for (u64 d = 0; d < mDepth - programActivePath; ++d) { - for (u64 j = 0; j < min; ++j) { - encSums[d][0][j] = sums[d][0][j] ^ mBaseOTs[treeIdx + j][d][0]; - encSums[d][1][j] = sums[d][1][j] ^ mBaseOTs[treeIdx + j][d][1]; + + auto d = mDepth - 1; + + // The previous level of the GGM tree. + auto level0 = levels[d]; + + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // The next level of theGGM tree that we are populating. + std::array child; + + // clear the sums + std::array, 2> leafSums; + CoeffCtx::resize(leafSums[0], 8); + CoeffCtx::resize(leafSums[1], 8); + CoeffCtx::zero(leafSums[0].begin(), leafSums[0].end()); + CoeffCtx::zero(leafSums[1].begin(), leafSums[1].end()); + + // for the leaf nodes we need to hash both children. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + { + // The value of the parent. + auto& parent = level0.data()[parentIdx]; + + // The bit that indicates if we are on the left child (0) + // or on the right child (1). + for (u64 keep = 0; keep < 2; ++keep, ++childIdx, leafOffset += 8) + { + // The child that we will write in this iteration. + + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); + + CoeffCtx::fromBlock(leafLevel[leafOffset + 0], child[0]); + CoeffCtx::fromBlock(leafLevel[leafOffset + 1], child[1]); + CoeffCtx::fromBlock(leafLevel[leafOffset + 2], child[2]); + CoeffCtx::fromBlock(leafLevel[leafOffset + 3], child[3]); + CoeffCtx::fromBlock(leafLevel[leafOffset + 4], child[4]); + CoeffCtx::fromBlock(leafLevel[leafOffset + 5], child[5]); + CoeffCtx::fromBlock(leafLevel[leafOffset + 6], child[6]); + CoeffCtx::fromBlock(leafLevel[leafOffset + 7], child[7]); + + // leafSum += child + auto& leafSum = leafSums[keep]; + CoeffCtx::plus(leafSum[0], leafSum[0], leafLevel[leafOffset + 0]); + CoeffCtx::plus(leafSum[1], leafSum[1], leafLevel[leafOffset + 1]); + CoeffCtx::plus(leafSum[2], leafSum[2], leafLevel[leafOffset + 2]); + CoeffCtx::plus(leafSum[3], leafSum[3], leafLevel[leafOffset + 3]); + CoeffCtx::plus(leafSum[4], leafSum[4], leafLevel[leafOffset + 4]); + CoeffCtx::plus(leafSum[5], leafSum[5], leafLevel[leafOffset + 5]); + CoeffCtx::plus(leafSum[6], leafSum[6], leafLevel[leafOffset + 6]); + CoeffCtx::plus(leafSum[7], leafSum[7], leafLevel[leafOffset + 7]); } } - if (programActivePath) { - // For the last level, we are going to do something special. + + if (programPuncturedPoint) + { + // For the leaf level, we are going to do something special. // The other party is currently missing both leaf children of - // the active parent. Since this is the last level, we want + // the active parent. Since this is the leaf level, we want // the inactive child to just be the normal value but the // active child should be the correct value XOR the delta. // This will be done by sending the sums and the sums plus // delta and ensure that they can only decrypt the correct ones. - auto d = mDepth - 1; - assert(lastOts.size() == min); - for (u64 j = 0; j < min; ++j) { - // Construct the sums where we will allow the delta (mValue) - // to either be on the left child or right child depending - // on which has the active path. - lastOts[j][0] = lastSums[0][j]; - lastOts[j][1] = TypeTrait::plus(lastSums[1][j], mValue[treeIdx + j]); - lastOts[j][2] = lastSums[1][j]; - lastOts[j][3] = TypeTrait::plus(lastSums[0][j], mValue[treeIdx + j]); - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - std::array masks, maskIn; - maskIn[0] = mBaseOTs[treeIdx + j][d][0]; - maskIn[1] = mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; - maskIn[2] = mBaseOTs[treeIdx + j][d][1]; - maskIn[3] = mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; - mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); - - // Add the OT masks to the sums and send them over. - lastOts[j][0] = TypeTrait::plus(lastOts[j][0], TypeTrait::fromBlock(masks[0])); - lastOts[j][1] = TypeTrait::plus(lastOts[j][1], TypeTrait::fromBlock(masks[1])); - lastOts[j][2] = TypeTrait::plus(lastOts[j][2], TypeTrait::fromBlock(masks[2])); - lastOts[j][3] = TypeTrait::plus(lastOts[j][3], TypeTrait::fromBlock(masks[3])); + CoeffCtx::template Vec leafOts; + CoeffCtx::resize(leafOts, 2); + PRNG otMasker; + + for (u64 j = 0; j < 8; ++j) + { + // we will construct two OT strings. Let + // s0, s1 be the left and right child sums. + // + // m0 = (s0 , s1 + val) + // m1 = (s0 + val, s1 ) + // + // these will be encrypted by the OT keys + for (u64 k = 0; k < 2; ++k) + { + if (k == 0) + { + // m0 = (s0, s1 + val) + CoeffCtx::copy(leafOts[0], leafSums[0][j]); + CoeffCtx::plus(leafOts[1], leafSums[1][j], mValue[treeIdx + j]); + } + else + { + // m1 = (s0+val, s1) + CoeffCtx::plus(leafOts[0], leafSums[0][j], mValue[treeIdx + j]); + CoeffCtx::copy(leafOts[1], leafSums[1][j]); + } + + // copy m0 into the output buffer. + span buff = leafMsgs.subspan(0, 2 * CoeffCtx::byteSize()); + leafMsgs = leafMsgs.subspan(buff.size()); + CoeffCtx::serialize(buff, leafOts); + + // encrypt the output buffer. + otMasker.SetSeed(mBaseOTs[treeIdx + j][d][k], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + } + } + } + else + { + CoeffCtx::template Vec leafOts; + CoeffCtx::resize(leafOts, 1); + PRNG otMasker; + + for (u64 j = 0; j < 8; ++j) + { + for (u64 k = 0; k < 2; ++k) + { + // copy the sum k into the output buffer. + CoeffCtx::copy(leafOts[0], leafSums[k][j]); + span buff = leafMsgs.subspan(0, CoeffCtx::byteSize()); + leafMsgs = leafMsgs.subspan(buff.size()); + CoeffCtx::serialize(buff, leafOts); + + // encrypt the output buffer. + otMasker.SetSeed(mBaseOTs[treeIdx + j][d][k], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + } } } + + assert(leafMsgs.size() == 0); } + + }; - template> + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > class SilentSubfieldPprfReceiver : public TimerAdapter { public: u64 mDomain = 0, mDepth = 0, mPntCount = 0; + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; std::vector mPoints; Matrix mBaseOTs; + Matrix mBaseChoices; - bool mPrint = false; + TreeAllocator mTreeAlloc; - block mDebugValue; - std::function>)> mOutputFn; - std::function fromBlock; + + std::function mOutputFn; SilentSubfieldPprfReceiver() = default; SilentSubfieldPprfReceiver(const SilentSubfieldPprfReceiver&) = delete; @@ -581,43 +645,43 @@ namespace osuCrypto::Subfield u64 idx; switch (format) { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - } while (idx >= modulus); + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + do { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = prng.getBit(); + idx = getActivePath(mBaseChoices[i]); + } while (idx >= modulus); - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - - if (modulus > mPntCount * mDomain) - throw std::runtime_error("modulus too big. " LOCATION); - if (modulus < mPntCount * mDomain / 2) - throw std::runtime_error("modulus too small. " LOCATION); - - // make sure that at least the first element of this tree - // is within the modulus. - idx = interleavedPoint(0, i, mPntCount, mDomain, format); - if (idx >= modulus) - throw RTE_LOC; + break; + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::Callback: + + if (modulus > mPntCount * mDomain) + throw std::runtime_error("modulus too big. " LOCATION); + if (modulus < mPntCount * mDomain / 2) + throw std::runtime_error("modulus too small. " LOCATION); + + // make sure that at least the first element of this tree + // is within the modulus. + idx = interleavedPoint(0, i, mPntCount, mDomain, format); + if (idx >= modulus) + throw RTE_LOC; - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); + do { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = prng.getBit(); + idx = getActivePath(mBaseChoices[i]); - idx = interleavedPoint(idx, i, mPntCount, mDomain, format); - } while (idx >= modulus); + idx = interleavedPoint(idx, i, mPntCount, mDomain, format); + } while (idx >= modulus); - break; - default: - throw RTE_LOC; - break; + break; + default: + throw RTE_LOC; + break; } } @@ -645,24 +709,24 @@ namespace osuCrypto::Subfield switch (format) { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - if (getActivePath(mBaseChoices[i]) >= mDomain) - throw RTE_LOC; - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - { - auto idx = getActivePath(mBaseChoices[i]); - auto idx2 = interleavedPoint(idx, i, mPntCount, mDomain, format); - if(idx2 > mPntCount * mDomain) - throw std::runtime_error("the base ot choice bits index outside of the domain. see sampleChoiceBits(...). " LOCATION); - break; - } - default: + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + if (getActivePath(mBaseChoices[i]) >= mDomain) throw RTE_LOC; - break; + + break; + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::Callback: + { + auto idx = getActivePath(mBaseChoices[i]); + auto idx2 = interleavedPoint(idx, i, mPntCount, mDomain, format); + if (idx2 > mPntCount * mDomain) + throw std::runtime_error("the base ot choice bits index outside of the domain. see sampleChoiceBits(...). " LOCATION); + break; + } + default: + throw RTE_LOC; + break; } } } @@ -702,111 +766,102 @@ namespace osuCrypto::Subfield { switch (format) { - case PprfOutputFormat::ByLeafIndex: - case PprfOutputFormat::ByTreeIndex: + case PprfOutputFormat::ByLeafIndex: + case PprfOutputFormat::ByTreeIndex: - memset(points.data(), 0, points.size() * sizeof(u64)); - for (u64 j = 0; j < mPntCount; ++j) - { - points[j] = getActivePath(mBaseChoices[j]); - } + memset(points.data(), 0, points.size() * sizeof(u64)); + for (u64 j = 0; j < mPntCount; ++j) + { + points[j] = getActivePath(mBaseChoices[j]); + } - break; - case PprfOutputFormat::Interleaved: - case PprfOutputFormat::Callback: + break; + case PprfOutputFormat::Interleaved: + case PprfOutputFormat::Callback: - if ((u64)points.size() != mPntCount) + if ((u64)points.size() != mPntCount) + throw RTE_LOC; + if (points.size() % 8) throw RTE_LOC; - if (points.size() % 8) - throw RTE_LOC; - getPoints(points, PprfOutputFormat::ByLeafIndex); - interleavedPoints(points, mDomain, format); + getPoints(points, PprfOutputFormat::ByLeafIndex); + interleavedPoints(points, mDomain, format); - break; - default: - throw RTE_LOC; - break; + break; + default: + throw RTE_LOC; + break; } } - task<> expand(Socket& chl, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) - { - MatrixView o(output.data(), output.size(), 1); - return expand(chl, o, oFormat, activeChildXorDelta, numThreads); - } - - // activeChildXorDelta says whether the sender is trying to program the + // programPuncturedPoint says whether the sender is trying to program the // active child to be its correct value XOR delta. If it is not, the // active child will just take a random value. - task<> expand(Socket& chl, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + task<> expand(Socket& chl, VecF& output, PprfOutputFormat oFormat, bool programPuncturedPoint, u64 numThreads) { validateExpandFormat(oFormat, output, mDomain, mPntCount); - MC_BEGIN(task<>, this, oFormat, output, &chl, activeChildXorDelta, - i = u64{}, - mTreeAllocDepth = u64{}, - tree = span>{}, - levels = std::vector>>{}, - lastLevel = span < AlignedArray>{}, - buff = std::vector{}, - sums = span, 2>>{}, - last = span>{} + MC_BEGIN(task<>, this, oFormat, output, &chl, programPuncturedPoint, + treeIndex = u64{}, + tree = span>{}, + levels = std::vector>>{}, + leafIndex = u64{}, + leafLevelPtr = (VecF*)nullptr, + leafLevel = VecF{}, + buff = std::vector{}, + encSums = span, 2>>{}, + leafMsgs = span{} ); - setTimePoint("SilentMultiPprfReceiver.start"); - mPoints.resize(roundUpTo(mPntCount, 8)); - getPoints(mPoints, PprfOutputFormat::ByLeafIndex); - - mTreeAllocDepth = mDepth + 1; // Subfield - mTreeAlloc.reserve(1, (1ull << mTreeAllocDepth) + 2); - setTimePoint("SilentMultiPprfSender.reserve"); + setTimePoint("SilentMultiPprfReceiver.start"); + mPoints.resize(roundUpTo(mPntCount, 8)); + getPoints(mPoints, PprfOutputFormat::ByLeafIndex); - levels.resize(mDepth + 1); - allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); + mTreeAlloc.reserve(1, (1ull << mDepth) + 2); + setTimePoint("SilentMultiPprfSender.reserve"); - for (i = 0; i < mPntCount; i += 8) - { - // for interleaved format, the last level of the tree - // is simply the output. - if (oFormat == PprfOutputFormat::Interleaved) - { - // Subfield - auto b = (AlignedArray *) output.data(); - auto forest = i / 8; - b += forest * mDomain; - lastLevel = span < AlignedArray>(b, mDomain); + levels.resize(mDepth); + allocateExpandTree(mDepth, mTreeAlloc, tree, levels); -// auto b = (AlignedArray*)output.data(); -// auto forest = i / 8; -// b += forest * mDomain; -// levels.back() = span>(b, mDomain); - } + for (treeIndex = 0; treeIndex < mPntCount; treeIndex += 8) + { + // for interleaved format, the leaf level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + leafIndex = treeIndex * mDomain; + leafLevelPtr = &output; + } + else + { + // we will use leaf level as a buffer before + // copying the result to the output. + leafIndex = 0; + CoeffCtx::resize(leafLevel, mDomain * 8); + leafLevelPtr = &leafLevel; + } - // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); + // allocate the send buffer and partition it. + allocateExpandBuffer(mDepth, programPuncturedPoint, buff, encSums, leafMsgs); - MC_AWAIT(chl.recv(buff)); + MC_AWAIT(chl.recv(buff)); - // exapnd the tree - expandOne(i, activeChildXorDelta, levels, lastLevel, sums, last); + // exapnd the tree + expandOne(treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs); - // if we aren't interleaved, we need to copy the - // last layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) { - // Subfield: no need to copyOut - throw RTE_LOC; -// copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); - } - } + // if we aren't interleaved, we need to copy the + // leaf layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); + } - setTimePoint("SilentMultiPprfReceiver.join"); + setTimePoint("SilentMultiPprfReceiver.join"); - mBaseOTs = {}; - mTreeAlloc.del(tree); - mTreeAlloc.clear(); + mBaseOTs = {}; + mTreeAlloc.del(tree); + mTreeAlloc.clear(); - setTimePoint("SilentMultiPprfReceiver.de-alloc"); + setTimePoint("SilentMultiPprfReceiver.de-alloc"); MC_END(); } @@ -820,21 +875,22 @@ namespace osuCrypto::Subfield mPntCount = 0; } + //treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs void expandOne( - u64 treeIdx, - bool programActivePath, - span>> levels, - span> lastLevel, - span, 2>> theirSums, - span> lastOts) + u64 treeIdx, + bool programPuncturedPoint, + span>> levels, + VecF leafLevel, + const u64 outputOffset, + span, 2>> theirSums, + span leafMsg) { - // This thread will process 8 trees at a time. + // We will process 8 trees at a time. // special case for the first level. auto l1 = levels[1]; for (u64 i = 0; i < 8; ++i) { - // For the non-active path, set the child of the root node // as the OT message XOR'ed with the correction sum. int notAi = mBaseChoices[i + treeIdx][0]; @@ -845,22 +901,21 @@ namespace osuCrypto::Subfield // space for our sums of each level. std::array, 2> mySums; - // Subfield: lastSums - std::array, 2> lastSums{}; + // this will be the value of both children of active an parent + // before the active child is updated. We will need to subtract + // this value as the main loop does not distinguish active parents. + std::array inactiveChildValues; + inactiveChildValues[0] = AES::roundEnc(mAesFixedKey.ecbEncBlock(ZeroBlock), ZeroBlock); + inactiveChildValues[1] = mAesFixedKey.ecbEncBlock(ZeroBlock); // For all other levels, expand the GGM tree and add in // the correction along the active path. - for (u64 d = 1; d < mDepth; ++d) + for (u64 d = 1; d < mDepth - 1; ++d) { - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = levels[d]; - - // The next level that we want to construct. - auto level1 = levels[d + 1]; - - // Zero out the previous sums. - memset(mySums.data(), 0, sizeof(mySums)); + // initialized the sums with inactiveChildValue so that + // it will cancel when we expand the actual inactive child. + std::fill(mySums[0].begin(), mySums[0].end(), inactiveChildValues[0]); + std::fill(mySums[1].begin(), mySums[1].end(), inactiveChildValues[1]); // We will iterate over each node on this level and // expand it into it's two children. Note that the @@ -868,172 +923,193 @@ namespace osuCrypto::Subfield // overwrite whatever the value was. This is an optimization. auto width = divCeil(mDomain, 1ull << (mDepth - d)); - // for internal nodes we the optimized approach. - if (d + 1 < mDepth && 0) + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; + + // The next level that we want to construct. + auto level1 = levels[d + 1]; + + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) { -// for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) -// { -// // The value of the parent. -// auto parent = level0[parentIdx]; -// -// auto& child0 = level1.data()[childIdx]; -// auto& child1 = level1.data()[childIdx + 1]; -// mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); -// -// child0[0] = child1[0] ^ parent[0]; -// child0[1] = child1[1] ^ parent[1]; -// child0[2] = child1[2] ^ parent[2]; -// child0[3] = child1[3] ^ parent[3]; -// child0[4] = child1[4] ^ parent[4]; -// child0[5] = child1[5] ^ parent[5]; -// child0[6] = child1[6] ^ parent[6]; -// child0[7] = child1[7] ^ parent[7]; -// -// // Update the running sums for this level. We keep -// // a left and right totals for each level. Note that -// // we are actually XOR in the incorrect value of the -// // children of the active parent (assuming !DEBUG_PRINT_PPRF). -// // This is ok since we will later XOR off these incorrect values. -// mySums[0][0] = mySums[0][0] ^ child0[0]; -// mySums[0][1] = mySums[0][1] ^ child0[1]; -// mySums[0][2] = mySums[0][2] ^ child0[2]; -// mySums[0][3] = mySums[0][3] ^ child0[3]; -// mySums[0][4] = mySums[0][4] ^ child0[4]; -// mySums[0][5] = mySums[0][5] ^ child0[5]; -// mySums[0][6] = mySums[0][6] ^ child0[6]; -// mySums[0][7] = mySums[0][7] ^ child0[7]; -// -// child1[0] = child1[0] + parent[0]; -// child1[1] = child1[1] + parent[1]; -// child1[2] = child1[2] + parent[2]; -// child1[3] = child1[3] + parent[3]; -// child1[4] = child1[4] + parent[4]; -// child1[5] = child1[5] + parent[5]; -// child1[6] = child1[6] + parent[6]; -// child1[7] = child1[7] + parent[7]; -// -// mySums[1][0] = mySums[1][0] ^ child1[0]; -// mySums[1][1] = mySums[1][1] ^ child1[1]; -// mySums[1][2] = mySums[1][2] ^ child1[2]; -// mySums[1][3] = mySums[1][3] ^ child1[3]; -// mySums[1][4] = mySums[1][4] ^ child1[4]; -// mySums[1][5] = mySums[1][5] ^ child1[5]; -// mySums[1][6] = mySums[1][6] ^ child1[6]; -// mySums[1][7] = mySums[1][7] ^ child1[7]; -// } + // The value of the parent. + auto parent = level0[parentIdx]; + + auto& child0 = level1.data()[childIdx]; + auto& child1 = level1.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + // inspired by the Expand Accumualte idea to + // use + // + // child0 = AES(parent) ^ parent + // child1 = AES(parent) + parent + // + // but instead we are a bit more conservative and + // compute + // + // child0 = AES:Round(AES(parent), parent) + // = AES:Round(AES(parent), 0) ^ parent + // child1 = AES(parent) + parent + // + // That is, we applies an additional AES round function + // to the first child before XORing it with parent. + child0[0] = AES::roundEnc(child1[0], parent[0]); + child0[1] = AES::roundEnc(child1[1], parent[1]); + child0[2] = AES::roundEnc(child1[2], parent[2]); + child0[3] = AES::roundEnc(child1[3], parent[3]); + child0[4] = AES::roundEnc(child1[4], parent[4]); + child0[5] = AES::roundEnc(child1[5], parent[5]); + child0[6] = AES::roundEnc(child1[6], parent[6]); + child0[7] = AES::roundEnc(child1[7], parent[7]); + + // Update the running sums for this level. We keep + // a left and right totals for each level. Note that + // we are actually XOR in the incorrect value of the + // children of the active parent but this will cancel + // with inactiveChildValue thats already there. + mySums[0][0] = mySums[0][0] ^ child0[0]; + mySums[0][1] = mySums[0][1] ^ child0[1]; + mySums[0][2] = mySums[0][2] ^ child0[2]; + mySums[0][3] = mySums[0][3] ^ child0[3]; + mySums[0][4] = mySums[0][4] ^ child0[4]; + mySums[0][5] = mySums[0][5] ^ child0[5]; + mySums[0][6] = mySums[0][6] ^ child0[6]; + mySums[0][7] = mySums[0][7] ^ child0[7]; + + // child1 = AES(parent) + parent + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + mySums[1][0] = mySums[1][0] ^ child1[0]; + mySums[1][1] = mySums[1][1] ^ child1[1]; + mySums[1][2] = mySums[1][2] ^ child1[2]; + mySums[1][3] = mySums[1][3] ^ child1[3]; + mySums[1][4] = mySums[1][4] ^ child1[4]; + mySums[1][5] = mySums[1][5] ^ child1[5]; + mySums[1][6] = mySums[1][6] ^ child1[6]; + mySums[1][7] = mySums[1][7] ^ child1[7]; } - else - { - // for leaf nodes both children should be hashed. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto parent = level0[parentIdx]; - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - // The child that we will write in this iteration. - auto& child = level1[childIdx]; - - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gAes[keep].hashBlocks<8>(parent.data(), child.data()); - - // Subfield: - if (d == mDepth - 1) { - if (lastLevel.size() <= childIdx) { - // todo: I have fix in my old code, not sure we need this for the new pprf - throw RTE_LOC; - } - auto& realChild = lastLevel[childIdx]; - auto& lastSum = lastSums[keep]; - realChild[0] = TypeTrait::fromBlock(child[0]); - lastSum[0] = TypeTrait::plus(lastSum[0], realChild[0]); - realChild[1] = TypeTrait::fromBlock(child[1]); - lastSum[1] = TypeTrait::plus(lastSum[1], realChild[1]); - realChild[2] = TypeTrait::fromBlock(child[2]); - lastSum[2] = TypeTrait::plus(lastSum[2], realChild[2]); - realChild[3] = TypeTrait::fromBlock(child[3]); - lastSum[3] = TypeTrait::plus(lastSum[3], realChild[3]); - realChild[4] = TypeTrait::fromBlock(child[4]); - lastSum[4] = TypeTrait::plus(lastSum[4], realChild[4]); - realChild[5] = TypeTrait::fromBlock(child[5]); - lastSum[5] = TypeTrait::plus(lastSum[5], realChild[5]); - realChild[6] = TypeTrait::fromBlock(child[6]); - lastSum[6] = TypeTrait::plus(lastSum[6], realChild[6]); - realChild[7] = TypeTrait::fromBlock(child[7]); - lastSum[7] = TypeTrait::plus(lastSum[7], realChild[7]); - } else { - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent (assuming !DEBUG_PRINT_PPRF). - // This is ok since we will later XOR off these incorrect values. - auto& sum = mySums[keep]; - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } - } - } - } - // For everything but the last level we have to - // 1) fix our sums so they dont include the incorrect - // values that are the children of the active parent - // 2) Update the non-active child of the active parent. - if (!programActivePath || d != mDepth - 1) + // we have to update the non-active child of the active parent. + for (u64 i = 0; i < 8; ++i) { - for (u64 i = 0; i < 8; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = mPoints[i + treeIdx]; + // the index of the leaf node that is active. + auto leafIdx = mPoints[i + treeIdx]; - // The index of the active child node. - auto activeChildIdx = leafIdx >> (mDepth - 1 - d); + // The index of the active child node. + auto activeChildIdx = leafIdx >> (mDepth - 1 - d); - // The index of the active child node sibling. - auto inactiveChildIdx = activeChildIdx ^ 1; + // The index of the active child node sibling. + auto inactiveChildIdx = activeChildIdx ^ 1; - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + // our sums & OTs cancel and we are leaf with the + // correct value for the inactive child. + level1[inactiveChildIdx][i] = + theirSums[d][notAi][i] ^ + mySums[notAi][i] ^ + mBaseOTs[i + treeIdx][d]; + + // we have to set the active child to zero so + // the next children are predictable. + level1[activeChildIdx][i] = ZeroBlock; + } + } - auto& inactiveChild = level1[inactiveChildIdx][i]; - // correct the sum value by XORing off the incorrect - auto correctSum = - inactiveChild ^ - theirSums[d][notAi][i]; + auto d = mDepth - 1; + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; - inactiveChild = - correctSum ^ - mySums[notAi][i] ^ - mBaseOTs[i + treeIdx][d]; + // The next level of theGGM tree that we are populating. + std::array child; - } + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // We change the hash function for the leaf so lets update + // inactiveChildValues to use the new hash and subtract + // these from the leafSums + CoeffCtx::template Vec temp; + CoeffCtx::resize(temp, 2); + std::array, 2> leafSums; + for (u64 k = 0; k < 2; ++k) + { + inactiveChildValues[k] = gGgmAes[k].hashBlock(ZeroBlock); + CoeffCtx::fromBlock(temp[k], inactiveChildValues[k]); + + // leafSum = -inactiveChildValues + CoeffCtx::resize(leafSums[k], 8); + CoeffCtx::zero(leafSums[k].begin(), leafSums[k].end()); + CoeffCtx::minus(leafSums[k][0], leafSums[k][0], temp[0]); + for (u64 i = 1; i < 8; ++i) + CoeffCtx::copy(leafSums[k][i], leafSums[k][0]); + } + + // for leaf nodes both children should be hashed. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + { + // The value of the parent. + auto parent = level0[parentIdx]; + + for (u64 keep = 0, outputIdx = outputOffset; keep < 2; ++keep, ++childIdx, outputIdx += 8) + { + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); + + CoeffCtx::fromBlock(leafLevel[outputIdx + 0], child[0]); + CoeffCtx::fromBlock(leafLevel[outputIdx + 1], child[1]); + CoeffCtx::fromBlock(leafLevel[outputIdx + 2], child[2]); + CoeffCtx::fromBlock(leafLevel[outputIdx + 3], child[3]); + CoeffCtx::fromBlock(leafLevel[outputIdx + 4], child[4]); + CoeffCtx::fromBlock(leafLevel[outputIdx + 5], child[5]); + CoeffCtx::fromBlock(leafLevel[outputIdx + 6], child[6]); + CoeffCtx::fromBlock(leafLevel[outputIdx + 7], child[7]); + + auto& leafSum = leafSums[keep]; + CoeffCtx::plus(leafSum[0], leafSum[0], leafLevel[outputIdx + 0]); + CoeffCtx::plus(leafSum[1], leafSum[1], leafLevel[outputIdx + 1]); + CoeffCtx::plus(leafSum[2], leafSum[2], leafLevel[outputIdx + 2]); + CoeffCtx::plus(leafSum[3], leafSum[3], leafLevel[outputIdx + 3]); + CoeffCtx::plus(leafSum[4], leafSum[4], leafLevel[outputIdx + 4]); + CoeffCtx::plus(leafSum[5], leafSum[5], leafLevel[outputIdx + 5]); + CoeffCtx::plus(leafSum[6], leafSum[6], leafLevel[outputIdx + 6]); + CoeffCtx::plus(leafSum[7], leafSum[7], leafLevel[outputIdx + 7]); } } - // last level. - if (programActivePath) + // leaf level. + if (programPuncturedPoint) { - // Now processes the last level. This one is special + // Now processes the leaf level. This one is special // because we must XOR in the correction value as // before but we must also fixed the child value for // the active child. To do this, we will receive 4 // values. Two for each case (left active or right active). - //timer.setTimePoint("recv.recvLast"); + //timer.setTimePoint("recv.recvleaf"); + VecF leafOts; + CoeffCtx::resize(leafOts, 2); + PRNG otMasker; - auto d = mDepth - 1; for (u64 j = 0; j < 8; ++j) { // The index of the child on the active path. @@ -1045,46 +1121,64 @@ namespace osuCrypto::Subfield // The indicator as to the left or right child is inactive auto notAi = inactiveChildIdx & 1; - std::array masks, maskIn; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - maskIn[0] = mBaseOTs[j + treeIdx][d]; - maskIn[1] = mBaseOTs[j + treeIdx][d] ^ AllOneBlock; - mAesFixedKey.hashBlocks<2>(maskIn.data(), masks.data()); - - // now get the chosen message OT strings by XORing - // the expended (random) OT strings with the lastOts values. - auto& ot0 = lastOts[j][2 * notAi + 0]; - auto& ot1 = lastOts[j][2 * notAi + 1]; - ot0 = TypeTrait::minus(ot0, TypeTrait::fromBlock(masks[0])); - ot1 = TypeTrait::minus(ot1, TypeTrait::fromBlock(masks[1])); - - auto& inactiveChild = lastLevel[inactiveChildIdx][j]; - auto& activeChild = lastLevel[activeChildIdx][j]; - - // Fix the sums we computed previously to not include the - // incorrect child values. - auto inactiveSum = TypeTrait::minus(lastSums[notAi][j], inactiveChild); - auto activeSum = TypeTrait::minus(lastSums[notAi ^ 1][j], activeChild); - - // Update the inactive and active child to have to correct - // value by XORing their full sum with out partial sum, which - // gives us exactly the value we are missing. - inactiveChild = TypeTrait::minus(ot0, inactiveSum); - activeChild = TypeTrait::minus(ot1, activeSum); - } - // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); + // offset to the first or second ot message, based on the one we want + auto offset = CoeffCtx::template byteSize() * 2 * notAi; + + // decrypt the ot string + span buff = leafMsg.subspan(offset, CoeffCtx::byteSize() * 2); + leafMsg = leafMsg.subspan(buff.size() * 2); + otMasker.SetSeed(mBaseOTs[j + treeIdx][d], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + CoeffCtx::deserialize(leafOts, buff); - //timer.setTimePoint("recv.expandLast"); + auto out0 = (activeChildIdx & ~1ull) * 8 + j + outputOffset; + auto out1 = (activeChildIdx | 1ull) * 8 + j + outputOffset; + + CoeffCtx::minus(leafLevel[out0], leafOts[0], leafSums[0][j]); + CoeffCtx::minus(leafLevel[out1], leafOts[1], leafSums[1][j]); + } } else { - for (auto j : rng(std::min(8, mPntCount - treeIdx))) + VecF leafOts; + CoeffCtx::resize(leafOts, 1); + PRNG otMasker; + + for (u64 j = 0; j < 8; ++j) { // The index of the child on the active path. auto activeChildIdx = mPoints[j + treeIdx]; - lastLevel[activeChildIdx][j] = F{}; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + // offset to the first or second ot message, based on the one we want + auto offset = CoeffCtx::template byteSize() * notAi; + + // decrypt the ot string + span buff = leafMsg.subspan(offset, CoeffCtx::byteSize()); + leafMsg = leafMsg.subspan(buff.size() * 2); + otMasker.SetSeed(mBaseOTs[j + treeIdx][d], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + CoeffCtx::deserialize(leafOts, buff); + + std::array out{ + (activeChildIdx & ~1ull) * 8 + j + outputOffset, + (activeChildIdx | 1ull) * 8 + j + outputOffset + }; + + auto keep = leafLevel.begin() + out[notAi]; + auto zero = leafLevel.begin() + out[notAi^1]; + + CoeffCtx::minus(*keep, leafOts[0], leafSums[notAi][j]); + CoeffCtx::zero(zero, zero + 1); } } } diff --git a/libOTe/TwoChooseOne/ConfigureCode.cpp b/libOTe/TwoChooseOne/ConfigureCode.cpp index c85019e7..7b84d6bf 100644 --- a/libOTe/TwoChooseOne/ConfigureCode.cpp +++ b/libOTe/TwoChooseOne/ConfigureCode.cpp @@ -8,6 +8,7 @@ #include "libOTe/Tools/QuasiCyclicCode.h" #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" +#include "libOTe/Tools/ExConvCode/ExConvCode2.h" #include namespace osuCrypto { @@ -142,6 +143,47 @@ namespace osuCrypto mEncoder.config(numOTs, numOTs * mScaler, w, a, true); } + + void ExConvConfigure( + u64 numOTs, u64 secParam, + MultType mMultType, + u64& mRequestedNumOTs, + u64& mNumPartitions, + u64& mSizePer, + u64& mN2, + u64& mN, + ExConvCode2& mEncoder + ) + { + u64 a = 24; + auto mScaler = 2; + u64 w; + double minDist; + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + w = 7; + minDist = 0.1; + break; + case osuCrypto::MultType::ExConv21x24: + w = 21; + minDist = 0.15; + break; + default: + throw RTE_LOC; + break; + } + + mRequestedNumOTs = numOTs; + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = roundUpTo((numOTs * mScaler + mNumPartitions - 1) / mNumPartitions, 8); + mN2 = mSizePer * mNumPartitions; + mN = mN2 / mScaler; + + mEncoder.config(numOTs, numOTs * mScaler, w, a, true); + } + + #ifdef ENABLE_INSECURE_SILVER void SilverConfigure( diff --git a/libOTe/TwoChooseOne/ConfigureCode.h b/libOTe/TwoChooseOne/ConfigureCode.h index 47cbaa53..4af551da 100644 --- a/libOTe/TwoChooseOne/ConfigureCode.h +++ b/libOTe/TwoChooseOne/ConfigureCode.h @@ -105,6 +105,19 @@ namespace osuCrypto ExConvCode& mEncoder ); + + class ExConvCode2; + void ExConvConfigure( + u64 numOTs, u64 secParam, + MultType mMultType, + u64& mRequestedNumOTs, + u64& mNumPartitions, + u64& mSizePer, + u64& mN2, + u64& mN, + ExConvCode2& mEncoder + ); + #ifdef ENABLE_INSECURE_SILVER struct SilverEncoder; void SilverConfigure( diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index 3107f0db..24b890e3 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -13,7 +13,7 @@ #include #include #include "libOTe/Tools/QuasiCyclicCode.h" - +#include "libOTe/Tools/Subfield/Subfield.h" namespace osuCrypto { @@ -747,13 +747,6 @@ namespace osuCrypto } else { - // allocate and initialize mC - //if (mChoiceSpanSize < mN2) - //{ - // mChoiceSpanSize = mN2; - // mChoicePtr.reset((new u8[mN2]())); - //} - //else mC.resize(mN2); std::memset(mC.data(), 0, mN2); auto cc = mC.data(); diff --git a/libOTe/Vole/Silent/SilentVoleSender.cpp b/libOTe/Vole/Silent/SilentVoleSender.cpp index ef107a62..a78d6736 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.cpp +++ b/libOTe/Vole/Silent/SilentVoleSender.cpp @@ -132,7 +132,7 @@ namespace osuCrypto if (isConfigured() == false) throw std::runtime_error("configure must be called first"); - return mGen.baseOtCount() + mGapOts.size(); + return mGen.baseOtCount(); } void SilentVoleSender::setSilentBaseOts( @@ -145,11 +145,7 @@ namespace osuCrypto if (noiseDeltaShares.size() != baseVoleCount()) throw RTE_LOC; - auto genOt = sendBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOt = sendBaseOts.subspan(genOt.size(), mGapOts.size()); - - mGen.setBase(genOt); - std::copy(gapOt.begin(), gapOt.end(), mGapOts.begin()); + mGen.setBase(sendBaseOts); mNoiseDeltaShares.resize(noiseDeltaShares.size()); std::copy(noiseDeltaShares.begin(), noiseDeltaShares.end(), mNoiseDeltaShares.begin()); } @@ -160,7 +156,6 @@ namespace osuCrypto u64 secParam) { mBaseType = type; - u64 gap = 0; switch (mMultType) { @@ -186,21 +181,6 @@ namespace osuCrypto #endif break; } -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - SilverConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - break; -#endif case osuCrypto::MultType::ExAcc7: case osuCrypto::MultType::ExAcc11: case osuCrypto::MultType::ExAcc21: @@ -225,7 +205,6 @@ namespace osuCrypto break; } - mGapOts.resize(gap); mGen.configure(mSizePer, mNumPartitions); mState = State::Configured; @@ -299,7 +278,6 @@ namespace osuCrypto Socket& chl) { MC_BEGIN(task<>,this, delta, n, &prng, &chl, - gapVals = std::vector{}, deltaShare = block{}, X = block{}, hash = std::array{}, @@ -339,19 +317,6 @@ namespace osuCrypto mB.resize(0); mB.resize(mN2); - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - for (u64 i = mNumPartitions * mSizePer, j = 0; i < mN2; ++i, ++j) - { - auto v = mGapOts[j][0] ^ mNoiseDeltaShares[mNumPartitions + j]; - gapVals[j] = AES(mGapOts[j][1]).ecbEncBlock(ZeroBlock) ^ v; - mB[i] = mGapOts[j][0]; - } - - if(gapVals.size()) - MC_AWAIT(chl.send(std::move(gapVals))); - if (mTimer) mGen.setTimer(*mTimer); diff --git a/libOTe/Vole/Subfield/NoisyVoleReceiver.h b/libOTe/Vole/Subfield/NoisyVoleReceiver.h index 5cdeb35c..4d739bdf 100644 --- a/libOTe/Vole/Subfield/NoisyVoleReceiver.h +++ b/libOTe/Vole/Subfield/NoisyVoleReceiver.h @@ -18,12 +18,6 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. -// This code implements features described in [Silver: Silent VOLE and Oblivious -// Transfer from Hardness of Decoding Structured LDPC Codes, -// https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative -// Commons Attribution 4.0 International Public License -// (https://creativecommons.org/licenses/by/4.0/legalcode). - #include #if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) @@ -32,18 +26,25 @@ #include "cryptoTools/Crypto/PRNG.h" #include "libOTe/Tools/Coproto.h" #include "libOTe/TwoChooseOne/OTExtInterface.h" +#include "libOTe/Tools/Subfield/Subfield.h" -namespace osuCrypto::Subfield { +namespace osuCrypto { - template - class NoisySubfieldVoleReceiver : public TimerAdapter { + template < + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class NoisySubfieldVoleReceiver : public TimerAdapter + { public: - using F = typename TypeTrait::F; - using G = typename TypeTrait::G; - task<> receive(span y, span z, PRNG& prng, - OtSender& ot, Socket& chl) { + + template + task<> receive(VecG&& y, VecF&& z, PRNG& prng, + OtSender& ot, Socket& chl) + { MC_BEGIN(task<>, this, y, z, &prng, &ot, &chl, - otMsg = AlignedUnVector>{ TypeTrait::bitsF }); + otMsg = AlignedUnVector>{}); setTimePoint("NoisyVoleReceiver.ot.begin"); @@ -56,44 +57,69 @@ namespace osuCrypto::Subfield { MC_END(); } - task<> receive(span y, span z, PRNG& _, + template + task<> receive(VecG&& y, VecF&& z, PRNG& _, span> otMsg, - Socket& chl) { - MC_BEGIN(task<>, this, y, z, otMsg, &chl, - msg = Matrix{}, + Socket& chl) + { + MC_BEGIN(task<>, this, y, z, otMsg, &chl, + buff = std::vector{}, + msg = typename CoeffCtx::Vec{}, + temp = typename CoeffCtx::Vec{}, prng = std::move(PRNG{}) ); - if (otMsg.size() != TypeTrait::bitsF) throw RTE_LOC; - if (y.size() != z.size()) throw RTE_LOC; - if (z.size() == 0) throw RTE_LOC; + if (y.size() != z.size()) + throw RTE_LOC; + if (z.size() == 0) + throw RTE_LOC; setTimePoint("NoisyVoleReceiver.begin"); - memset(z.data(), 0, TypeTrait::bytesF * z.size()); - msg.resize(otMsg.size(), z.size(), AllocType::Uninitialized); + CoeffCtx::zero(z.begin(), z.end()); + CoeffCtx::resize(msg, otMsg.size() * z.size()); + CoeffCtx::resize(temp, 2); + + for (size_t i = 0, k = 0; i < otMsg.size(); ++i) + { + prng.SetSeed(otMsg[i][0], z.size()); + + // t1 = 2^i + CoeffCtx::pow(temp[1], i); - for (size_t ii = 0; ii < TypeTrait::bitsF; ++ii) { - prng.SetSeed(otMsg[ii][0], z.size()); - auto& buffer = prng.mBuffer; - auto pow = TypeTrait::pow(ii); - for (size_t j = 0; j < y.size(); ++j) { - auto bufj = TypeTrait::fromBlock(buffer[j]); - z[j] = TypeTrait::plus(z[j], bufj); - F yy = TypeTrait::mul(pow, y[j]); + for (size_t j = 0; j < y.size(); ++j, ++k) + { + // msg[i,j] = otMsg[i,j,0] + CoeffCtx::fromBlock(msg[k], prng.get()); - msg(ii, j) = TypeTrait::plus(yy, bufj); + // z[j] -= otMsg[i,j,0] + CoeffCtx::minus(z[j], z[j], msg[k]); + + // temp = 2^i * y[j] + CoeffCtx::mul(temp[0], temp[1], y[j]); + + // msg[i,j] = otMsg[i,j,0] + 2^i * y[j] + CoeffCtx::plus(msg[k], msg[k], temp[0]); } - prng.SetSeed(otMsg[ii][1], z.size()); + k -= y.size(); + prng.SetSeed(otMsg[i][1], z.size()); + + for (size_t j = 0; j < y.size(); ++j, ++k) + { + // temp = otMsg[i,j,1] + CoeffCtx::fromBlock(temp[0], prng.get()); - for (size_t j = 0; j < y.size(); ++j) { // enc one message under the OT msg. - msg(ii, j) = TypeTrait::plus(msg(ii, j), TypeTrait::fromBlock(prng.mBuffer[j])); + // msg[i,j] = (otMsg[i,j,0] + 2^i * y[j]) - otMsg[i,j,1] + CoeffCtx::minus(msg[k], msg[k], temp[0]); } } - MC_AWAIT(chl.send(std::move(msg))); + buff.resize(msg.size() * CoeffCtx::byteSize()); + CoeffCtx::serialize(buff, msg); + + MC_AWAIT(chl.send(std::move(buff))); setTimePoint("NoisyVoleReceiver.done"); MC_END(); diff --git a/libOTe/Vole/Subfield/NoisyVoleSender.h b/libOTe/Vole/Subfield/NoisyVoleSender.h index 7de8b989..81aede7e 100644 --- a/libOTe/Vole/Subfield/NoisyVoleSender.h +++ b/libOTe/Vole/Subfield/NoisyVoleSender.h @@ -33,18 +33,26 @@ #include "cryptoTools/Crypto/PRNG.h" #include "libOTe/Tools/Coproto.h" #include "libOTe/TwoChooseOne/OTExtInterface.h" +#include "libOTe/Tools/Subfield/Subfield.h" + +namespace osuCrypto { + template < + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class NoisySubfieldVoleSender : public TimerAdapter + { -namespace osuCrypto::Subfield { - template - class NoisySubfieldVoleSender : public TimerAdapter { public: - using F = typename TypeTrait::F; - using G = typename TypeTrait::G; - task<> send(F x, span z, PRNG& prng, + + template + task<> send(F x, FVec&& z, PRNG& prng, OtReceiver& ot, Socket& chl) { MC_BEGIN(task<>, this, x, z, &prng, &ot, &chl, - bv = TypeTrait::BitVectorF(x), - otMsg = AlignedUnVector{ TypeTrait::bitsF }); + bv = CoeffCtx::binaryDecomposition(x), + otMsg = AlignedUnVector{ }); + otMsg.resize(bv.size()); setTimePoint("NoisyVoleSender.ot.begin"); @@ -56,34 +64,56 @@ namespace osuCrypto::Subfield { MC_END(); } - task<> send(F x, span z, PRNG& _, + template + task<> send(F x, FVec&& z, PRNG& _, span otMsg, Socket& chl) { MC_BEGIN(task<>, this, x, z, otMsg, &chl, prng = std::move(PRNG{}), - msg = Matrix{}, + buffer = std::vector{}, + msg = typename CoeffCtx::Vec{}, + temp = typename CoeffCtx::Vec{}, xb = BitVector{}); - if (otMsg.size() != TypeTrait::bitsF) + xb = CoeffCtx::binaryDecomposition(x); + + if (otMsg.size() != xb.size()) throw RTE_LOC; setTimePoint("NoisyVoleSender.main"); - memset(z.data(), 0, TypeTrait::bytesF * z.size()); - msg.resize(otMsg.size(), z.size(), AllocType::Uninitialized); + // z = 0; + CoeffCtx::zero(z.begin(), z.end()); - MC_AWAIT(chl.recv(msg)); + // receive the the excrypted one shares. + buffer.resize(otMsg.size() * z.size() * CoeffCtx::byteSize()); + MC_AWAIT(chl.recv(buffer)); + CoeffCtx::deserialize(msg, buffer); setTimePoint("NoisyVoleSender.recvMsg"); - xb = TypeTrait::BitVectorF(x); - for (size_t i = 0; i < TypeTrait::bitsF; ++i) + temp.resize(1); + for (size_t i = 0, k = 0; i < xb.size(); ++i) { + // expand the zero shares or one share masks prng.SetSeed(otMsg[i], z.size()); - for (u64 j = 0; j < (u64)z.size(); ++j) + // otMsg[i,j, bc[i]] + //auto otMsgi = prng.getBufferSpan(z.size()); + + for (u64 j = 0; j < (u64)z.size(); ++j, ++k) { - F bufj = TypeTrait::fromBlock(prng.mBuffer[j]); - F data = xb[i] ? TypeTrait::minus(msg(i, j), bufj) : bufj; - z[j] = TypeTrait::plus(z[j], data); + // temp = otMsg[i,j, xb[i]] + CoeffCtx::fromBlock(temp[0], prng.get()); + + // temp = otMsg[i,j,xb[i]] + xb[i] * msg[i,j] + // = otMsg[i,j,xb[i]] + xb[i] * (otMsg[i,j,0] + 2^i * y[j] - otMsg[i,j,1]) + // = otMsg[i,j,xb[i]] // if 0 + // = otMsg[i,j,0] + 2^i * y[j] // if 1 + // = -z + 2^i * y[j] // if 1 + if (xb[i]) + CoeffCtx::plus(temp[0], msg[k], temp[0]); + + // zj += msg - xb[i] * otMsg[i,j] + CoeffCtx::plus(z[j], z[j], temp[0]); } } setTimePoint("NoisyVoleSender.done"); diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h index c6497b3f..8fa52168 100644 --- a/libOTe/Vole/Subfield/SilentVoleReceiver.h +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include #include @@ -29,7 +29,11 @@ namespace osuCrypto::Subfield { - template> + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > class SilentSubfieldVoleReceiver : public TimerAdapter { public: @@ -42,6 +46,9 @@ namespace osuCrypto::Subfield HasBase }; + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; + // The current state of the protocol State mState = State::Default; @@ -51,7 +58,7 @@ namespace osuCrypto::Subfield // The number of OTs actually produced (at least the number requested). u64 mN = 0; - // The length of the noisy vectors (2 * mN for the silver codes). + // The length of the noisy vectors (2 * mN for the most codes). u64 mN2 = 0; // We perform regular LPN, so this is the @@ -70,33 +77,31 @@ namespace osuCrypto::Subfield // the sparse vector. MultType mMultType = DefaultMultType; - ExConvCode mExConvEncoder; + ExConvCode2 mExConvEncoder; // The multi-point punctured PRF for generating // the sparse vectors. - SilentSubfieldPprfReceiver mGen; + SilentSubfieldPprfReceiver mGen; // The internal buffers for holding the expanded vectors. // mA + mB = mC * delta - AlignedUnVector mA; + VecF mA; // mA + mB = mC * delta - AlignedUnVector mC; - - std::vector mGapOts; + VecG mC; u64 mNumThreads = 1; bool mDebug = false; - BitVector mIknpSendBaseChoice, mGapBaseChoice; + BitVector mIknpSendBaseChoice; SilentSecType mMalType = SilentSecType::SemiHonest; block mMalCheckSeed, mMalCheckX, mDeltaShare; - AlignedVector mNoiseDeltaShare; - AlignedVector mNoiseValues; + VecF mNoiseDeltaShare; + VecG mNoiseValues; #ifdef ENABLE_SOFTSPOKEN_OT @@ -113,7 +118,7 @@ namespace osuCrypto::Subfield u64 baseVoleCount() const { - return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); } // // returns true if the IKNP base OTs are currently set. @@ -144,9 +149,9 @@ namespace osuCrypto::Subfield baseOt = BaseOT{}, chl2 = Socket{}, prng2 = std::move(PRNG{}), - noiseVals = std::vector{}, - noiseDeltaShares = std::vector{}, - nv = NoisySubfieldVoleReceiver{} + noiseVals = VecG{}, + noiseDeltaShares = VecF{}, + nv = NoisySubfieldVoleReceiver{} ); @@ -172,7 +177,8 @@ namespace osuCrypto::Subfield // other party will program the PPRF to output their share of delta * noiseVals. // noiseVals = sampleBaseVoleVals(prng); - noiseDeltaShares.resize(noiseVals.size()); + CoeffCtx::resize(noiseDeltaShares, noiseVals.size()); + if (mTimer) nv.setTimer(*mTimer); @@ -237,7 +243,6 @@ namespace osuCrypto::Subfield u64 secParam = 128) { mState = State::Configured; - u64 gap = 0; mBaseType = type; switch (mMultType) @@ -245,14 +250,13 @@ namespace osuCrypto::Subfield case osuCrypto::MultType::ExConv7x24: case osuCrypto::MultType::ExConv21x24: - SubfieldExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); + ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); break; default: throw RTE_LOC; break; } - mGapOts.resize(gap); mGen.configure(mSizePer, mNumPartitions); } @@ -266,7 +270,7 @@ namespace osuCrypto::Subfield if (isConfigured() == false) throw std::runtime_error("configure must be called first"); - return mGen.baseOtCount() + mGapOts.size(); + return mGen.baseOtCount(); } @@ -281,43 +285,27 @@ namespace osuCrypto::Subfield auto choice = mGen.sampleChoiceBits(mN2, getPprfFormat(), prng); - mGapBaseChoice.resize(mGapOts.size()); - mGapBaseChoice.randomize(prng); - choice.append(mGapBaseChoice); - return choice; } - std::vector sampleBaseVoleVals(PRNG& prng) + VecG sampleBaseVoleVals(PRNG& prng) { if (isConfigured() == false) throw RTE_LOC; - if (mGapBaseChoice.size() != mGapOts.size()) - throw std::runtime_error("sampleBaseChoiceBits must be called before sampleBaseVoleVals. " LOCATION); // sample the values of the noisy coordinate of c // and perform a noicy vole to get x+y = mD * c - auto w = mNumPartitions + mGapOts.size(); + auto w = mNumPartitions; std::vector seeds(w); - mNoiseValues.resize(w); + CoeffCtx::resize(mNoiseValues, w); prng.get(seeds.data(), seeds.size()); for (size_t i = 0; i < w; i++) { - mNoiseValues[i] = TypeTrait::fromBlockG(seeds[i]); + CoeffCtx::fromBlock(mNoiseValues[i], seeds[i]); } mS.resize(mNumPartitions); mGen.getPoints(mS, getPprfFormat()); - auto j = mNumPartitions * mSizePer; - - for (u64 i = 0; i < (u64)mGapBaseChoice.size(); ++i) - { - if (mGapBaseChoice[i]) - { - mS.push_back(j + i); - } - } - // if (mMalType == SilentSecType::Malicious) // { // @@ -352,7 +340,7 @@ namespace osuCrypto::Subfield // return y; // } - return std::vector(mNoiseValues.begin(), mNoiseValues.end()); + return mNoiseValues; } // Set the externally generated base OTs. This choice @@ -366,11 +354,7 @@ namespace osuCrypto::Subfield if (static_cast(recvBaseOts.size()) != silentBaseOtCount()) throw std::runtime_error("wrong number of silent base OTs"); - auto genOts = recvBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOts = recvBaseOts.subspan(mGen.baseOtCount(), mGapOts.size()); - - mGen.setBase(genOts); - std::copy(gapOts.begin(), gapOts.end(), mGapOts.begin()); + mGen.setBase(recvBaseOts); // if (mMalType == SilentSecType::Malicious) // { @@ -378,7 +362,8 @@ namespace osuCrypto::Subfield // noiseDeltaShare = noiseDeltaShare.subspan(0, noiseDeltaShare.size() - 1); // } - mNoiseDeltaShare = AlignedVector(noiseDeltaShare.begin(), noiseDeltaShare.end()); + CoeffCtx::resize(mNoiseDeltaShare, noiseDeltaShare.size()); + CoeffCtx::copy(noiseDeltaShare.begin(), noiseDeltaShare.end(), mNoiseDeltaShare.begin()); mState = State::HasBase; } @@ -389,18 +374,19 @@ namespace osuCrypto::Subfield // the silent base OTs will automatically be performed. task<> silentReceive( span c, - span b, + span a, PRNG& prng, Socket& chl) { - MC_BEGIN(task<>, this, c, b, &prng, &chl); - if (c.size() != b.size()) + MC_BEGIN(task<>, this, c, a, &prng, &chl); + if (c.size() != a.size()) throw RTE_LOC; MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); - std::memcpy(c.data(), mC.data(), c.size() * TypeTrait::bytesG); - std::memcpy(b.data(), mA.data(), b.size() * TypeTrait::bytesF); + CoeffCtx::copy(mC.begin(), mC.begin() + c.size(), c.begin()); + CoeffCtx::copy(mA.begin(), mA.begin() + a.size(), a.begin()); + clear(); MC_END(); } @@ -415,7 +401,6 @@ namespace osuCrypto::Subfield Socket& chl) { MC_BEGIN(task<>, this, n, &prng, &chl, - gapVals = std::vector{}, myHash = std::array{}, theirHash = std::array{} ); @@ -437,52 +422,21 @@ namespace osuCrypto::Subfield } // allocate mA - mA.resize(0); - mA.resize(mN2); + CoeffCtx::resize(mA, 0); + CoeffCtx::resize(mA, mN2); setTimePoint("SilentVoleReceiver.alloc"); // allocate the space for mC - mC.resize(0); - mC.resize(mN2, AllocType::Zeroed); + CoeffCtx::resize(mC, 0); + CoeffCtx::resize(mC, mN2); + CoeffCtx::zero(mC.begin(), mC.end()); setTimePoint("SilentVoleReceiver.alloc.zero"); - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - - if (gapVals.size()) - MC_AWAIT(chl.recv(gapVals)); - - for (auto g : rng(mGapOts.size())) - { - auto aa = mA.subspan(mNumPartitions * mSizePer); - auto cc = mC.subspan(mNumPartitions * mSizePer); - - auto noise = mNoiseValues.subspan(mNumPartitions); - auto noiseShares = mNoiseDeltaShare.subspan(mNumPartitions); - - if (mGapBaseChoice[g]) - { - cc[g] = noise[g]; - aa[g] = TypeTrait::minus( - TypeTrait::minus(gapVals[g], TypeTrait::fromBlock(AES(mGapOts[g]).ecbEncBlock(ZeroBlock))), - noiseShares[g]); - } - else - { - aa[g] = TypeTrait::fromBlock(mGapOts[g]); - } - } - - setTimePoint("SilentVoleReceiver.recvGap"); - - - if (mTimer) mGen.setTimer(*mTimer); // expand the seeds into mA - MC_AWAIT(mGen.expand(chl, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); + MC_AWAIT(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); @@ -491,11 +445,10 @@ namespace osuCrypto::Subfield for (u64 i = 0; i < mNumPartitions; ++i) { auto pnt = mS[i]; - mC[pnt] = mNoiseValues[i]; - mA[pnt] = TypeTrait::minus(mA[pnt], mNoiseDeltaShare[i]); + CoeffCtx::copy(mC[pnt], mNoiseValues[i]); + CoeffCtx::minus(mA[pnt], mA[pnt], mNoiseDeltaShare[i]); } - if (mDebug) { MC_AWAIT(checkRT(chl)); @@ -503,17 +456,17 @@ namespace osuCrypto::Subfield } - // if (mMalType == SilentSecType::Malicious) - // { - // MC_AWAIT(chl.send(std::move(mMalCheckSeed))); + // if (mMalType == SilentSecType::Malicious) + // { + // MC_AWAIT(chl.send(std::move(mMalCheckSeed))); // - // myHash = ferretMalCheck(mDeltaShare, mNoiseValues); + // myHash = ferretMalCheck(mDeltaShare, mNoiseValues); // - // MC_AWAIT(chl.recv(theirHash)); + // MC_AWAIT(chl.recv(theirHash)); // - // if (theirHash != myHash) - // throw RTE_LOC; - // } + // if (theirHash != myHash) + // throw RTE_LOC; + // } switch (mMultType) { @@ -523,10 +476,10 @@ namespace osuCrypto::Subfield mExConvEncoder.setTimer(getTimer()); } - mExConvEncoder.dualEncode2( - mA.subspan(0, mExConvEncoder.mCodeSize), - mC.subspan(0, mExConvEncoder.mCodeSize) - ); + mExConvEncoder.dualEncode2( + mA.begin(), + mC.begin() + ); break; default: @@ -535,8 +488,8 @@ namespace osuCrypto::Subfield } // resize the buffers down to only contain the real elements. - mA.resize(mRequestedNumOTs); - mC.resize(mRequestedNumOTs); + CoeffCtx::resize(mA, mRequestedNumOTs); + CoeffCtx::resize(mC, mRequestedNumOTs); mNoiseValues = {}; mNoiseDeltaShare = {}; @@ -554,23 +507,29 @@ namespace osuCrypto::Subfield task<> checkRT(Socket& chl) const { MC_BEGIN(task<>, this, &chl, - B = AlignedVector(mA.size()), - sparseNoiseDelta = std::vector(mA.size()), - noiseDeltaShare2 = std::vector(), - delta = F{} + B = typename CoeffCtx::Vec{}, + sparseNoiseDelta = typename CoeffCtx::Vec{}, + noiseDeltaShare2 = typename CoeffCtx::Vec{}, + delta = typename CoeffCtx::Vec{}, + tempF = typename CoeffCtx::Vec{}, + tempG = typename CoeffCtx::Vec{}, + buffer = std::vector{} ); - //std::vector mB(mA.size()); - MC_AWAIT(chl.recv(delta)); - MC_AWAIT(chl.recv(B)); - MC_AWAIT(chl.recvResize(noiseDeltaShare2)); - - for (u64 i = 0; i < mA.size(); i++) { - F left = TypeTrait::mul(delta, mC[i]); - F right = TypeTrait::minus(mA[i], B[i]); - if (left != right) { - throw RTE_LOC; - } - } + + // recv delta + buffer.resize(CoeffCtx::byteSize()); + MC_AWAIT(chl.recv(buffer)); + CoeffCtx::deserialize(delta, buffer); + + // recv B + buffer.resize(CoeffCtx::byteSize() * mA.size()); + MC_AWAIT(chl.recv(buffer)); + CoeffCtx::deserialize(B, buffer); + + // recv the noisy values. + buffer.resize(CoeffCtx::byteSize() * mNoiseDeltaShare.size()); + MC_AWAIT(chl.recvResize(buffer)); + CoeffCtx::deserialize(noiseDeltaShare2, buffer); //check that at locations mS[0],...,mS[..] // that we hold a sharing mA, mB of @@ -584,154 +543,72 @@ namespace osuCrypto::Subfield // delta * mC = mA + mB // -// if (noiseDeltaShare2.size() != mNoiseDeltaShare.size()) -// throw RTE_LOC; -// -// for (auto i : rng(mNoiseDeltaShare.size())) -// { -// if ((mNoiseDeltaShare[i] ^ noiseDeltaShare2[i]) != mNoiseValues[i].gf128Mul(delta)) -// throw RTE_LOC; -// } -// -// { -// -// for (auto i : rng(mNumPartitions* mSizePer)) -// { -// auto iter = std::find(mS.begin(), mS.end(), i); -// if (iter != mS.end()) -// { -// auto d = iter - mS.begin(); -// -// if (mC[i] != mNoiseValues[d]) -// throw RTE_LOC; -// -// if (mNoiseValues[d].gf128Mul(delta) != (mA[i] ^ B[i])) -// { -// std::cout << "bad vole base correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; -// std::cout << "i " << i << std::endl; -// std::cout << "mA[i] " << mA[i] << std::endl; -// std::cout << "mB[i] " << B[i] << std::endl; -// std::cout << "mC[i] " << mC[i] << std::endl; -// std::cout << "delta " << delta << std::endl; -// std::cout << "mA[i] + mB[i] " << (mA[i] ^ B[i]) << std::endl; -// std::cout << "mC[i] * delta " << (mC[i].gf128Mul(delta)) << std::endl; -// -// throw RTE_LOC; -// } -// } -// else -// { -// if (mA[i] != B[i]) -// { -// std::cout << mA[i] << " " << B[i] << std::endl; -// throw RTE_LOC; -// } -// -// if (mC[i] != oc::ZeroBlock) -// throw RTE_LOC; -// } -// } -// -// u64 d = mNumPartitions; -// for (auto j : rng(mGapBaseChoice.size())) -// { -// auto idx = j + mNumPartitions * mSizePer; -// auto aa = mA.subspan(mNumPartitions * mSizePer); -// auto bb = B.subspan(mNumPartitions * mSizePer); -// auto cc = mC.subspan(mNumPartitions * mSizePer); -// auto noise = mNoiseValues.subspan(mNumPartitions); -// //auto noiseShare = mNoiseValues.subspan(mNumPartitions); -// if (mGapBaseChoice[j]) -// { -// if (mS[d++] != idx) -// throw RTE_LOC; -// -// if (cc[j] != noise[j]) -// { -// std::cout << "sparse noise vector mC is not the expected value" << std::endl; -// std::cout << "i j " << idx << " " << j << std::endl; -// std::cout << "mC[i] " << cc[j] << std::endl; -// std::cout << "noise[j] " << noise[j] << std::endl; -// throw RTE_LOC; -// } -// -// if (noise[j].gf128Mul(delta) != (aa[j] ^ bb[j])) -// { -// -// std::cout << "bad vole base GAP correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; -// std::cout << "i " << idx << std::endl; -// std::cout << "mA[i] " << aa[j] << std::endl; -// std::cout << "mB[i] " << bb[j] << std::endl; -// std::cout << "mC[i] " << cc[j] << std::endl; -// std::cout << "delta " << delta << std::endl; -// std::cout << "mA[i] + mB[i] " << (aa[j] ^ bb[j]) << std::endl; -// std::cout << "mC[i] * delta " << (cc[j].gf128Mul(delta)) << std::endl; -// std::cout << "noise * delta " << (noise[j].gf128Mul(delta)) << std::endl; -// throw RTE_LOC; -// } -// -// } -// else -// { -// if (aa[j] != bb[j]) -// throw RTE_LOC; -// -// if (cc[j] != oc::ZeroBlock) -// throw RTE_LOC; -// } -// } -// -// if (d != mS.size()) -// throw RTE_LOC; -// } - - - //{ - - // auto cDelta = B; - // for (u64 i = 0; i < cDelta.size(); ++i) - // cDelta[i] = cDelta[i] ^ mA[i]; - - // std::vector exp(mN2); - // for (u64 i = 0; i < mNumPartitions; ++i) - // { - // auto j = mS[i]; - // exp[j] = noiseDeltaShare2[i]; - // } - - // auto iter = mS.begin() + mNumPartitions; - // for (u64 i = 0, j = mNumPartitions * mSizePer; i < mGapOts.size(); ++i, ++j) - // { - // if (mGapBaseChoice[i]) - // { - // if (*iter != j) - // throw RTE_LOC; - // ++iter; - - // exp[j] = noiseDeltaShare2[mNumPartitions + i]; - // } - // } - - // if (iter != mS.end()) - // throw RTE_LOC; - - // bool failed = false; - // for (u64 i = 0; i < mN2; ++i) - // { - // if (neq(cDelta[i], exp[i])) - // { - // std::cout << i << " / " << mN2 << - // " cd = " << cDelta[i] << - // " exp= " << exp[i] << std::endl; - // failed = true; - // } - // } - - // if (failed) - // throw RTE_LOC; - - // std::cout << "debug check ok" << std::endl; - //} + CoeffCtx::resize(tempF, 2); + CoeffCtx::resize(tempG, 1); + CoeffCtx::zero(tempG.begin(), tempG.end()); + + for (auto i : rng(mNoiseDeltaShare.size())) + { + // temp[0] = mNoiseDeltaShare[i] + noiseDeltaShare2[i] + CoeffCtx::plus(tempF[0], mNoiseDeltaShare[i], noiseDeltaShare2[i]); + + // temp[1] = mNoiseValues[i] * delta[0] + CoeffCtx::mul(tempF[1], delta[0], mNoiseValues[i]); + + if (!CoeffCtx::eq(tempF[0], tempF[1])) + throw RTE_LOC; + } + + { + + for (auto i : rng(mNumPartitions* mSizePer)) + { + auto iter = std::find(mS.begin(), mS.end(), i); + if (iter != mS.end()) + { + auto d = iter - mS.begin(); + + if (!CoeffCtx::eq(mC[i], mNoiseValues[d])) + throw RTE_LOC; + + // temp[0] = A[i] + B[i] + CoeffCtx::plus(tempF[0], mA[i], B[i]); + + // temp[1] = mNoiseValues[d] * delta[0] + CoeffCtx::mul(tempF[1], delta[0], mNoiseValues[d]); + + + if (!CoeffCtx::eq(tempF[0], tempF[1])) + { + std::cout << "bad vole base noisy correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; + std::cout << "i " << i << std::endl; + //std::cout << "mA[i] " << mA[i] << std::endl; + //std::cout << "mB[i] " << B[i] << std::endl; + //std::cout << "mC[i] " << mC[i] << std::endl; + //std::cout << "delta " << delta << std::endl; + //std::cout << "mA[i] + mB[i] " << (mA[i] ^ B[i]) << std::endl; + //std::cout << "mC[i] * delta " << (mC[i].gf128Mul(delta)) << std::endl; + + throw RTE_LOC; + } + } + else + { + if (!CoeffCtx::eq(mA[i], B[i])) + { + std::cout << "bad vole base non-noisy correlation, mA[i] + mB[i] != 0" << std::endl; + //std::cout << mA[i] << " " << B[i] << std::endl; + throw RTE_LOC; + } + + if (!CoeffCtx::eq(mC[i], tempG[0])) + { + std::cout << "bad vole base non-noisy correlation, mC[i] != 0" << std::endl; + throw RTE_LOC; + } + } + } + } MC_END(); } @@ -777,7 +654,6 @@ namespace osuCrypto::Subfield mA = {}; mC = {}; mGen.clear(); - mGapBaseChoice = {}; } }; } diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h index d84eef18..f470e87e 100644 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -19,57 +19,19 @@ #include #include #include -#include -#include +#include #include #include #include //#define NO_HASH -namespace osuCrypto::Subfield +namespace osuCrypto { - - inline void SubfieldExConvConfigure( - u64 numOTs, u64 secParam, - MultType mMultType, - u64& mRequestedNumOTs, - u64& mNumPartitions, - u64& mSizePer, - u64& mN2, - u64& mN, - ExConvCode& mEncoder - ) - { - u64 a = 24; - auto mScaler = 2; - u64 w; - double minDist; - switch (mMultType) - { - case osuCrypto::MultType::ExConv7x24: - w = 7; - minDist = 0.1; - break; - case osuCrypto::MultType::ExConv21x24: - w = 21; - minDist = 0.15; - break; - default: - throw RTE_LOC; - break; - } - - mRequestedNumOTs = numOTs; - mNumPartitions = getRegNoiseWeight(minDist, secParam); - mSizePer = roundUpTo((numOTs * mScaler + mNumPartitions - 1) / mNumPartitions, 8); - mN2 = mSizePer * mNumPartitions; - mN = mN2 / mScaler; - - mEncoder.config(numOTs, numOTs * mScaler, w, a, true); - } - - - template> + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > class SilentSubfieldVoleSender : public TimerAdapter { public: @@ -82,10 +44,12 @@ namespace osuCrypto::Subfield HasBase }; + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; State mState = State::Default; - SilentSubfieldPprfSender mGen; + SilentSubfieldPprfSender mGen; u64 mRequestedNumOTs = 0; u64 mN2 = 0; @@ -93,9 +57,8 @@ namespace osuCrypto::Subfield u64 mNumPartitions = 0; u64 mSizePer = 0; u64 mNumThreads = 1; - std::vector> mGapOts; SilentBaseType mBaseType; - std::vector mNoiseDeltaShares; + VecF mNoiseDeltaShares; SilentSecType mMalType = SilentSecType::SemiHonest; @@ -105,41 +68,14 @@ namespace osuCrypto::Subfield #endif MultType mMultType = DefaultMultType; -#ifdef ENABLE_INSECURE_SILVER - SilverEncoder mEncoder; -#endif - ExConvCode mExConvEncoder; - - AlignedUnVector mB; - - ///////////////////////////////////////////////////// - // The standard OT extension interface - ///////////////////////////////////////////////////// - -// // the number of IKNP base OTs that should be set. -// u64 baseOtCount() const; -// -// // returns true if the IKNP base OTs are currently set. -// bool hasBaseOts() const; -// -// // sets the IKNP base OTs that are then used to extend -// void setBaseOts( -// span baseRecvOts, -// const BitVector& choices); - - // use the default base OT class to generate the - // IKNP base OTs that are required. -// task<> genBaseOts(PRNG& prng, Socket& chl) -// { -// return mOtExtSender.genBaseOts(prng, chl); -// } - - ///////////////////////////////////////////////////// - // The native silent OT extension interface - ///////////////////////////////////////////////////// + + ExConvCode2 mExConvEncoder; + + VecF mB; + u64 baseVoleCount() const { - return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); } // Generate the silent base OTs. If the Iknp @@ -157,7 +93,7 @@ namespace osuCrypto::Subfield prng2 = std::move(PRNG{}), xx = BitVector{}, chl2 = Socket{}, - nv = NoisySubfieldVoleSender{}, + nv = NoisySubfieldVoleSender{}, noiseDeltaShares = std::vector{} ); setTimePoint("SilentVoleSender.genSilent.begin"); @@ -165,9 +101,10 @@ namespace osuCrypto::Subfield if (isConfigured() == false) throw std::runtime_error("configure must be called first"); + if(!delta) + CoeffCtx::fromBlock(*delta, prng.get()); - delta = delta.value_or(TypeTrait::fromBlock(prng.get())); - xx = TypeTrait::BitVectorF(*delta); + xx = CoeffCtx::binaryDecomposition(*delta); // compute the correlation for the noisy coordinates. noiseDeltaShares.resize(baseVoleCount()); @@ -232,21 +169,19 @@ namespace osuCrypto::Subfield u64 secParam = 128) { mBaseType = type; - u64 gap = 0; switch (mMultType) { case osuCrypto::MultType::ExConv7x24: case osuCrypto::MultType::ExConv21x24: - SubfieldExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); + ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); break; default: throw RTE_LOC; break; } - mGapOts.resize(gap); mGen.configure(mSizePer, mNumPartitions); mState = State::Configured; @@ -262,7 +197,7 @@ namespace osuCrypto::Subfield if (isConfigured() == false) throw std::runtime_error("configure must be called first"); - return mGen.baseOtCount() + mGapOts.size(); + return mGen.baseOtCount(); } // Set the externally generated base OTs. This choice @@ -277,11 +212,7 @@ namespace osuCrypto::Subfield if (noiseDeltaShares.size() != baseVoleCount()) throw RTE_LOC; - auto genOt = sendBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOt = sendBaseOts.subspan(genOt.size(), mGapOts.size()); - - mGen.setBase(genOt); - std::copy(gapOt.begin(), gapOt.end(), mGapOts.begin()); + mGen.setBase(sendBaseOts); mNoiseDeltaShares.resize(noiseDeltaShares.size()); std::copy(noiseDeltaShares.begin(), noiseDeltaShares.end(), mNoiseDeltaShares.begin()); } @@ -301,7 +232,8 @@ namespace osuCrypto::Subfield MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); - std::memcpy(b.data(), mB.data(), b.size() * TypeTrait::bytesF); + CoeffCtx::copy(mB.begin(), mB.begin() + b.size(), b.begin()); + //std::memcpy(b.data(), mB.data(), b.size() * CoeffCtx::bytesF); clear(); setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); @@ -320,12 +252,9 @@ namespace osuCrypto::Subfield Socket& chl) { MC_BEGIN(task<>, this, delta, n, &prng, &chl, - gapVals = std::vector{}, deltaShare = block{}, X = block{}, - hash = std::array{}, - noiseShares = span{}, - mbb = span{} + hash = std::array{} ); setTimePoint("SilentVoleSender.ot.enter"); @@ -349,33 +278,16 @@ namespace osuCrypto::Subfield setTimePoint("SilentVoleSender.start"); //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); -// if (mMalType == SilentSecType::Malicious) -// { -// deltaShare = mNoiseDeltaShares.back(); -// mNoiseDeltaShares.pop_back(); -// } + //if (mMalType == SilentSecType::Malicious) + //{ + // deltaShare = mNoiseDeltaShares.back(); + // mNoiseDeltaShares.pop_back(); + //} - // allocate B + // allocate B mB.resize(0); mB.resize(mN2); - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - for (u64 i = mNumPartitions * mSizePer, j = 0; i < mN2; ++i, ++j) - { - auto t = TypeTrait::fromBlock(mGapOts[j][0]); - auto v = TypeTrait::plus(t, mNoiseDeltaShares[mNumPartitions + j]); - gapVals[j] = TypeTrait::plus( - TypeTrait::fromBlock(AES(mGapOts[j][1]).ecbEncBlock(ZeroBlock)), - v); - mB[i] = t; - } - - if (gapVals.size()) - MC_AWAIT(chl.send(std::move(gapVals))); - - if (mTimer) mGen.setTimer(*mTimer); @@ -383,25 +295,22 @@ namespace osuCrypto::Subfield // our secret share of delta * noiseVals. The receiver // can then manually add their shares of this to the // output of the PPRF at the correct locations. - noiseShares = span(mNoiseDeltaShares.data(), mNumPartitions); - mbb = mB.subspan(0, mNumPartitions * mSizePer); - MC_AWAIT(mGen.expand(chl, noiseShares, prng.get(), mbb, + MC_AWAIT(mGen.expand(chl, mNoiseDeltaShares, prng.get(), mB, PprfOutputFormat::Interleaved, true, mNumThreads)); + setTimePoint("SilentVoleSender.expand.pprf"); - setTimePoint("SilentVoleSender.expand.pprf_transpose"); if (mDebug) { MC_AWAIT(checkRT(chl, delta)); setTimePoint("SilentVoleSender.expand.checkRT"); } - - // if (mMalType == SilentSecType::Malicious) - // { - // MC_AWAIT(chl.recv(X)); - // hash = ferretMalCheck(X, deltaShare); - // MC_AWAIT(chl.send(std::move(hash))); - // } + //if (mMalType == SilentSecType::Malicious) + //{ + // MC_AWAIT(chl.recv(X)); + // hash = ferretMalCheck(X, deltaShare); + // MC_AWAIT(chl.send(std::move(hash))); + //} switch (mMultType) { @@ -410,7 +319,7 @@ namespace osuCrypto::Subfield if (mTimer) { mExConvEncoder.setTimer(getTimer()); } - mExConvEncoder.dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); + mExConvEncoder.dualEncode(mB.begin()); break; default: throw RTE_LOC; diff --git a/libOTe_Tests/ExConvCode_Tests.cpp b/libOTe_Tests/ExConvCode_Tests.cpp index 22565f4e..c6fe4b0e 100644 --- a/libOTe_Tests/ExConvCode_Tests.cpp +++ b/libOTe_Tests/ExConvCode_Tests.cpp @@ -1,222 +1,232 @@ #include "ExConvCode_Tests.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" -#include "libOTe/Tools/ExConvCode/ExConvCode.h" +#include "libOTe/Tools/ExConvCode/ExConvCode2.h" #include +#include "libOTe/Tools/Subfield/Subfield.h" namespace osuCrypto { - void ExConvCode_encode_basic_test(const oc::CLP& cmd) + + std::ostream& operator<<(std::ostream& o, const std::array&a) { + o << "{" << a[0] << " " << a[1] << " " << a[2] << "}"; + return o; + } - auto k = cmd.getOr("k", 16ul); - auto R = cmd.getOr("R", 2.0); - auto n = cmd.getOr("n", k * R); - auto bw = cmd.getOr("bw", 7); - auto aw = cmd.getOr("aw", 8); + struct mtxPrint + { + mtxPrint(AlignedUnVector& d, u64 r, u64 c) + :mData(d) + , mRows(r) + , mCols(c) + {} - bool v = cmd.isSet("v"); + AlignedUnVector& mData; + u64 mRows, mCols; + }; - for (auto sys : {/* false,*/ true }) + std::ostream& operator<<(std::ostream& o, const mtxPrint& m) + { + for (u64 i = 0; i < m.mRows; ++i) { + for (u64 j = 0; j < m.mCols; ++j) + { + bool color = (int)m.mData[i * m.mCols + j] && &o == &std::cout; + if (color) + o << Color::Green; + + o << (int)m.mData[i * m.mCols + j] << " "; + + if (color) + o << Color::Default; + } + o << std::endl; + } + return o; + } + template + void exConvTest(u64 k, u64 n, u64 bw, u64 aw, bool sys) + { + ExConvCode2 code; + code.config(k, n, bw, aw, sys); - ExConvCode code; - code.config(k, n, bw, aw, sys); + auto accOffset = sys * k; + std::vector x1(n), x2(n), x3(n), x4(n); + PRNG prng(CCBlock); - auto A = code.getA(); - auto B = code.getB(); - auto G = B * A; + for (u64 i = 0; i < x1.size(); ++i) + { + x1[i] = x2[i] = x3[i] = prng.get(); + } - std::vector m0(k), m1(k), a1(n); + std::vector rand(divCeil(aw, 8)); + for (i64 i = 0; i < x1.size() - aw - 1; ++i) + { + prng.get(rand.data(), rand.size()); + code.accOne(x1.begin() + i, x1.end(), rand.data(), std::integral_constant{}); - if (v) - { - std::cout << "B\n" << B << std::endl << std::endl; - std::cout << "A'\n" << code.getAPar() << std::endl << std::endl; - std::cout << "A\n" << A << std::endl << std::endl; - std::cout << "G\n" << G << std::endl; + if (aw == 16) + code.accOne(x2.begin() + i, x2.end(), rand.data(), std::integral_constant{}); - } - const auto c0 = [&]() { - std::vector c0(n); - PRNG prng(ZeroBlock); - prng.get(c0.data(), c0.size()); - return c0; - }(); - - auto a0 = c0; - auto aa0 = a0; - std::vector aa1(n); - for (u64 i = 0; i < n; ++i) + CoeffCtx::plus(x3[i + 1], x3[i + 1], x3[i]); + for (u64 j = 0; j < aw && (i + j + 2) < x3.size(); ++j) { - aa1[i] = aa0[i].get(0); + if (*BitIterator(rand.data(), j)) + { + CoeffCtx::plus(x3[i + j + 2], x3[i + j + 2], x3[i]); + } } - if (code.mSystematic) + + for (u64 j = i; j < x1.size() && j < i + aw + 2; ++j) { - code.accumulate(span(a0.begin() + k, a0.begin() + n)); - code.accumulate( - span(aa0.begin() + k, aa0.begin() + n), - span(aa1.begin() + k, aa1.begin() + n) - ); + if (aw == 16 && x1[j] != x2[j]) + { + std::cout << j << " " << (x1[j]) << " " << (x2[j]) << std::endl; + throw RTE_LOC; + } - for (u64 i = 0; i < n; ++i) + if (x1[j] != x3[j]) { - if (aa0[i] != a0[i]) - throw RTE_LOC; - if (aa1[i] != a0[i].get(0)) - throw RTE_LOC; + std::cout << j << " " << (x1[j]) << " " << (x3[j]) << std::endl; + throw RTE_LOC; } } - else - { - code.accumulate(a0); - } - A.multAdd(c0, a1); - //A.leftMultAdd(c0, a1); - if (a0 != a1) + } + + + x4 = x1; + //std::cout << std::endl; + + code.accumulateFixed(x1.begin() + accOffset); + + if (aw == 16) + { + code.accumulateFixed(x2.begin() + accOffset); + + if (x1 != x2) { - if (v) + for (u64 i = 0; i < x1.size(); ++i) { - - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (a0[i]) << " "; - std::cout << "\n"; - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (a1[i]) << " "; - std::cout << "\n"; + std::cout << i << " " << (x1[i]) << " " << (x2[i]) << std::endl; } - throw RTE_LOC; } + } + { + PRNG coeffGen(code.mSeed ^ OneBlock); + u8* mtxCoeffIter = (u8*)coeffGen.mBuffer.data(); + auto mtxCoeffEnd = mtxCoeffIter + coeffGen.mBuffer.size() * sizeof(block) - divCeil(aw, 8); - - for (u64 q = 0; q < n; ++q) + auto xi = x3.begin() + accOffset; + auto end = x3.end(); + while (xi < end) { - std::vector c0(n); - c0[q] = AllOneBlock; - - //auto q = 0; - auto cc = c0; - auto cc1 = c0; - auto mm1 = m1; - - std::vector cc2(cc1.size()), mm2(mm1.size()); - for (u64 i = 0; i < n; ++i) - cc2[i] = cc1[i].get(0); - for (u64 i = 0; i < k; ++i) - mm2[i] = mm1[i].get(0); - //std::vector cc(n); - //cc[q] = AllOneBlock; - std::fill(m0.begin(), m0.end(), ZeroBlock); - B.multAdd(cc, m0); - - - if (code.mSystematic) + if (mtxCoeffIter > mtxCoeffEnd) { - std::copy(cc.begin(), cc.begin() + k, m1.begin()); - code.mExpander.expand( - span(cc.begin() + k, cc.end()), - m1); - //for (u64 i = 0; i < k; ++i) - // m1[i] ^= cc[i]; - std::copy(cc1.begin(), cc1.begin() + k, mm1.begin()); - std::copy(cc2.begin(), cc2.begin() + k, mm2.begin()); - - code.mExpander.expand( - span(cc1.begin() + k, cc1.end()), - span(cc2.begin() + k, cc2.end()), - mm1, mm2); + // generate more mtx coefficients + ExConvCode2::refill(coeffGen); + mtxCoeffIter = (u8*)coeffGen.mBuffer.data(); } - else + + // add xi to the next positions + auto xj = xi + 1; + if (xj != end) { - code.mExpander.expand(cc, m1); + CoeffCtx::plus(*xj, *xj, *xi); + ++xj; } - if (m0 != m1) + for (u64 j = 0; j < aw && xj != end; ++j, ++xj) { - - std::cout << "B\n" << B << std::endl << std::endl; - for (u64 i = 0; i < n; ++i) - std::cout << (c0[i].get(0) & 1) << " "; - std::cout << std::endl; - - std::cout << "exp act " << q << "\n"; - for (u64 i = 0; i < k; ++i) + if (*BitIterator(mtxCoeffIter, j)) { - std::cout << (m0[i].get(0) & 1) << " " << (m1[i].get(0) & 1) << std::endl; + CoeffCtx::plus(*xj, *xj, *xi); } - throw RTE_LOC; } + ++mtxCoeffIter; - if (code.mSystematic) - { - if (mm1 != m1) - throw RTE_LOC; - - for (u64 i = 0; i < k; ++i) - if (mm2[i] != m1[i].get(0)) - throw RTE_LOC; - } + ++xi; } + } - //for (u64 q = 0; q < n; ++q) + if (x1 != x3) + { + for (u64 i = 0; i < x1.size(); ++i) { - auto q = 0; + std::cout << i << " " << (x1[i]) << " " << (x3[i]) << std::endl; + } + throw RTE_LOC; + } - //std::fill(c0.begin(), c0.end(), ZeroBlock); - //c0[q] = AllOneBlock; - auto cc = c0; - auto cc1 = c0; - std::vector cc2(cc1.size()); - for (u64 i = 0; i < n; ++i) - cc2[i] = cc1[i].get(0); + detail::ExpanderModd expanderCoeff(code.mExpander.mSeed, code.mExpander.mCodeSize); + std::vector y1(k), y2(k); - std::fill(m0.begin(), m0.end(), ZeroBlock); - G.multAdd(c0, m0); + if (sys) + { + std::copy(x1.begin(), x1.begin() + k, y1.begin()); + y2 = y1; + code.mExpander.expand(x1.cbegin() + accOffset, y1.begin()); + } + else + { + code.mExpander.expand(x1.cbegin() + accOffset, y1.begin()); + } - if (code.mSystematic) - { - code.dualEncode(cc); - std::copy(cc.begin(), cc.begin() + k, m1.begin()); - } - else + u64 i = 0; + auto main = k / 8 * 8; + for (; i < main; i += 8) + { + for (u64 j = 0; j < code.mExpander.mExpanderWeight; ++j) + { + for (u64 p = 0; p < 8; ++p) { - code.dualEncode(cc, m1); + auto idx = expanderCoeff.get(); + CoeffCtx::plus(y2[i + p], y2[i + p], x1[idx + accOffset]); } + } + } - if (m0 != m1) - { - std::cout << "G\n" << G << std::endl << std::endl; - for (u64 i = 0; i < n; ++i) - std::cout << (c0[i].get(0) & 1) << " "; - std::cout << std::endl; + for (; i < k; ++i) + { + for (u64 j = 0; j < code.mExpander.mExpanderWeight; ++j) + { + auto idx = expanderCoeff.get(); + CoeffCtx::plus(y2[i], y2[i], x1[idx + accOffset]); + } + } - std::cout << "exp act " << q << "\n"; - for (u64 i = 0; i < k; ++i) - { - std::cout << (m0[i].get(0) & 1) << " " << (m1[i].get(0) & 1) << std::endl; - } - throw RTE_LOC; - } + if (y1 != y2) + throw RTE_LOC; + code.dualEncode(x4.begin()); - if (code.mSystematic) - { - code.dualEncode2(cc1, cc2); + x4.resize(k); + if (x4 != y1) + throw RTE_LOC; + } - for (u64 i = 0; i < k; ++i) - { - if (cc1[i] != cc[i]) - throw RTE_LOC; - if (cc2[i] != cc[i].get(0)) - throw RTE_LOC; - } - } - } + + void ExConvCode_encode_basic_test(const oc::CLP& cmd) + { + + auto K = cmd.getManyOr("k", { 16ul, 64, 4353 }); + auto R = cmd.getManyOr("R", { 2.0, 3.0 }); + auto Bw = cmd.getManyOr("bw", { 7, 21 }); + auto Aw = cmd.getManyOr("aw", { 16, 24, 29 }); + + bool v = cmd.isSet("v"); + for (auto k : K) for (auto r : R) for (auto bw : Bw) for (auto aw : Aw) for (auto sys : { false, true }) + { + auto n = k * r; + exConvTest(k, n, bw, aw, sys); + exConvTest(k, n, bw, aw, sys); + exConvTest(k, n, bw, aw, sys); + exConvTest, CoeffCtxArray>(k, n, bw, aw, sys); } - } + } } \ No newline at end of file diff --git a/libOTe_Tests/SilentOT_Tests.cpp b/libOTe_Tests/SilentOT_Tests.cpp index e1eb82e1..4cca9d7d 100644 --- a/libOTe_Tests/SilentOT_Tests.cpp +++ b/libOTe_Tests/SilentOT_Tests.cpp @@ -841,7 +841,6 @@ void OtExt_Silent_mal_Test(const oc::CLP& cmd) #endif } - void Tools_Pprf_expandOne_test(const oc::CLP& cmd) { #if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp index ab99fcc7..7a0414b1 100644 --- a/libOTe_Tests/Subfield_Tests.cpp +++ b/libOTe_Tests/Subfield_Tests.cpp @@ -1,6 +1,6 @@ #include "Subfield_Test.h" #include "libOTe/Tools/Subfield/Subfield.h" -#include "libOTe/Tools/Subfield/ExConvCode.h" +#include "libOTe/Tools/ExConvCode/ExConvCode2.h" #include "libOTe/Vole/Subfield/NoisyVoleSender.h" #include "libOTe/Vole/Subfield/NoisyVoleReceiver.h" #include "libOTe/Vole/Subfield/SilentVoleSender.h" @@ -10,15 +10,16 @@ namespace osuCrypto::Subfield { + static_assert(std::is_trivially_copyable_v>); + static_assert(std::is_trivially_copyable_v); using tests_libOTe::eval; - void Subfield_ExConvCode_encode_test(const oc::CLP& cmd) { { - using TypeTrait = TypeTraitF128; + using CoeffCtx = DefaultCoeffCtx; u64 n = 1024; - ExConvCode code; + ExConvCode2 code; code.config(n / 2, n, 7, 24, true); PRNG prng(ZeroBlock); @@ -31,10 +32,10 @@ namespace osuCrypto::Subfield z1[i] = z0[i] ^ delta.gf128Mul(y[i]); } - code.dualEncode(z1); -// code.dualEncode(z0); -// code.dualEncode(y); - code.dualEncode2(z0, y); + code.dualEncode(z1.begin()); + code.dualEncode(z0.begin()); + code.dualEncode(y.begin()); + //code.dualEncode2(z0, y); for (u64 i = 0; i < n; ++i) { @@ -46,9 +47,9 @@ namespace osuCrypto::Subfield } { - using TypeTrait = TypeTraitPrimitive; + using CoeffCtx = DefaultCoeffCtx; u64 n = 1024; - ExConvCode code; + ExConvCode2 code; code.config(n / 2, n, 7, 24, true); PRNG prng(ZeroBlock); @@ -61,8 +62,11 @@ namespace osuCrypto::Subfield z1[i] = z0[i] + delta * y[i]; } - code.dualEncode(z1); - code.dualEncode2(z0, y); + code.dualEncode(z1.begin()); + code.dualEncode(z0.begin()); + code.dualEncode(y.begin()); + + //code.dualEncode2(z0, y); for (u64 i = 0; i < n; ++i) { @@ -74,9 +78,9 @@ namespace osuCrypto::Subfield } { - using TypeTrait = TypeTrait64; + using CoeffCtx = DefaultCoeffCtx; u64 n = 1024; - ExConvCode code; + ExConvCode2 code; code.config(n / 2, n, 7, 24, true); PRNG prng(ZeroBlock); @@ -89,8 +93,10 @@ namespace osuCrypto::Subfield z1[i] = z0[i] + delta * y[i]; } - code.dualEncode(z1); - code.dualEncode2(z0, y); + code.dualEncode(z1.begin()); + code.dualEncode(z0.begin()); + code.dualEncode(y.begin()); + //code.dualEncode2(z0, y); for (u64 i = 0; i < n; ++i) { @@ -115,8 +121,8 @@ namespace osuCrypto::Subfield // auto sockets = cp::LocalAsyncSocket::makePair(); // auto format = PprfOutputFormat::Interleaved; - // SilentSubfieldPprfSender sender; - // SilentSubfieldPprfReceiver recver; + // SilentSubfieldPprfSender sender; + // SilentSubfieldPprfReceiver recver; // sender.configure(domain, numPoints); // recver.configure(domain, numPoints); @@ -176,8 +182,10 @@ namespace osuCrypto::Subfield std::vector z0(n), z1(n); prng.get(y.data(), y.size()); - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; + + using Trait = DefaultCoeffCtx; + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; recv.setTimer(timer); send.setTimer(timer); @@ -218,18 +226,23 @@ namespace osuCrypto::Subfield PRNG prng(seed); constexpr size_t N = 3; - using TypeTrait = TypeTraitVec; - u64 bitsF = TypeTrait::bitsF; - using F = TypeTrait::F; - using G = TypeTrait::G; - - F x = TypeTrait::fromBlock(prng.get()); + using G = u32; + using F = std::array; + using CoeffCtx = CoeffCtxArray; + u64 bitsF = sizeof(F) * 8;; + + static_assert( + std::is_standard_layout::value && + std::is_trivial::value + ); + F x; + CoeffCtx::fromBlock(x, prng.get()); std::vector y(n); std::vector z0(n), z1(n); prng.get(y.data(), y.size()); - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; recv.setTimer(timer); send.setTimer(timer); @@ -245,7 +258,7 @@ namespace osuCrypto::Subfield otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; timer.setTimePoint("ot"); - auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); + auto p0 = recv.receive((span)y, (span)z0, prng, otSendMsg, chls[0]); auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); eval(p0, p1); @@ -279,9 +292,10 @@ namespace osuCrypto::Subfield std::vector y(n); std::vector z0(n), z1(n); prng.get(y.data(), y.size()); - - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; + using F = block; + using G = block; + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; recv.setTimer(timer); send.setTimer(timer); @@ -339,8 +353,8 @@ namespace osuCrypto::Subfield recv.setTimer(timer); send.setTimer(timer); -// recv.mDebug = true; -// send.mDebug = true; + // recv.mDebug = true; + // send.mDebug = true; auto chls = cp::LocalAsyncSocket::makePair(); @@ -368,15 +382,16 @@ namespace osuCrypto::Subfield { PRNG prng(seed); constexpr size_t N = 10; - using F = Vec; using G = u32; - using TypeTrait = TypeTraitVec; - F x = TypeTrait::fromBlock(prng.get()); + using F = std::array; + using CoeffCtx = CoeffCtxArray; + F x; + CoeffCtx::fromBlock(x, prng.get()); std::vector c(n); - std::vector z0(n), z1(n); + std::vector a(n), b(n); - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; recv.mMultType = MultType::ExConv7x24; send.mMultType = MultType::ExConv7x24; @@ -394,8 +409,8 @@ namespace osuCrypto::Subfield timer.setTimePoint("ot"); // fakeBase(n, nt, prng, x, recv, send); - auto p0 = send.silentSend(x, span(z0), prng, chls[0]); - auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); + auto p0 = send.silentSend(x, span(b), prng, chls[0]); + auto p1 = recv.silentReceive(span(c), span(a), prng, chls[1]); eval(p0, p1); // std::cout << "transferred " << (chls[0].bytesSent() + chls[0].bytesReceived()) << std::endl; @@ -404,13 +419,16 @@ namespace osuCrypto::Subfield timer.setTimePoint("send"); for (u64 i = 0; i < n; i++) { for (u64 j = 0; j < N; j++) { - G left = c[i] * x[j]; - G right = z1[i][j] - z0[i][j]; - if (left != right) { - std::cout << "bad " << i << "\n c[i] " << c[i] << " * x[j] " << x[j] << " = " << left << std::endl; - std::cout << "z0[i][j] " << z0[i][j] << " - z1 " << z1[i][j] << " = " << right << std::endl; - throw RTE_LOC; - } + throw RTE_LOC;// fix this + // c = a x + b + // c - b = a x + //G left = a[i] * x[j]; + //G right = c[i][j] - b[i][j]; + //if (left != right) { + // std::cout << "bad " << i << "\n a[i] " << a[i] << " * x[j] " << x[j] << " = " << left << std::endl; + // std::cout << "c[i][j] " << c[i][j] << " - b " << b[i][j] << " = " << right << std::endl; + // throw RTE_LOC; + //} } } } @@ -429,8 +447,8 @@ namespace osuCrypto::Subfield recv.setTimer(timer); send.setTimer(timer); -// recv.mDebug = true; -// send.mDebug = true; + // recv.mDebug = true; + // send.mDebug = true; auto chls = cp::LocalAsyncSocket::makePair(); diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 97d0c5ac..6dbaeeb5 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -57,13 +57,19 @@ namespace tests_libOTe tc.add("EACode_encode_basic_test ", EACode_encode_basic_test); tc.add("ExConvCode_encode_basic_test ", ExConvCode_encode_basic_test); - + tc.add("Subfield_ExConvCode_encode_test ", Subfield::Subfield_ExConvCode_encode_test); + tc.add("Tools_Pprf_expandOne_test ", Tools_Pprf_expandOne_test); tc.add("Tools_Pprf_test ", Tools_Pprf_test); tc.add("Tools_Pprf_inter_test ", Tools_Pprf_inter_test); tc.add("Tools_Pprf_blockTrans_test ", Tools_Pprf_blockTrans_test); tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); + + tc.add("Subfield_Tools_Pprf_test ", Subfield::Subfield_Tools_Pprf_test); + tc.add("Subfield_Noisy_Vole_test ", Subfield::Subfield_Noisy_Vole_test); + tc.add("Subfield_Silent_Vole_test ", Subfield::Subfield_Silent_Vole_test); + tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); @@ -117,10 +123,6 @@ namespace tests_libOTe tc.add("NcoOt_Oos_Test ", NcoOt_Oos_Test); tc.add("NcoOt_genBaseOts_Test ", NcoOt_genBaseOts_Test); - tc.add("Subfield_ExConvCode_encode_test ", Subfield::Subfield_ExConvCode_encode_test); - tc.add("Subfield_Tools_Pprf_test ", Subfield::Subfield_Tools_Pprf_test); - tc.add("Subfield_Noisy_Vole_test ", Subfield::Subfield_Noisy_Vole_test); - tc.add("Subfield_Silent_Vole_test ", Subfield::Subfield_Silent_Vole_test); }); } From 7f7a37b6f42b5cd98000dba29c07385179a80dcb Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Thu, 18 Jan 2024 20:28:30 -0800 Subject: [PATCH 10/23] pprf --- libOTe/Tools/ExConvCode/ExConvCode2.h | 84 +-- libOTe/Tools/ExConvCode/Expander2.h | 255 ++++----- libOTe/Tools/Subfield/Subfield.h | 200 +++++-- libOTe/Tools/Subfield/SubfieldPprf.h | 268 +++++----- .../SoftSpokenOT/SoftSpokenMalOtExt.cpp | 4 + libOTe/Vole/Subfield/NoisyVoleReceiver.h | 65 ++- libOTe/Vole/Subfield/NoisyVoleSender.h | 46 +- libOTe/Vole/Subfield/SilentVoleReceiver.h | 8 +- libOTe/Vole/Subfield/SilentVoleSender.h | 6 +- libOTe_Tests/ExConvCode_Tests.cpp | 20 + libOTe_Tests/Pprf_Tests.cpp | 448 ++++++++++++++++ libOTe_Tests/Pprf_Tests.h | 16 + libOTe_Tests/SilentOT_Tests.cpp | 487 ------------------ libOTe_Tests/SilentOT_Tests.h | 6 - libOTe_Tests/Subfield_Test.h | 7 +- libOTe_Tests/Subfield_Tests.cpp | 290 ++--------- libOTe_Tests/UnitTests.cpp | 7 +- 17 files changed, 1080 insertions(+), 1137 deletions(-) create mode 100644 libOTe_Tests/Pprf_Tests.cpp create mode 100644 libOTe_Tests/Pprf_Tests.h diff --git a/libOTe/Tools/ExConvCode/ExConvCode2.h b/libOTe/Tools/ExConvCode/ExConvCode2.h index 9f7517d2..fb7e1bea 100644 --- a/libOTe/Tools/ExConvCode/ExConvCode2.h +++ b/libOTe/Tools/ExConvCode/ExConvCode2.h @@ -100,15 +100,15 @@ namespace osuCrypto // return code size n. u64 generatorCols() const { return mCodeSize; } - // Compute w = G * e. e will be modified in the computation. - // the computation will be done over F using CoeffCtx::plus - template< - typename F, - typename CoeffCtx, - typename SrcIter, - typename DstIter - > - void dualEncode(SrcIter&& e, DstIter&& w); + //// Compute w = G * e. e will be modified in the computation. + //// the computation will be done over F using CoeffCtx::plus + //template< + // typename F, + // typename CoeffCtx, + // typename SrcIter, + // typename DstIter + //> + //void dualEncode(SrcIter&& e, DstIter&& w); // Compute e[0,...,k-1] = G * e. // the computation will be done over F using CoeffCtx::plus @@ -240,41 +240,41 @@ namespace osuCrypto { // Compute w = G * e. e will be modified in the computation. - template< - typename F, - typename CoeffCtx, - typename SrcIter, - typename DstIter - > - void ExConvCode2::dualEncode(SrcIter&& e, DstIter&& w) - { + //template< + // typename F, + // typename CoeffCtx, + // typename SrcIter, + // typename DstIter + //> + //void ExConvCode2::dualEncode(SrcIter&& e, DstIter&& w) + //{ - static_assert(is_iterator::value, "must pass in an iterator to the data, " __FUNCTION__); - static_assert(is_iterator::value, "must pass in an iterator to the data"); + // static_assert(is_iterator::value, "must pass in an iterator to the data, " __FUNCTION__); + // static_assert(is_iterator::value, "must pass in an iterator to the data"); - // try to deref the back. might bounds check. - (void)*(e + mCodeSize - 1); - (void)*(w + mMessageSize - 1); + // // try to deref the back. might bounds check. + // (void)*(e + mCodeSize - 1); + // (void)*(w + mMessageSize - 1); - if (mSystematic) - { - dualEncode(e); - CoeffCtx::copy(w, w + mMessageSize, e); - setTimePoint("ExConv.encode.memcpy"); - } - else - { + // if (mSystematic) + // { + // dualEncode(e); + // CoeffCtx::copy(w, w + mMessageSize, e); + // setTimePoint("ExConv.encode.memcpy"); + // } + // else + // { - setTimePoint("ExConv.encode.begin"); + // setTimePoint("ExConv.encode.begin"); - accumulate(e); + // accumulate(e); - setTimePoint("ExConv.encode.accumulate"); + // setTimePoint("ExConv.encode.accumulate"); - mExpander.expand(e, w); - setTimePoint("ExConv.encode.expand"); - } - } + // mExpander.expand(e, w); + // setTimePoint("ExConv.encode.expand"); + // } + //} // Compute e[0,...,k-1] = G * e. template @@ -295,11 +295,17 @@ namespace osuCrypto } else { + + setTimePoint("ExConv.encode.begin"); + accumulate(e); + setTimePoint("ExConv.encode.accumulate"); + CoeffCtx::template Vec w; CoeffCtx::resize(w, mMessageSize); - dualEncode(e, w.begin()); + mExpander.expand(e, w.begin()); + setTimePoint("ExConv.encode.expand"); + CoeffCtx::copy(w.begin(), w.end(), e); - //memcpy(e.data(), w.data(), w.size() * sizeof(T)); setTimePoint("ExConv.encode.memcpy"); } diff --git a/libOTe/Tools/ExConvCode/Expander2.h b/libOTe/Tools/ExConvCode/Expander2.h index 2dd6420b..fc83a1fe 100644 --- a/libOTe/Tools/ExConvCode/Expander2.h +++ b/libOTe/Tools/ExConvCode/Expander2.h @@ -64,20 +64,6 @@ namespace osuCrypto u64 generatorRows() const { return mMessageSize; } u64 generatorCols() const { return mCodeSize; } - // compute a eight output. - // the result is written to the dst iterator/ptr. - // - template< - typename CoeffCtx, - typename DstIter, - typename SrcIter - > - void expandEight( - DstIter&& dst, - SrcIter&& ee, - detail::ExpanderModd& prng, - CoeffCtx ctx) const; - template< typename F, @@ -99,40 +85,21 @@ namespace osuCrypto > typename CoeffCtx::template Vec getB(CoeffCtx ctx = {}) const; - }; - template< - typename CoeffCtx, - typename DstIter, - typename SrcIter - > - OC_FORCEINLINE void - ExpanderCode2::expandEight( - DstIter&& dst, - SrcIter&& ee, - detail::ExpanderModd& prng, - CoeffCtx ctx) const - { - u64 rr[8]; - rr[0] = prng.get(); - rr[1] = prng.get(); - rr[2] = prng.get(); - rr[3] = prng.get(); - rr[4] = prng.get(); - rr[5] = prng.get(); - rr[6] = prng.get(); - rr[7] = prng.get(); - - ctx.plus(*(dst + 0), *(dst + 0), *(ee + rr[0])); - ctx.plus(*(dst + 1), *(dst + 1), *(ee + rr[1])); - ctx.plus(*(dst + 2), *(dst + 2), *(ee + rr[2])); - ctx.plus(*(dst + 3), *(dst + 3), *(ee + rr[3])); - ctx.plus(*(dst + 4), *(dst + 4), *(ee + rr[4])); - ctx.plus(*(dst + 5), *(dst + 5), *(ee + rr[5])); - ctx.plus(*(dst + 6), *(dst + 6), *(ee + rr[6])); - ctx.plus(*(dst + 7), *(dst + 7), *(ee + rr[7])); - } + + //template< + // bool Add, + // typename CoeffCtx, + // typename... F, + // typename... SrcDstIterPair + //> + //void expandMany( + // std::tuple out, + // CoeffCtx ctx = {})const; + + }; + template< @@ -147,6 +114,12 @@ namespace osuCrypto DstIter&& output, CoeffCtx ctx) const { + //using P = std::pair; + //expandMany( + // std::tuple

{ P{input, output}}, + // ctx + //); + (void)*(input + (mCodeSize - 1)); (void)*(output + (mMessageSize - 1)); @@ -164,16 +137,30 @@ namespace osuCrypto for (auto j = 0ull; j < mExpanderWeight; ++j) { - // temp[0...7] = expand(input) - expandEight( - output, input, - prng, ctx); + u64 rr[8]; + rr[0] = prng.get(); + rr[1] = prng.get(); + rr[2] = prng.get(); + rr[3] = prng.get(); + rr[4] = prng.get(); + rr[5] = prng.get(); + rr[6] = prng.get(); + rr[7] = prng.get(); + + ctx.plus(*(output + 0), *(output + 0), *(input + rr[0])); + ctx.plus(*(output + 1), *(output + 1), *(input + rr[1])); + ctx.plus(*(output + 2), *(output + 2), *(input + rr[2])); + ctx.plus(*(output + 3), *(output + 3), *(input + rr[3])); + ctx.plus(*(output + 4), *(output + 4), *(input + rr[4])); + ctx.plus(*(output + 5), *(output + 5), *(input + rr[5])); + ctx.plus(*(output + 6), *(output + 6), *(input + rr[6])); + ctx.plus(*(output + 7), *(output + 7), *(input + rr[7])); } } if constexpr (Add == false) { - ctx.zero(output, output + (mMessageSize-i)); + ctx.zero(output, output + (mMessageSize - i)); } for (; i < mMessageSize; ++i, ++output) @@ -185,77 +172,103 @@ namespace osuCrypto } } - - template< - typename F, - typename CoeffCtx - > - inline typename CoeffCtx::template Vec ExpanderCode2::getB(CoeffCtx ctx) const - { - - typename CoeffCtx::template Vec e, x; - ctx.resize(e, mCodeSize); - ctx.resize(x, mMessageSize * mCodeSize); - - for (u64 i = 0; i < e.size(); ++i) - { - // construct the i'th unit vector as input. - ctx.zero(e.begin(), e.end()); - ctx.one(e.begin() + i, e.begin() + i + 1); - - // expand it to geth the i'th row of the matrix - expand(e.begin(), x.begin() + i * mCodeSize); - } - - return x; - } + //template< + // bool Add, + // typename CoeffCtx, + // typename... F, + // typename... SrcDstIterPair + //> + //void ExpanderCode2::expandMany( + // std::tuple inOuts, + // CoeffCtx ctx)const + //{ + + // std::apply([&](auto&&... inOut) {( + // [&] { + // (void)*(std::get<0>(inOut) + (mCodeSize - 1)); + // (void)*(std::get<1>(inOut) + (mMessageSize - 1)); + // }(), ...); }, inOuts); + + + // detail::ExpanderModd prng(mSeed, mCodeSize); + + // auto main = mMessageSize / 8 * 8; + // u64 i = 0; + + // std::vector rr(8 * mExpanderWeight); + + // for (; i < main; i += 8) + // { + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // rr[j * 8 + 0] = prng.get(); + // rr[j * 8 + 1] = prng.get(); + // rr[j * 8 + 2] = prng.get(); + // rr[j * 8 + 3] = prng.get(); + // rr[j * 8 + 4] = prng.get(); + // rr[j * 8 + 5] = prng.get(); + // rr[j * 8 + 6] = prng.get(); + // rr[j * 8 + 7] = prng.get(); + // } + + // std::apply([&](auto&&... inOut) {([&] { + + // auto& input = std::get<0>(inOut); + // auto& output = std::get<1>(inOut); + + // if constexpr (Add == false) + // { + // ctx.zero(output, output + 8); + // } + + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // ctx.plus(*(output + 0), *(output + 0), *(input + rr[j * 8 + 0])); + // ctx.plus(*(output + 1), *(output + 1), *(input + rr[j * 8 + 1])); + // ctx.plus(*(output + 2), *(output + 2), *(input + rr[j * 8 + 2])); + // ctx.plus(*(output + 3), *(output + 3), *(input + rr[j * 8 + 3])); + // ctx.plus(*(output + 4), *(output + 4), *(input + rr[j * 8 + 4])); + // ctx.plus(*(output + 5), *(output + 5), *(input + rr[j * 8 + 5])); + // ctx.plus(*(output + 6), *(output + 6), *(input + rr[j * 8 + 6])); + // ctx.plus(*(output + 7), *(output + 7), *(input + rr[j * 8 + 7])); + // } + + // output += 8; + // }(), ...); }, inOuts); + + // } + + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // rr[j * 8 + 0] = prng.get(); + // rr[j * 8 + 1] = prng.get(); + // rr[j * 8 + 2] = prng.get(); + // rr[j * 8 + 3] = prng.get(); + // rr[j * 8 + 4] = prng.get(); + // rr[j * 8 + 5] = prng.get(); + // rr[j * 8 + 6] = prng.get(); + // rr[j * 8 + 7] = prng.get(); + // } + + // std::apply([&](auto&&... inOut) {([&] { + + // auto& input = std::get<0>(inOut); + // auto& output = std::get<1>(inOut); + // if constexpr (Add == false) + // { + // ctx.zero(output, output + (mMessageSize - i)); + // } + + // for (u64 k = 0; i < mMessageSize; ++i, ++output, ++k) + // { + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // ctx.plus(*output, *output, *(input + rr[j*8 + k])); + // } + // } + // }(), ...); }, inOuts); + //} - // //detail::ExpanderModd prng(mSeed, mCodeSize); - // //PointList points(mMessageSize, mCodeSize); - - // //u64 i = 0; - // //auto main = mMessageSize / 8 * 8; - - // //// for the main phase we process 8 expands in parallel. - // //Matrix rows(8, mExpanderWeight); - // //for (; i < main; i += 8) - // //{ - // // for (auto j = 0ull; j < mExpanderWeight; ++j) - // // { - // // for (u64 k = 0; k < 8; ++k) - // // rows(k, j) = prng.get(); - // // } - - // // for (auto j = 0ull; j < mExpanderWeight; ++j) - // // { - // // for (u64 k = 0; k < 8; ++k) - // // { - // // auto rk = rows[k]; - // // // we could have duplicates that cancel. - // // auto count = std::count(rk.begin(), rk.end(), rk[j]); - // // if (count == 1 || (count > 1 && std::find(rk.begin(), rk.end(), rk[j]) == rk.begin() + j)) - // // points.push_back(i + k, rk[j]); - // // } - // // } - // //} - - // //for (; i < mMessageSize; ++i) - // //{ - // // auto rk = rows[0]; - // // for (auto j = 0ull; j < mExpanderWeight; ++j) - // // rk[j] = prng.get(); - - // // for (auto j = 0ull; j < mExpanderWeight; ++j) - // // { - // // // we could have duplicates that cancel. - // // auto count = std::count(rk.begin(), rk.end(), rk[j]); - // // if (count == 1 || (count > 1 && std::find(rk.begin(), rk.end(), rk[j]) == rk.begin() + j)) - // // points.push_back(i, rk[j]); - // // } - // //} - - // //return points; - //} } diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h index 890f54fe..c5218446 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/Subfield/Subfield.h @@ -2,6 +2,7 @@ #include "libOTe/Vole/Noisy/NoisyVoleSender.h" #include "cryptoTools/Common/BitIterator.h" #include "cryptoTools/Common/BitVector.h" +#include namespace osuCrypto { @@ -10,7 +11,6 @@ namespace osuCrypto { */ struct CoeffCtxInteger { - template static OC_FORCEINLINE void plus(R&& ret, F1&& lhs, F2&& rhs) { ret = lhs + rhs; @@ -40,23 +40,31 @@ namespace osuCrypto { return sizeof(F) * 8; } - + // return the binary decomposition of x. This will be used to + // reconstruct x as + // + // x = sum_{i = 0,...,n} 2^i * binaryDecomposition(x)[i] + // template static OC_FORCEINLINE BitVector binaryDecomposition(F& x) { static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); return { (u8*)&x, sizeof(F) * 8 }; } + // sample an F using the randomness b. template static OC_FORCEINLINE void fromBlock(F& ret, const block& b) { static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); if constexpr (sizeof(F) <= sizeof(block)) { + // if small, just return the bytes of b memcpy(&ret, &b, sizeof(F)); } else { + // if large, we need to expand the seed. using fixed key AES in counter mode + // with b as the IV. auto constexpr size = (sizeof(F) + sizeof(block) - 1) / sizeof(block); std::array buffer; mAesFixedKey.ecbEncCounterMode(b, buffer); @@ -64,20 +72,48 @@ namespace osuCrypto { } } + // return the F element with value 2^power template - static OC_FORCEINLINE void pow(F& ret, u64 power) { + static OC_FORCEINLINE void powerOfTwo(F& ret, u64 power) { static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); memset(&ret, 0, sizeof(F)); *BitIterator((u8*)&ret, power) = 1; } + // A vector like type that can be used to store + // temporaries. + // + // must have: + // * size() + // * operator[i] that returns the i'th F element reference. + // * begin() iterator over the F elements + // * end() iterator over the F elements + template + using Vec = AlignedUnVector; + // resize Vec + template + static void resize(Vec& f, u64 size) + { + f.resize(size); + } + + // the size of F when serialized. + template + static u64 byteSize() + { + return sizeof(F); + } + + // copy a single F element. template static OC_FORCEINLINE void copy(F& dst, const F& src) { dst = src; } + // copy [begin,...,end) into [dstBegin, ...) + // the iterators will point to the same types, i.e. F template static OC_FORCEINLINE void copy( SrcIter begin, @@ -92,48 +128,88 @@ namespace osuCrypto { std::copy(begin, end, dstBegin); } - // must have - // .size() - // operator[] that returns the element. - // begin() iterator - // end() iterator - template - using Vec = AlignedUnVector; - - // the size of F when serialized. - template - static u64 byteSize() - { - return sizeof(F); - } - - - // deserialize buff into dst - template - static void deserialize(Vec& dst, span buff) + // deserialize [begin,...,end) into [dstBegin, ...) + // begin will be a byte pointer/iterator. + // dstBegin will be an F pointer/iterator + template + static void deserialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) { - if (dst.size() * sizeof(F) != buff.size()) + // as written this function is a bit more general than strictly neccessary + // due to serialize(...) redirecting here. + using SrcType = std::remove_reference_t; + using DstType = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "source serialization types must be trivially_copyable."); + static_assert(std::is_trivially_copyable::value, "destination serialization types must be trivially_copyable."); + +#if __cplusplus >= 202002L + //std::contiguous_iterator<> + // static_assert contigous iter in cpp20 +#endif + + + // how many source elem do we have? + auto srcN = std::distance(begin, end); + if (srcN) { - std::cout << "bad buffer size " << LOCATION << std::endl; - std::terminate(); + // the source size in bytes + auto n = srcN * sizeof(SrcType); + + // The byte size must be a multiple of the destination element byte size. + if (n % sizeof(DstType)) + { + std::cout << "bad buffer size. the source buffer (byte) size is not a multiple of the distination value type size." LOCATION << std::endl; + std::terminate(); + } + // the number of destination elements. + auto dstN = n / sizeof(DstType); + + // make sure the pointer math work with this iterator type. + auto beginU8 = (u8*)&*begin; + auto dstBeginU8 = (u8*)&*dstBegin; + + auto dstBackPtr = dstBeginU8 + (n -sizeof(DstType)); + auto dstBackIter = dstBegin + (dstN -1); + + // try to deref the back. might bounds check. + // And check that the pointer math works + if (dstBackPtr != (u8*)&*dstBackIter) + { + std::cout << "bad destination iterator type. pointer arithemtic not correct. " LOCATION << std::endl; + std::terminate(); + } + + auto srcBackPtr = beginU8 + (n - sizeof(SrcType)); + auto srcBackIter = begin + (srcN - 1); + + // try to deref the back. might bounds check. + // And check that the pointer math works + if (srcBackPtr != (u8*)&*srcBackIter) + { + std::cout << "bad source iterator type. pointer arithemtic not correct. " LOCATION << std::endl; + std::terminate(); + } + + // memcpy the bytes + std::memcpy(dstBeginU8, beginU8, n); } - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); - memcpy(dst.data(), buff.data(), buff.size()); } - // serial buff into dst - template - static void serialize(span dst, Vec& buff) + // serialize [begin,...,end) into [dstBegin, ...) + // begin will be an F pointer/iterator + // dstBegin will be a byte pointer/iterator. + template + static void serialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) { - if (buff.size() * sizeof(F) != dst.size()) - { - std::cout << "bad buffer size " << LOCATION << std::endl; - std::terminate(); - } - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); - memcpy(dst.data(), buff.data(), dst.size()); + // for primitive types serialization and deserializaion + // are the same, a memcpy. + deserialize(begin, end, dstBegin); } + + + + // fill the range [begin,..., end) with zeros. + // begin will be an F pointer/iterator. template static void zero(Iter begin, Iter end) { @@ -148,18 +224,28 @@ namespace osuCrypto { } } + // fill the range [begin,..., end) with ones. + // begin will be an F pointer/iterator. template static void one(Iter begin, Iter end) { std::fill(begin, end, 1); } - // resize Vec - template - static void resize(FVec&& f, u64 size) + + // convert F into a string + template + static std::string str(F&& f) { - f.resize(size); + std::stringstream ss; + if constexpr (std::is_same_v, u8>) + ss << int(f); + else + ss << f; + + return ss.str(); } + }; // CoeffCtx for GF fields. @@ -232,6 +318,38 @@ namespace osuCrypto { { return lhs == rhs; } + + // convert F into a string + static std::string str(const F& f) + { + auto delim = "{ "; + std::stringstream ss; + for (u64 i = 0; i < f.size(); ++i) + { + ss << std::exchange(delim, ", "); + + if constexpr (std::is_same_v, u8>) + ss << int(f[i]); + else + ss << f[i]; + + } + ss << " }"; + + return ss.str(); + } + + // convert G into a string + static std::string str(const G& g) + { + std::stringstream ss; + if constexpr (std::is_same_v, u8>) + ss << int(g); + else + ss << g; + + return ss.str(); + } }; template diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h index 4166bcbe..fdf8f0a1 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -116,7 +116,7 @@ namespace osuCrypto template void allocateExpandBuffer( u64 depth, - u64 programPuncturedPoint, + bool programPuncturedPoint, std::vector& buff, span, 2>>& sums, span& leaf) @@ -125,25 +125,24 @@ namespace osuCrypto u64 elementSize = CoeffCtx::byteSize(); using SumType = std::array, 2>; - - // the number of internal levels. We process 8 trees at a time - u64 numSums = (depth - programPuncturedPoint) * 8; - - // the number of leaf level that we will program - u64 numleaf = programPuncturedPoint * 8; - // num of bytes they will take up. - u64 numBytes = numSums * 2 * sizeof(block) + numleaf * 4 * elementSize; + u64 numBytes = + depth * sizeof(SumType) + // each internal level of the tree has a sum + elementSize * 8 * 2 + // we must program 8 inactive F leaves + elementSize * 8 * 2 * programPuncturedPoint; // if we are programing the active lead, then we have 8 more. // allocate the buffer and partition them. buff.resize(numBytes); - sums = span((SumType*)buff.data(), numSums); - leaf = span((u8*)(sums.data() + sums.size()), numleaf * 4 * elementSize); + sums = span((SumType*)buff.data(), depth); + leaf = span((u8*)(sums.data() + sums.size()), + elementSize * 8 * 2 + + elementSize * 8 * 2 * programPuncturedPoint + ); void* sEnd = sums.data() + sums.size(); void* lEnd = leaf.data() + leaf.size(); void* end = buff.data() + buff.size(); - if (sEnd > end || lEnd > end) + if (sEnd > end || lEnd != end) throw RTE_LOC; } @@ -201,7 +200,13 @@ namespace osuCrypto configure(domainSize, pointCount); } - void configure(u64 domainSize, u64 pointCount) { + void configure(u64 domainSize, u64 pointCount) + { + if (domainSize & 1) + throw std::runtime_error("Pprf domain must be even. " LOCATION); + if (mPntCount % 8) + throw std::runtime_error("pointCount must be a multiple of 8 (general case not impl). " LOCATION); + mDomain = domainSize; mDepth = log2ceil(mDomain); mPntCount = pointCount; @@ -236,7 +241,6 @@ namespace osuCrypto // MatrixView o(output.data(), output.size(), 1); // return expand(chls, value, seed, o, oFormat, programPuncturedPoint, numThreads); //} - task<> expand( Socket& chl, const VecF& value, @@ -244,7 +248,8 @@ namespace osuCrypto VecF& output, PprfOutputFormat oFormat, bool programPuncturedPoint, - u64 numThreads) { + u64 numThreads) + { if (programPuncturedPoint) setValue(value); @@ -252,7 +257,7 @@ namespace osuCrypto validateExpandFormat(oFormat, output, mDomain, mPntCount); - MC_BEGIN(task<>, this, numThreads, oFormat, output, seed, &chl, programPuncturedPoint, + MC_BEGIN(task<>, this, numThreads, oFormat, &output, seed, &chl, programPuncturedPoint, treeIndex = u64{}, tree = span>{}, levels = std::vector> >{}, @@ -262,7 +267,6 @@ namespace osuCrypto buff = std::vector{}, encSums = span, 2>>{}, leafMsgs = span{} - ); mTreeAlloc.reserve(numThreads, (1ull << mDepth) + 2); @@ -290,7 +294,7 @@ namespace osuCrypto } // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, programPuncturedPoint, buff, encSums, leafMsgs); + allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs); // exapnd the tree expandOne(seed, treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs); @@ -341,7 +345,7 @@ namespace osuCrypto bool programPuncturedPoint, span>> levels, VecF& leafLevel, - u64 leafOffset, + const u64 leafOffset, span, 2>> encSums, span leafMsgs) { @@ -443,8 +447,8 @@ namespace osuCrypto // encrypt the sums and write them to the output. for (u64 j = 0; j < 8; ++j) { - encSums[d][0][j] = sums[0][j] ^ mBaseOTs[treeIdx + j][d][0]; - encSums[d][1][j] = sums[1][j] ^ mBaseOTs[treeIdx + j][d][1]; + encSums[d][0][j] = sums[0][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][1]; + encSums[d][1][j] = sums[1][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][0]; } } @@ -468,14 +472,14 @@ namespace osuCrypto CoeffCtx::zero(leafSums[1].begin(), leafSums[1].end()); // for the leaf nodes we need to hash both children. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + for (u64 parentIdx = 0, outIdx = leafOffset, childIdx = 0; parentIdx < width; ++parentIdx) { // The value of the parent. auto& parent = level0.data()[parentIdx]; // The bit that indicates if we are on the left child (0) // or on the right child (1). - for (u64 keep = 0; keep < 2; ++keep, ++childIdx, leafOffset += 8) + for (u64 keep = 0; keep < 2; ++keep, ++childIdx, outIdx += 8) { // The child that we will write in this iteration. @@ -487,28 +491,28 @@ namespace osuCrypto // where each half defines one of the children. gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); - CoeffCtx::fromBlock(leafLevel[leafOffset + 0], child[0]); - CoeffCtx::fromBlock(leafLevel[leafOffset + 1], child[1]); - CoeffCtx::fromBlock(leafLevel[leafOffset + 2], child[2]); - CoeffCtx::fromBlock(leafLevel[leafOffset + 3], child[3]); - CoeffCtx::fromBlock(leafLevel[leafOffset + 4], child[4]); - CoeffCtx::fromBlock(leafLevel[leafOffset + 5], child[5]); - CoeffCtx::fromBlock(leafLevel[leafOffset + 6], child[6]); - CoeffCtx::fromBlock(leafLevel[leafOffset + 7], child[7]); + CoeffCtx::fromBlock(leafLevel[outIdx + 0], child[0]); + CoeffCtx::fromBlock(leafLevel[outIdx + 1], child[1]); + CoeffCtx::fromBlock(leafLevel[outIdx + 2], child[2]); + CoeffCtx::fromBlock(leafLevel[outIdx + 3], child[3]); + CoeffCtx::fromBlock(leafLevel[outIdx + 4], child[4]); + CoeffCtx::fromBlock(leafLevel[outIdx + 5], child[5]); + CoeffCtx::fromBlock(leafLevel[outIdx + 6], child[6]); + CoeffCtx::fromBlock(leafLevel[outIdx + 7], child[7]); // leafSum += child auto& leafSum = leafSums[keep]; - CoeffCtx::plus(leafSum[0], leafSum[0], leafLevel[leafOffset + 0]); - CoeffCtx::plus(leafSum[1], leafSum[1], leafLevel[leafOffset + 1]); - CoeffCtx::plus(leafSum[2], leafSum[2], leafLevel[leafOffset + 2]); - CoeffCtx::plus(leafSum[3], leafSum[3], leafLevel[leafOffset + 3]); - CoeffCtx::plus(leafSum[4], leafSum[4], leafLevel[leafOffset + 4]); - CoeffCtx::plus(leafSum[5], leafSum[5], leafLevel[leafOffset + 5]); - CoeffCtx::plus(leafSum[6], leafSum[6], leafLevel[leafOffset + 6]); - CoeffCtx::plus(leafSum[7], leafSum[7], leafLevel[leafOffset + 7]); + CoeffCtx::plus(leafSum[0], leafSum[0], leafLevel[outIdx + 0]); + CoeffCtx::plus(leafSum[1], leafSum[1], leafLevel[outIdx + 1]); + CoeffCtx::plus(leafSum[2], leafSum[2], leafLevel[outIdx + 2]); + CoeffCtx::plus(leafSum[3], leafSum[3], leafLevel[outIdx + 3]); + CoeffCtx::plus(leafSum[4], leafSum[4], leafLevel[outIdx + 4]); + CoeffCtx::plus(leafSum[5], leafSum[5], leafLevel[outIdx + 5]); + CoeffCtx::plus(leafSum[6], leafSum[6], leafLevel[outIdx + 6]); + CoeffCtx::plus(leafSum[7], leafSum[7], leafLevel[outIdx + 7]); } - } + } if (programPuncturedPoint) { @@ -550,10 +554,10 @@ namespace osuCrypto // copy m0 into the output buffer. span buff = leafMsgs.subspan(0, 2 * CoeffCtx::byteSize()); leafMsgs = leafMsgs.subspan(buff.size()); - CoeffCtx::serialize(buff, leafOts); + CoeffCtx::serialize(leafOts.begin(), leafOts.end(), buff.begin()); // encrypt the output buffer. - otMasker.SetSeed(mBaseOTs[treeIdx + j][d][k], divCeil(buff.size(), sizeof(block))); + otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) buff[i] ^= otMasker.get(); @@ -574,10 +578,10 @@ namespace osuCrypto CoeffCtx::copy(leafOts[0], leafSums[k][j]); span buff = leafMsgs.subspan(0, CoeffCtx::byteSize()); leafMsgs = leafMsgs.subspan(buff.size()); - CoeffCtx::serialize(buff, leafOts); + CoeffCtx::serialize(leafOts.begin(), leafOts.end(), buff.begin()); // encrypt the output buffer. - otMasker.SetSeed(mBaseOTs[treeIdx + j][d][k], divCeil(buff.size(), sizeof(block))); + otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) buff[i] ^= otMasker.get(); @@ -620,6 +624,8 @@ namespace osuCrypto void configure(u64 domainSize, u64 pointCount) { + if (domainSize & 1) + throw std::runtime_error("Pprf domain must be even. " LOCATION); mDomain = domainSize; mDepth = log2ceil(mDomain); mPntCount = pointCount; @@ -628,62 +634,20 @@ namespace osuCrypto } - // For output format ByLeafIndex or ByTreeIndex, the choice bits it - // samples are in blocks of mDepth, with mPntCount blocks total (one for - // each punctured point). For ByLeafIndex these blocks encode the punctured - // leaf index in big endian, while for ByTreeIndex they are in - // little endian. + // this function sample mPntCount integers in the range + // [0,domain) and returns these as the choice bits. BitVector sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng) { BitVector choices(mPntCount * mDepth); // The points are read in blocks of 8, so make sure that there is a // whole number of blocks. - mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); + mBaseChoices.resize(mPntCount, mDepth); for (u64 i = 0; i < mPntCount; ++i) { - u64 idx; - switch (format) - { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - } while (idx >= modulus); - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - - if (modulus > mPntCount * mDomain) - throw std::runtime_error("modulus too big. " LOCATION); - if (modulus < mPntCount * mDomain / 2) - throw std::runtime_error("modulus too small. " LOCATION); - - // make sure that at least the first element of this tree - // is within the modulus. - idx = interleavedPoint(0, i, mPntCount, mDomain, format); - if (idx >= modulus) - throw RTE_LOC; - - - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - - idx = interleavedPoint(idx, i, mPntCount, mDomain, format); - } while (idx >= modulus); - - - break; - default: - throw RTE_LOC; - break; - } - + u64 idx = prng.get() % mDomain; + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = *BitIterator((u8*)&idx, j); } for (u64 i = 0; i < mBaseChoices.size(); ++i) @@ -701,33 +665,18 @@ namespace osuCrypto if (choices.size() != baseOtCount()) throw RTE_LOC; - mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); + mBaseChoices.resize(mPntCount, mDepth); for (u64 i = 0; i < mPntCount; ++i) { + u64 idx = 0; for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = choices[mDepth * i + j]; - - switch (format) - { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - if (getActivePath(mBaseChoices[i]) >= mDomain) - throw RTE_LOC; - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: { - auto idx = getActivePath(mBaseChoices[i]); - auto idx2 = interleavedPoint(idx, i, mPntCount, mDomain, format); - if (idx2 > mPntCount * mDomain) - throw std::runtime_error("the base ot choice bits index outside of the domain. see sampleChoiceBits(...). " LOCATION); - break; - } - default: - throw RTE_LOC; - break; + mBaseChoices(i, j) = choices[mDepth * i + j]; + idx |= u64(choices[mDepth * i + j]) << j; } + + if (idx >= mDomain) + throw std::runtime_error("provided choice bits index outside of the domain." LOCATION); } } @@ -764,6 +713,9 @@ namespace osuCrypto } void getPoints(span points, PprfOutputFormat format) { + if ((u64)points.size() != mPntCount) + throw RTE_LOC; + switch (format) { case PprfOutputFormat::ByLeafIndex: @@ -772,20 +724,29 @@ namespace osuCrypto memset(points.data(), 0, points.size() * sizeof(u64)); for (u64 j = 0; j < mPntCount; ++j) { - points[j] = getActivePath(mBaseChoices[j]); + for (u64 k = 0; k < mDepth; ++k) + points[j] |= u64(mBaseChoices(j, k)) << k; + + assert(points[j] < mDomain); } + break; case PprfOutputFormat::Interleaved: case PprfOutputFormat::Callback: - if ((u64)points.size() != mPntCount) - throw RTE_LOC; - if (points.size() % 8) - throw RTE_LOC; - getPoints(points, PprfOutputFormat::ByLeafIndex); - interleavedPoints(points, mDomain, format); + + // in interleaved mode we generate 8 trees in a batch. + // the i'th leaf of these 8 trees are next to eachother. + for (u64 j = 0; j < points.size(); ++j) + { + auto subTree = j % 8; + auto batch = j / 8; + points[j] = (batch * mDomain + points[j]) * 8 + subTree; + } + + //interleavedPoints(points, mDomain, format); break; default: @@ -797,11 +758,16 @@ namespace osuCrypto // programPuncturedPoint says whether the sender is trying to program the // active child to be its correct value XOR delta. If it is not, the // active child will just take a random value. - task<> expand(Socket& chl, VecF& output, PprfOutputFormat oFormat, bool programPuncturedPoint, u64 numThreads) + task<> expand( + Socket& chl, + VecF& output, + PprfOutputFormat oFormat, + bool programPuncturedPoint, + u64 numThreads) { validateExpandFormat(oFormat, output, mDomain, mPntCount); - MC_BEGIN(task<>, this, oFormat, output, &chl, programPuncturedPoint, + MC_BEGIN(task<>, this, oFormat, &output, &chl, programPuncturedPoint, treeIndex = u64{}, tree = span>{}, levels = std::vector>>{}, @@ -842,7 +808,7 @@ namespace osuCrypto } // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, programPuncturedPoint, buff, encSums, leafMsgs); + allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs); MC_AWAIT(chl.recv(buff)); @@ -880,7 +846,7 @@ namespace osuCrypto u64 treeIdx, bool programPuncturedPoint, span>> levels, - VecF leafLevel, + VecF& leafLevel, const u64 outputOffset, span, 2>> theirSums, span leafMsg) @@ -893,9 +859,16 @@ namespace osuCrypto { // For the non-active path, set the child of the root node // as the OT message XOR'ed with the correction sum. - int notAi = mBaseChoices[i + treeIdx][0]; - l1[notAi][i] = mBaseOTs[i + treeIdx][0] ^ theirSums[0][notAi][i]; - l1[notAi ^ 1][i] = ZeroBlock; + + int active = mBaseChoices[i + treeIdx].back(); + l1[active ^ 1][i] = mBaseOTs[i + treeIdx].back() ^ theirSums[0][active ^ 1][i]; + l1[active][i] = ZeroBlock; + //if (!i) + // std::cout << " unmask " + // << mBaseOTs[i + treeIdx].back() << " ^ " + // << theirSums[0][active ^ 1][i] << " = " + // << l1[active ^ 1][i] << std::endl; + } // space for our sums of each level. @@ -930,7 +903,7 @@ namespace osuCrypto // The next level that we want to construct. auto level1 = levels[d + 1]; - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) { // The value of the parent. auto parent = level0[parentIdx]; @@ -995,6 +968,7 @@ namespace osuCrypto mySums[1][5] = mySums[1][5] ^ child1[5]; mySums[1][6] = mySums[1][6] ^ child1[6]; mySums[1][7] = mySums[1][7] ^ child1[7]; + } @@ -1004,25 +978,25 @@ namespace osuCrypto // the index of the leaf node that is active. auto leafIdx = mPoints[i + treeIdx]; - // The index of the active child node. - auto activeChildIdx = leafIdx >> (mDepth - 1 - d); + // The index of the active (missing) child node. + auto missingChildIdx = leafIdx >> (mDepth - 1 - d); // The index of the active child node sibling. - auto inactiveChildIdx = activeChildIdx ^ 1; + auto siblingIdx = missingChildIdx ^ 1; // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; + auto notAi = siblingIdx & 1; // our sums & OTs cancel and we are leaf with the // correct value for the inactive child. - level1[inactiveChildIdx][i] = + level1[siblingIdx][i] = theirSums[d][notAi][i] ^ mySums[notAi][i] ^ - mBaseOTs[i + treeIdx][d]; + mBaseOTs[i + treeIdx][mDepth - 1 - d]; // we have to set the active child to zero so // the next children are predictable. - level1[activeChildIdx][i] = ZeroBlock; + level1[missingChildIdx][i] = ZeroBlock; } } @@ -1052,21 +1026,23 @@ namespace osuCrypto inactiveChildValues[k] = gGgmAes[k].hashBlock(ZeroBlock); CoeffCtx::fromBlock(temp[k], inactiveChildValues[k]); + + // leafSum = -inactiveChildValues CoeffCtx::resize(leafSums[k], 8); CoeffCtx::zero(leafSums[k].begin(), leafSums[k].end()); - CoeffCtx::minus(leafSums[k][0], leafSums[k][0], temp[0]); + CoeffCtx::minus(leafSums[k][0], leafSums[k][0], temp[k]); for (u64 i = 1; i < 8; ++i) CoeffCtx::copy(leafSums[k][i], leafSums[k][0]); } // for leaf nodes both children should be hashed. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + for (u64 parentIdx = 0, childIdx = 0, outputIdx = outputOffset; parentIdx < width; ++parentIdx) { // The value of the parent. auto parent = level0[parentIdx]; - for (u64 keep = 0, outputIdx = outputOffset; keep < 2; ++keep, ++childIdx, outputIdx += 8) + for (u64 keep = 0; keep < 2; ++keep, ++childIdx, outputIdx += 8) { // Each parent is expanded into the left and right children // using a different AES fixed-key. Therefore our OWF is: @@ -1085,6 +1061,8 @@ namespace osuCrypto CoeffCtx::fromBlock(leafLevel[outputIdx + 6], child[6]); CoeffCtx::fromBlock(leafLevel[outputIdx + 7], child[7]); + + auto& leafSum = leafSums[keep]; CoeffCtx::plus(leafSum[0], leafSum[0], leafLevel[outputIdx + 0]); CoeffCtx::plus(leafSum[1], leafSum[1], leafLevel[outputIdx + 1]); @@ -1112,6 +1090,7 @@ namespace osuCrypto for (u64 j = 0; j < 8; ++j) { + // The index of the child on the active path. auto activeChildIdx = mPoints[j + treeIdx]; @@ -1124,14 +1103,15 @@ namespace osuCrypto // offset to the first or second ot message, based on the one we want auto offset = CoeffCtx::template byteSize() * 2 * notAi; + // decrypt the ot string span buff = leafMsg.subspan(offset, CoeffCtx::byteSize() * 2); leafMsg = leafMsg.subspan(buff.size() * 2); - otMasker.SetSeed(mBaseOTs[j + treeIdx][d], divCeil(buff.size(), sizeof(block))); + otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) buff[i] ^= otMasker.get(); - CoeffCtx::deserialize(leafOts, buff); + CoeffCtx::deserialize(buff.begin(), buff.end(), leafOts.begin()); auto out0 = (activeChildIdx & ~1ull) * 8 + j + outputOffset; auto out1 = (activeChildIdx | 1ull) * 8 + j + outputOffset; @@ -1163,11 +1143,11 @@ namespace osuCrypto // decrypt the ot string span buff = leafMsg.subspan(offset, CoeffCtx::byteSize()); leafMsg = leafMsg.subspan(buff.size() * 2); - otMasker.SetSeed(mBaseOTs[j + treeIdx][d], divCeil(buff.size(), sizeof(block))); + otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) buff[i] ^= otMasker.get(); - CoeffCtx::deserialize(leafOts, buff); + CoeffCtx::deserialize(buff.begin(), buff.end(), leafOts.begin()); std::array out{ (activeChildIdx & ~1ull) * 8 + j + outputOffset, @@ -1175,7 +1155,7 @@ namespace osuCrypto }; auto keep = leafLevel.begin() + out[notAi]; - auto zero = leafLevel.begin() + out[notAi^1]; + auto zero = leafLevel.begin() + out[notAi ^ 1]; CoeffCtx::minus(*keep, leafOts[0], leafSums[notAi][j]); CoeffCtx::zero(zero, zero + 1); diff --git a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp index 14a93af5..85ed17c6 100644 --- a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp +++ b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp @@ -16,6 +16,9 @@ namespace osuCrypto if ((u64)messages.data() % 32) throw std::runtime_error("soft spoken requires the messages to by 32 byte aligned. Consider using AlignedUnVector or AlignedVector." LOCATION); + if (messages.size() == 0) + throw std::runtime_error("soft spoken must be called with at least 1 messag." LOCATION); + MC_BEGIN(task<>, this, messages, &prng, &chl, nChunks = u64{}, messagesFullChunks = u64{}, @@ -26,6 +29,7 @@ namespace osuCrypto mHasher = Hasher{} ); + if (!hasBaseOts()) MC_AWAIT(genBaseOts(prng, chl)); diff --git a/libOTe/Vole/Subfield/NoisyVoleReceiver.h b/libOTe/Vole/Subfield/NoisyVoleReceiver.h index 4d739bdf..5587cac3 100644 --- a/libOTe/Vole/Subfield/NoisyVoleReceiver.h +++ b/libOTe/Vole/Subfield/NoisyVoleReceiver.h @@ -39,85 +39,102 @@ namespace osuCrypto { { public: + // for chosen c, compute a such htat + // + // a = b + c * delta + // template - task<> receive(VecG&& y, VecF&& z, PRNG& prng, + task<> receive(VecG& c, VecF& a, PRNG& prng, OtSender& ot, Socket& chl) { - MC_BEGIN(task<>, this, y, z, &prng, &ot, &chl, + MC_BEGIN(task<>, this, &c, &a, &prng, &ot, &chl, otMsg = AlignedUnVector>{}); setTimePoint("NoisyVoleReceiver.ot.begin"); - + otMsg.resize(CoeffCtx::bitSize()); MC_AWAIT(ot.send(otMsg, prng, chl)); setTimePoint("NoisyVoleReceiver.ot.end"); - MC_AWAIT(receive(y, z, prng, otMsg, chl)); + MC_AWAIT(receive(c, a, prng, otMsg, chl)); MC_END(); } + // for chosen c, compute a such htat + // + // a = b + c * delta + // template - task<> receive(VecG&& y, VecF&& z, PRNG& _, + task<> receive(VecG& c, VecF& a, PRNG& _, span> otMsg, Socket& chl) { - MC_BEGIN(task<>, this, y, z, otMsg, &chl, + MC_BEGIN(task<>, this, &c, &a, otMsg, &chl, buff = std::vector{}, msg = typename CoeffCtx::Vec{}, temp = typename CoeffCtx::Vec{}, prng = std::move(PRNG{}) ); - if (y.size() != z.size()) + if (c.size() != a.size()) throw RTE_LOC; - if (z.size() == 0) + if (a.size() == 0) throw RTE_LOC; setTimePoint("NoisyVoleReceiver.begin"); - CoeffCtx::zero(z.begin(), z.end()); - CoeffCtx::resize(msg, otMsg.size() * z.size()); + CoeffCtx::zero(a.begin(), a.end()); + CoeffCtx::resize(msg, otMsg.size() * a.size()); CoeffCtx::resize(temp, 2); for (size_t i = 0, k = 0; i < otMsg.size(); ++i) { - prng.SetSeed(otMsg[i][0], z.size()); + prng.SetSeed(otMsg[i][0], a.size()); // t1 = 2^i - CoeffCtx::pow(temp[1], i); + CoeffCtx::powerOfTwo(temp[1], i); + //std::cout << "2^i " << CoeffCtx::str(temp[1]) << "\n"; - for (size_t j = 0; j < y.size(); ++j, ++k) + for (size_t j = 0; j < c.size(); ++j, ++k) { // msg[i,j] = otMsg[i,j,0] CoeffCtx::fromBlock(msg[k], prng.get()); + //CoeffCtx::zero(msg.begin() + k, msg.begin() + k + 1); + //std::cout << "m" << i << ",0 = " << CoeffCtx::str(msg[k]) << std::endl; - // z[j] -= otMsg[i,j,0] - CoeffCtx::minus(z[j], z[j], msg[k]); + // a[j] += otMsg[i,j,0] + CoeffCtx::plus(a[j], a[j], msg[k]); + //std::cout << "z = " << CoeffCtx::str(a[j]) << std::endl; - // temp = 2^i * y[j] - CoeffCtx::mul(temp[0], temp[1], y[j]); + // temp = 2^i * c[j] + CoeffCtx::mul(temp[0], temp[1], c[j]); + //std::cout << "2^i y = " << CoeffCtx::str(temp[0]) << std::endl; - // msg[i,j] = otMsg[i,j,0] + 2^i * y[j] - CoeffCtx::plus(msg[k], msg[k], temp[0]); + // msg[i,j] = otMsg[i,j,0] + 2^i * c[j] + CoeffCtx::minus(msg[k], msg[k], temp[0]); + //std::cout << "m" << i << ",0 + 2^i y = " << CoeffCtx::str(msg[k]) << std::endl; } - k -= y.size(); - prng.SetSeed(otMsg[i][1], z.size()); + k -= c.size(); + prng.SetSeed(otMsg[i][1], a.size()); - for (size_t j = 0; j < y.size(); ++j, ++k) + for (size_t j = 0; j < c.size(); ++j, ++k) { // temp = otMsg[i,j,1] CoeffCtx::fromBlock(temp[0], prng.get()); + //CoeffCtx::zero(temp.begin(), temp.begin() + 1); + //std::cout << "m" << i << ",1 = " << CoeffCtx::str(temp[0]) << std::endl; // enc one message under the OT msg. - // msg[i,j] = (otMsg[i,j,0] + 2^i * y[j]) - otMsg[i,j,1] + // msg[i,j] = (otMsg[i,j,0] + 2^i * c[j]) - otMsg[i,j,1] CoeffCtx::minus(msg[k], msg[k], temp[0]); + //std::cout << "m" << i << ",0 + 2^i y - m" << i << ",1 = " << CoeffCtx::str(msg[k]) << std::endl << std::endl; } } buff.resize(msg.size() * CoeffCtx::byteSize()); - CoeffCtx::serialize(buff, msg); + CoeffCtx::serialize(msg.begin(), msg.end(), buff.begin()); MC_AWAIT(chl.send(std::move(buff))); setTimePoint("NoisyVoleReceiver.done"); diff --git a/libOTe/Vole/Subfield/NoisyVoleSender.h b/libOTe/Vole/Subfield/NoisyVoleSender.h index 81aede7e..00a75ed9 100644 --- a/libOTe/Vole/Subfield/NoisyVoleSender.h +++ b/libOTe/Vole/Subfield/NoisyVoleSender.h @@ -46,11 +46,15 @@ namespace osuCrypto { public: + // for chosen delta, compute b such htat + // + // a = b + c * delta + // template - task<> send(F x, FVec&& z, PRNG& prng, + task<> send(F delta, FVec& b, PRNG& prng, OtReceiver& ot, Socket& chl) { - MC_BEGIN(task<>, this, x, z, &prng, &ot, &chl, - bv = CoeffCtx::binaryDecomposition(x), + MC_BEGIN(task<>, this, delta, &b, &prng, &ot, &chl, + bv = CoeffCtx::binaryDecomposition(delta), otMsg = AlignedUnVector{ }); otMsg.resize(bv.size()); @@ -59,34 +63,39 @@ namespace osuCrypto { MC_AWAIT(ot.receive(bv, otMsg, prng, chl)); setTimePoint("NoisyVoleSender.ot.end"); - MC_AWAIT(send(x, z, prng, otMsg, chl)); + MC_AWAIT(send(delta, b, prng, otMsg, chl)); MC_END(); } + // for chosen delta, compute b such htat + // + // a = b + c * delta + // template - task<> send(F x, FVec&& z, PRNG& _, + task<> send(F delta, FVec& b, PRNG& _, span otMsg, Socket& chl) { - MC_BEGIN(task<>, this, x, z, otMsg, &chl, + MC_BEGIN(task<>, this, delta, &b, otMsg, &chl, prng = std::move(PRNG{}), buffer = std::vector{}, msg = typename CoeffCtx::Vec{}, temp = typename CoeffCtx::Vec{}, xb = BitVector{}); - xb = CoeffCtx::binaryDecomposition(x); + xb = CoeffCtx::binaryDecomposition(delta); if (otMsg.size() != xb.size()) throw RTE_LOC; setTimePoint("NoisyVoleSender.main"); - // z = 0; - CoeffCtx::zero(z.begin(), z.end()); + // b = 0; + CoeffCtx::zero(b.begin(), b.end()); // receive the the excrypted one shares. - buffer.resize(otMsg.size() * z.size() * CoeffCtx::byteSize()); + buffer.resize(xb.size() * b.size() * CoeffCtx::byteSize()); MC_AWAIT(chl.recv(buffer)); - CoeffCtx::deserialize(msg, buffer); + CoeffCtx::resize(msg, xb.size() * b.size()); + CoeffCtx::deserialize(buffer.begin(), buffer.end(), msg.begin()); setTimePoint("NoisyVoleSender.recvMsg"); @@ -94,26 +103,31 @@ namespace osuCrypto { for (size_t i = 0, k = 0; i < xb.size(); ++i) { // expand the zero shares or one share masks - prng.SetSeed(otMsg[i], z.size()); + prng.SetSeed(otMsg[i], b.size()); // otMsg[i,j, bc[i]] - //auto otMsgi = prng.getBufferSpan(z.size()); + //auto otMsgi = prng.getBufferSpan(b.size()); - for (u64 j = 0; j < (u64)z.size(); ++j, ++k) + for (u64 j = 0; j < (u64)b.size(); ++j, ++k) { // temp = otMsg[i,j, xb[i]] CoeffCtx::fromBlock(temp[0], prng.get()); + //CoeffCtx::zero(temp.begin(), temp.begin() + 1); + //std::cout << "m" << i << ","<()); MC_AWAIT(chl.recv(buffer)); - CoeffCtx::deserialize(delta, buffer); + CoeffCtx::deserialize(buffer.begin(), buffer.end(), delta.begin()); // recv B buffer.resize(CoeffCtx::byteSize() * mA.size()); MC_AWAIT(chl.recv(buffer)); - CoeffCtx::deserialize(B, buffer); + CoeffCtx::deserialize(buffer.begin(), buffer.end(), B.begin()); // recv the noisy values. buffer.resize(CoeffCtx::byteSize() * mNoiseDeltaShare.size()); MC_AWAIT(chl.recvResize(buffer)); - CoeffCtx::deserialize(noiseDeltaShare2, buffer); + CoeffCtx::deserialize(buffer.begin(), buffer.end(), noiseDeltaShare2.begin()); //check that at locations mS[0],...,mS[..] // that we hold a sharing mA, mB of @@ -540,7 +540,7 @@ namespace osuCrypto::Subfield // // That is, I hold mA, mC s.t. // - // delta * mC = mA + mB + // mA = mB + mC * mDelta // CoeffCtx::resize(tempF, 2); diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h index f470e87e..bbe38f97 100644 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -285,8 +285,8 @@ namespace osuCrypto //} // allocate B - mB.resize(0); - mB.resize(mN2); + CoeffCtx::resize(mB, 0); + CoeffCtx::resize(mB, mN2); if (mTimer) mGen.setTimer(*mTimer); @@ -326,8 +326,8 @@ namespace osuCrypto break; } + CoeffCtx::resize(mB, mRequestedNumOTs); - mB.resize(mRequestedNumOTs); mState = State::Default; mNoiseDeltaShares.clear(); diff --git a/libOTe_Tests/ExConvCode_Tests.cpp b/libOTe_Tests/ExConvCode_Tests.cpp index c6fe4b0e..fa1785a7 100644 --- a/libOTe_Tests/ExConvCode_Tests.cpp +++ b/libOTe_Tests/ExConvCode_Tests.cpp @@ -170,6 +170,11 @@ namespace osuCrypto std::copy(x1.begin(), x1.begin() + k, y1.begin()); y2 = y1; code.mExpander.expand(x1.cbegin() + accOffset, y1.begin()); + //using P = std::pair::const_iterator, typename std::vector::iterator>; + //auto p = P{ x1.cbegin() + accOffset, y1.begin() }; + //code.mExpander.expandMany( + // std::tuple

{ p } + //); } else { @@ -210,9 +215,24 @@ namespace osuCrypto } + + void ExConvCode_encode_basic_test(const oc::CLP& cmd) { + //std::vector i0, o0; + //std::vector i1, o1; + //std::vector i2, o2; + + //ExpanderCode2 ex; + //ex.expandMany( + // std::tuple{ + // std::pair{i0.begin(), o0.begin()}, + // std::pair{i1.begin(), o1.begin()}, + // std::pair{i2.begin(), o2.begin()} + // }, {}); + + auto K = cmd.getManyOr("k", { 16ul, 64, 4353 }); auto R = cmd.getManyOr("R", { 2.0, 3.0 }); auto Bw = cmd.getManyOr("bw", { 7, 21 }); diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp new file mode 100644 index 00000000..0d4fd4e6 --- /dev/null +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -0,0 +1,448 @@ +#include "Pprf_Tests.h" + +#include "libOTe/Tools/Subfield/SubfieldPprf.h" +#include "cryptoTools/Common/Log.h" +#include "Common.h" +#include +using namespace osuCrypto; +using namespace tests_libOTe; + + +template +void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) +{ + + u64 depth = log2ceil(domain); + auto pntCount = 8ull; + PRNG prng(CCBlock); + + auto format = PprfOutputFormat::Interleaved; + SilentSubfieldPprfSender sender; + SilentSubfieldPprfReceiver recver; + + sender.configure(domain, pntCount); + recver.configure(domain, pntCount); + + F value = prng.get(); + sender.setValue({ &value, 1 }); + + auto numOTs = sender.baseOtCount(); + std::vector> sendOTs(numOTs); + std::vector recvOTs(numOTs); + BitVector recvBits = recver.sampleChoiceBits(domain * pntCount, format, prng); + + + prng.get(sendOTs.data(), sendOTs.size()); + for (u64 i = 0; i < numOTs; ++i) + { + recvOTs[i] = sendOTs[i][recvBits[i]]; + } + sender.setBase(sendOTs); + recver.setBase(recvOTs); + + std::vector points(8); + recver.getPoints(points, PprfOutputFormat::ByLeafIndex); + + block seed = CCBlock; + + auto sTree = span>{}; + auto sLevels = std::vector>>{}; + auto rTree = span>{}; + auto rLevels = std::vector>>{}; + auto sBuff = std::vector{}; + auto sSums = span, 2>>{}; + auto sLast = span{}; + + TreeAllocator mTreeAlloc; + sLevels.resize(depth); + rLevels.resize(depth); + + + mTreeAlloc.reserve(2, (1ull << depth) + 2); + allocateExpandTree(depth, mTreeAlloc, sTree, sLevels); + allocateExpandTree(depth, mTreeAlloc, rTree, rLevels); + + Ctx::Vec sLeafLevel(8ull << depth); + Ctx::Vec rLeafLevel(8ull << depth); + u64 leafOffset = 0; + + allocateExpandBuffer(depth - 1, program, sBuff, sSums, sLast); + + recver.mPoints.resize(roundUpTo(recver.mPntCount, 8)); + recver.getPoints(recver.mPoints, PprfOutputFormat::ByLeafIndex); + + sender.expandOne(seed, 0, program, sLevels, sLeafLevel, leafOffset, sSums, sLast); + recver.expandOne(0, program, rLevels, rLeafLevel, leafOffset, sSums, sLast); + + bool failed = false; + for (u64 i = 0; i < pntCount; ++i) + { + // the index of the leaf node that is active. + auto leafIdx = points[i]; + //std::cout << "active leaf idx = " << leafIdx << std::endl; + for (u64 d = 1; d < depth; ++d) + { + //u64 width = std::min(domain, 1ull << d); + auto width = divCeil(domain, 1ull << (depth - d)); + + // The index of the active child node. + auto activeChildIdx = leafIdx >> (depth - d); + + // The index of the active child node sibling. + + for (u64 j = 0; j < width; ++j) + { + //std::cout + // << " " << sLevels[d][j][i].get()[0] + // << " " << rLevels[d][j][i].get()[0] + // ; + + if (j == activeChildIdx) + { + //std::cout << "*"; + continue; + } + + + if (sLevels[d][j][i] != rLevels[d][j][i]) + { + //std::cout << " < "; + throw RTE_LOC; + failed = true; + } + + //std::cout << ", "; + } + //std::cout << std::endl; + } + + MatrixView sLeaves(sLeafLevel.data(), sLeafLevel.size() / 8, 8); + MatrixView rLeaves(rLeafLevel.data(), rLeafLevel.size() / 8, 8); + + for (u64 j = 0; j < sLeaves.rows(); ++j) + { + if (j == leafIdx) + { + F exp; + Ctx::plus(exp, sLeaves(j, i), value); + if (program && exp != rLeaves(j, i)) + { + std::cout << i << " exp " << Ctx::str(exp) << " " << Ctx::str(rLeaves(j, i)) << std::endl; + throw RTE_LOC; + } + } + else + { + if (sLeaves(j, i) != rLeaves(j, i)) + throw RTE_LOC; + } + } + } + + if (failed) + throw RTE_LOC; +} + +void Tools_Pprf_expandOne_test(const oc::CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + + for (u64 domain : { 8, 128, 4522}) for (bool program : {true, false}) + { + + Tools_Pprf_expandOne_test_impl(domain, program); + Tools_Pprf_expandOne_test_impl(domain, program); + Tools_Pprf_expandOne_test_impl, u32, CoeffCtxArray>(domain, program); + + } + +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} + + +template +void Tools_Pprf_test_impl( + u64 domain, + u64 numPoints, + bool program, + PprfOutputFormat format, + bool verbose) +{ + + u64 depth = log2ceil(domain); + auto threads = 1; + PRNG prng(CCBlock); + using Vec = typename Ctx::Vec; + + auto sockets = cp::LocalAsyncSocket::makePair(); + + SilentSubfieldPprfSender sender; + SilentSubfieldPprfReceiver recver; + Vec delta; + auto seed = prng.get(); + Ctx::resize(delta, numPoints * program); + for (u64 i = 0; i < delta.size(); ++i) + Ctx::fromBlock(delta[i], seed); + + sender.configure(domain, numPoints); + recver.configure(domain, numPoints); + + auto numOTs = sender.baseOtCount(); + std::vector> sendOTs(numOTs); + std::vector recvOTs(numOTs); + BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); + + prng.get(sendOTs.data(), sendOTs.size()); + for (u64 i = 0; i < numOTs; ++i) + recvOTs[i] = sendOTs[i][recvBits[i]]; + + sender.setBase(sendOTs); + recver.setBase(recvOTs); + + Vec a(numPoints * domain), a2; + Vec b(numPoints * domain), b2; + if (format == PprfOutputFormat::Callback) + { + a2 = std::move(a); + b2 = std::move(b); + a = {}; + b = {}; + sender.mOutputFn = [&](u64 treeIdx, Vec& data){ + auto offset = treeIdx * domain; + std::copy(data.begin(), data.end(), b2.begin() + offset); + }; + recver.mOutputFn = [&](u64 treeIdx, Vec& data) { + auto offset = treeIdx * domain; + std::copy(data.begin(), data.end(), a2.begin() + offset); + }; + } + + std::vector points(numPoints); + recver.getPoints(points, format); + + // a = b + points * delta + auto p0 = sender.expand(sockets[0], delta, prng.get(), b, format, program, threads); + auto p1 = recver.expand(sockets[1], a, format, program, threads); + + + try + { + eval(p0, p1); + } + catch (std::exception& e) + { + sockets[0].close(); + sockets[1].close(); + macoro::sync_wait(macoro::when_all_ready( + sockets[0].flush(), + sockets[1].flush() + )); + throw; + } + + if (format == PprfOutputFormat::Callback) + { + a = std::move(a2); + b = std::move(b2); + } + + switch (format) + { + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + { + + bool failed = false; + for (u64 j = 0; j < numPoints; ++j) + { + for (u64 i = 0; i < domain; ++i) + { + u64 idx = format == osuCrypto::PprfOutputFormat::ByTreeIndex ? + j * domain + i : + i * numPoints + j; + + F exp; + + if (points[j] == i) + { + if (program) + Ctx::plus(exp, b[idx], delta[j]); + else + Ctx::zero(&exp, &exp + 1); + } + else + exp = b[idx]; + + if (program && exp != a[idx]) + { + failed = true; + + if (verbose) + std::cout << Color::Red; + } + if (verbose) + { + std::cout << "r[" << j << "][" << i << "] " << exp << " " << Ctx::str(a[idx]); + if (points[j] == i) + std::cout << " < "; + + std::cout << std::endl << Color::Default; + } + } + if (verbose) + std::cout << "\n"; + } + + if (failed) + throw RTE_LOC; + + break; + } + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::Callback: + { + + bool failed = false; + std::vector index(points.size()); + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(), + [&](std::size_t i, std::size_t j) { return points[i] < points[j]; }); + + auto iIter = index.begin(); + auto leafIdx = points[*iIter]; + F deltaVal; + Ctx::zero(&deltaVal, &deltaVal + 1); + if(program) + deltaVal = delta[*iIter]; + + ++iIter; + for (u64 j = 0; j < a.size(); ++j) + { + F exp, act; + + // a = b + points * delta + + // act = a - b + // = point * delta + Ctx::minus(act, a[j], b[j]); + Ctx::zero(&exp, &exp + 1); + bool active = false; + if (j == leafIdx) + { + active = true; + if (program) + Ctx::copy(exp, deltaVal); + else + Ctx::minus(exp, exp, b[j]); + } + + if (exp != act) + { + failed = true; + if (verbose) + std::cout << Color::Red; + } + + if (verbose) + { + std::cout << j << " exp " << Ctx::str(exp) << " " << Ctx::str(act) + << " a " << Ctx::str(a[j]) << " b " << Ctx::str(b[j]); + + if (active) + std::cout << " < " << deltaVal; + + std::cout << std::endl << Color::Default; + } + + if (j == leafIdx) + { + if (iIter != index.end()) + { + leafIdx = points[*iIter]; + if(program) + deltaVal = delta[*iIter]; + ++iIter; + } + } + } + + if (failed) + throw RTE_LOC; + break; + } + default: + break; + } + + +} + +void Tools_Pprf_inter_test(const CLP& cmd) +{ + auto f = PprfOutputFormat::Interleaved; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } +} + + + +void Tools_Pprf_ByLeafIndex_test(const CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + auto f = PprfOutputFormat::ByLeafIndex; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} + + + +void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + + auto f = PprfOutputFormat::ByTreeIndex; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } + +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} + + + +void Tools_Pprf_callback_test(const oc::CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + auto f = PprfOutputFormat::Callback; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} diff --git a/libOTe_Tests/Pprf_Tests.h b/libOTe_Tests/Pprf_Tests.h new file mode 100644 index 00000000..785d6ce5 --- /dev/null +++ b/libOTe_Tests/Pprf_Tests.h @@ -0,0 +1,16 @@ +#pragma once +// © 2020 Peter Rindal. +// © 2022 Visa. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +#include + +void Tools_Pprf_expandOne_test(const oc::CLP& cmd); +void Tools_Pprf_inter_test(const oc::CLP& cmd); +void Tools_Pprf_ByLeafIndex_test(const oc::CLP& cmd); +void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd); +void Tools_Pprf_callback_test(const oc::CLP& cmd); diff --git a/libOTe_Tests/SilentOT_Tests.cpp b/libOTe_Tests/SilentOT_Tests.cpp index 4cca9d7d..3262dd78 100644 --- a/libOTe_Tests/SilentOT_Tests.cpp +++ b/libOTe_Tests/SilentOT_Tests.cpp @@ -840,490 +840,3 @@ void OtExt_Silent_mal_Test(const oc::CLP& cmd) throw UnitTestSkipped("ENABLE_SILENTOT not defined."); #endif } - -void Tools_Pprf_expandOne_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - u64 depth = cmd.getOr("d", 4);; - u64 domain = (1ull << depth) * 0.75; - auto pntCount = 8ull; - PRNG prng(CCBlock); - - auto format = PprfOutputFormat::Interleaved; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, pntCount); - recver.configure(domain, pntCount); - - block value = prng.get(); - sender.setValue({ &value, 1 }); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * pntCount, format, prng); - - - prng.get(sendOTs.data(), sendOTs.size()); - for (u64 i = 0; i < numOTs; ++i) - { - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - std::vector points(8); - recver.getPoints(points, PprfOutputFormat::ByLeafIndex); - - block seed = CCBlock; - bool program = true; - - auto sTree = span>{}; - auto sLevels = std::vector>>{}; - auto rTree = span>{}; - auto rLevels = std::vector>>{}; - //auto rBuff = std::vector{}; - auto sBuff = std::vector{}; - auto sSums = span, 2>>{}; - auto sLast = span>{}; - - TreeAllocator mTreeAlloc; - sLevels.resize(depth + 1); - rLevels.resize(depth + 1); - - - mTreeAlloc.reserve(2, (1ull << (depth + 1)) + 2); - allocateExpandTree(depth + 1, mTreeAlloc, sTree, sLevels); - allocateExpandTree(depth + 1, mTreeAlloc, rTree, rLevels); - - - allocateExpandBuffer(depth, program, sBuff, sSums, sLast); - //allocateExpandBuffer(depth, program, rbuff, sSums, sLast); - - recver.mPoints.resize(roundUpTo(recver.mPntCount, 8)); - recver.getPoints(recver.mPoints, PprfOutputFormat::ByLeafIndex); - - sender.expandOne(seed, 0, program, sLevels, sSums, sLast); - recver.expandOne(0, program, rLevels, sSums, sLast); - - bool failed = false; - for (u64 i = 0; i < pntCount; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = points[i]; - //std::cout << leafIdx << std::endl; - for (u64 d = 1; d <= depth; ++d) - { - //u64 width = std::min(domain, 1ull << d); - auto width = divCeil(domain, 1ull << (depth - d)); - - // The index of the active child node. - auto activeChildIdx = leafIdx >> (depth - d); - - // The index of the active child node sibling. - - for (u64 j = 0; j < width; ++j) - { - //std::cout - // << " " << sLevels[d][j][i].get()[0] - // << " " << rLevels[d][j][i].get()[0] - // << ", "; - - if (j == activeChildIdx) - { - //std::cout << "*"; - continue; - } - - if (sLevels[d][j][i] != rLevels[d][j][i]) - { - //std::cout << " < "; - failed = true; - } - } - //std::cout << std::endl; - } - } - - if (failed) - throw RTE_LOC; -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif - } - -void Tools_Pprf_test(const CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - u64 depth = cmd.getOr("d", 3);; - u64 domain = 1ull << depth; - auto threads = cmd.getOr("t", 3ull); - u64 numPoints = cmd.getOr("s", 8); - - PRNG prng(ZeroBlock); - - //IOService ios; - //Session s0(ios, "localhost:1212", SessionMode::Server); - //Session s1(ios, "localhost:1212", SessionMode::Client); - //auto sockets[0] = s0.addChannel(); - //auto sockets[1] = s1.addChannel(); - - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::ByLeafIndex; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain, format, prng); - - prng.get(sendOTs.data(), sendOTs.size()); - //sendOTs[cmd.getOr("i",0)] = prng.get(); - - //recvBits[16] = 1; - for (u64 i = 0; i < numOTs; ++i) - { - //recvBits[i] = 0; - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - Matrix sOut(domain, numPoints); - Matrix rOut(domain, numPoints); - std::vector points(numPoints); - recver.getPoints(points, format); - - auto p0 = sender.expand(sockets[0], { &CCBlock,1 }, prng.get(), sOut, format, true, threads); - auto p1 = recver.expand(sockets[1], rOut, format, true, threads); - - eval(p0, p1); - - bool failed = false; - - - for (u64 j = 0; j < numPoints; ++j) - { - - for (u64 i = 0; i < domain; ++i) - { - - auto exp = sOut(i, j); - if (points[j] == i) - exp = exp ^ CCBlock; - - if (neq(exp, rOut(i, j))) - { - failed = true; - - if (cmd.isSet("v")) - std::cout << Color::Red; - } - if (cmd.isSet("v")) - std::cout << "r[" << j << "][" << i << "] " << exp << " " << rOut(i, j) << std::endl << Color::Default; - } - } - - if (failed) - throw RTE_LOC; - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif -} - -void Tools_Pprf_inter_test(const CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - //u64 depth = 6; - //u64 domain = 13;// (1ull << depth) - 7; - //u64 numPoints = 40; - - u64 domain = cmd.getOr("d", 334); - auto threads = cmd.getOr("t", 3ull); - u64 numPoints = cmd.getOr("s", 5) * 8; - //bool mal = cmd.isSet("mal"); - - PRNG prng(ZeroBlock); - - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::Interleaved; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); - //recvBits.randomize(prng); - - //recvBits[16] = 1; - prng.get(sendOTs.data(), sendOTs.size()); - for (u64 i = 0; i < numOTs; ++i) - { - //recvBits[i] = 0; - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - //auto cols = (numPoints * domain + 127) / 128; - Matrix sOut2(numPoints * domain, 1); - Matrix rOut2(numPoints * domain, 1); - std::vector points(numPoints); - recver.getPoints(points, format); - - - auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng.get(), sOut2, format, true, threads); - auto p1 = recver.expand(sockets[1], rOut2, format, true, threads); - - try - { - - eval(p0, p1); - } - catch (std::exception& e) - { - sockets[0].close(); - sockets[1].close(); - macoro::sync_wait(macoro::when_all_ready( - sockets[0].flush(), - sockets[1].flush() - )); - throw; - } - for (u64 i = 0; i < rOut2.rows(); ++i) - { - sOut2(i) = (sOut2(i) ^ rOut2(i)); - } - - - bool failed = false; - for (u64 i = 0; i < sOut2.rows(); ++i) - { - - auto f = std::find(points.begin(), points.end(), i) != points.end(); - - auto exp = f ? AllOneBlock : ZeroBlock; - - if (neq(sOut2(i), exp)) - { - failed = true; - - if (cmd.getOr("v", 0) > 1) - std::cout << Color::Red; - } - if (cmd.getOr("v", 0) > 1) - std::cout << i << " " << sOut2(i) << " " << exp << std::endl << Color::Default; - } - - if (failed) - throw RTE_LOC; - - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif -} - - - -void Tools_Pprf_blockTrans_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - - u64 depth = cmd.getOr("d", 2);; - u64 domain = 1ull << depth; - auto threads = cmd.getOr("t", 1ull); - u64 numPoints = cmd.getOr("s", 8); - - PRNG prng(ZeroBlock); - - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::ByTreeIndex; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain, format, prng); - - prng.get(sendOTs.data(), sendOTs.size()); - //sendOTs[cmd.getOr("i",0)] = prng.get(); - - //recvBits[16] = 1; - for (u64 i = 0; i < numOTs; ++i) - { - //recvBits[i] = 0; - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - Matrix sOut(numPoints, domain); - Matrix rOut(numPoints, domain); - std::vector points(numPoints); - recver.getPoints(points, format); - - cp::sync_wait(cp::when_all_ready( - sender.expand(sockets[0], span{}, prng.get(), sOut, format, false, threads), - recver.expand(sockets[1], rOut, format, false, threads) - )); - - bool failed = false; - - for (u64 j = 0; j < numPoints; ++j) - { - - for (u64 i = 0; i < domain; ++i) - { - auto ss = sOut(j, i); - auto rr = rOut(j, i); - - if (points[j] == i) - { - if (ss == rr || rr != ZeroBlock) - { - failed = true; - - if (cmd.isSet("v")) - std::cout << Color::Red; - } - - } - else - { - if (ss != rr || rr == ZeroBlock) - { - failed = true; - - if (cmd.isSet("v")) - std::cout << Color::Red; - } - } - if (cmd.isSet("v")) - std::cout << "r[" << j << "][" << i << "] " << ss << " " << rr << std::endl << Color::Default; - } - } - - if (failed) - throw RTE_LOC; - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif -} - - - -void Tools_Pprf_callback_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - u64 domain = cmd.getOr("d", 512); - auto threads = cmd.getOr("t", 1ull); - u64 numPoints = cmd.getOr("s", 5) * 8; - - PRNG prng(ZeroBlock); - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::Callback; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); - - prng.get(sendOTs.data(), sendOTs.size()); - for (u64 i = 0; i < numOTs; ++i) - { - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - //auto cols = (numPoints * domain + 127) / 128; - Matrix sOut2(numPoints * domain, 1); - Matrix rOut2(numPoints * domain, 1); - std::vector points(numPoints); - recver.getPoints(points, format); - - sender.mOutputFn = [&](u64 treeIdx, span> data) - { - span d = sOut2; - d = d.subspan(treeIdx * data.size()); - d = d.subspan(0, std::min(d.size(), data.size() * 8)); - memcpy(d.data(), data.data(), d.size_bytes()); - }; - recver.mOutputFn = [&](u64 treeIdx, span> data) - { - span d = rOut2; - d = d.subspan(treeIdx * data.size()); - d = d.subspan(0, std::min(d.size(), data.size() * 8)); - memcpy(d.data(), data.data(), d.size_bytes()); - }; - - - auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng.get(), span{}, format, true, threads); - auto p1 = recver.expand(sockets[1], span{}, format, true, threads); - - eval(p0, p1); - for (u64 i = 0; i < rOut2.rows(); ++i) - { - sOut2(i) = (sOut2(i) ^ rOut2(i)); - } - - bool failed = false; - for (u64 i = 0; i < sOut2.rows(); ++i) - { - - auto f = std::find(points.begin(), points.end(), i) != points.end(); - - auto exp = f ? AllOneBlock : ZeroBlock; - - if (neq(sOut2(i), exp)) - { - failed = true; - - if (cmd.getOr("v", 0) > 1) - std::cout << Color::Red; - } - if (cmd.getOr("v", 0) > 1) - std::cout << i << " " << sOut2(i) << " " << exp << std::endl << Color::Default; - } - - if (failed) - throw RTE_LOC; - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif - } diff --git a/libOTe_Tests/SilentOT_Tests.h b/libOTe_Tests/SilentOT_Tests.h index d93e09be..26ff014d 100644 --- a/libOTe_Tests/SilentOT_Tests.h +++ b/libOTe_Tests/SilentOT_Tests.h @@ -9,12 +9,6 @@ #include -void Tools_Pprf_expandOne_test(const oc::CLP& cmd); -void Tools_Pprf_test(const oc::CLP& cmd); -void Tools_Pprf_trans_test(const oc::CLP& cmd); -void Tools_Pprf_inter_test(const oc::CLP& cmd); -void Tools_Pprf_blockTrans_test(const oc::CLP& cmd); -void Tools_Pprf_callback_test(const oc::CLP& cmd); void OtExt_Silent_random_Test(const oc::CLP& cmd); void OtExt_Silent_correlated_Test(const oc::CLP& cmd); diff --git a/libOTe_Tests/Subfield_Test.h b/libOTe_Tests/Subfield_Test.h index 92d8babb..8df6c24f 100644 --- a/libOTe_Tests/Subfield_Test.h +++ b/libOTe_Tests/Subfield_Test.h @@ -5,9 +5,8 @@ namespace osuCrypto::Subfield { -void Subfield_ExConvCode_encode_test(const oc::CLP& cmd); -void Subfield_Tools_Pprf_test(const oc::CLP& cmd); -void Subfield_Noisy_Vole_test(const oc::CLP& cmd); -void Subfield_Silent_Vole_test(const oc::CLP& cmd); + void Subfield_Tools_Pprf_test(const oc::CLP& cmd); + void Subfield_Noisy_Vole_test(const oc::CLP& cmd); + void Subfield_Silent_Vole_test(const oc::CLP& cmd); } \ No newline at end of file diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp index 7a0414b1..95377320 100644 --- a/libOTe_Tests/Subfield_Tests.cpp +++ b/libOTe_Tests/Subfield_Tests.cpp @@ -14,99 +14,6 @@ namespace osuCrypto::Subfield static_assert(std::is_trivially_copyable_v); using tests_libOTe::eval; - void Subfield_ExConvCode_encode_test(const oc::CLP& cmd) - { - { - using CoeffCtx = DefaultCoeffCtx; - u64 n = 1024; - ExConvCode2 code; - code.config(n / 2, n, 7, 24, true); - - PRNG prng(ZeroBlock); - block delta = prng.get(); - std::vector y(n), z0(n), z1(n); - prng.get(y.data(), y.size()); - prng.get(z0.data(), z0.size()); - for (u64 i = 0; i < n; ++i) - { - z1[i] = z0[i] ^ delta.gf128Mul(y[i]); - } - - code.dualEncode(z1.begin()); - code.dualEncode(z0.begin()); - code.dualEncode(y.begin()); - //code.dualEncode2(z0, y); - - for (u64 i = 0; i < n; ++i) - { - block left = delta.gf128Mul(y[i]); - block right = z1[i] ^ z0[i]; - if (left != right) - throw RTE_LOC; - } - } - - { - using CoeffCtx = DefaultCoeffCtx; - u64 n = 1024; - ExConvCode2 code; - code.config(n / 2, n, 7, 24, true); - - PRNG prng(ZeroBlock); - u8 delta = 111; - std::vector y(n), z0(n), z1(n); - prng.get(y.data(), y.size()); - prng.get(z0.data(), z0.size()); - for (u64 i = 0; i < n; ++i) - { - z1[i] = z0[i] + delta * y[i]; - } - - code.dualEncode(z1.begin()); - code.dualEncode(z0.begin()); - code.dualEncode(y.begin()); - - //code.dualEncode2(z0, y); - - for (u64 i = 0; i < n; ++i) - { - u8 left = delta * y[i]; - u8 right = z1[i] - z0[i]; - if (left != right) - throw RTE_LOC; - } - } - - { - using CoeffCtx = DefaultCoeffCtx; - u64 n = 1024; - ExConvCode2 code; - code.config(n / 2, n, 7, 24, true); - - PRNG prng(ZeroBlock); - u64 delta = 111; - std::vector y(n), z0(n), z1(n); - prng.get(y.data(), y.size()); - prng.get(z0.data(), z0.size()); - for (u64 i = 0; i < n; ++i) - { - z1[i] = z0[i] + delta * y[i]; - } - - code.dualEncode(z1.begin()); - code.dualEncode(z0.begin()); - code.dualEncode(y.begin()); - //code.dualEncode2(z0, y); - - for (u64 i = 0; i < n; ++i) - { - u64 left = delta * y[i]; - u64 right = z1[i] - z0[i]; - if (left != right) - throw RTE_LOC; - } - } - } void Subfield_Tools_Pprf_test(const oc::CLP& cmd) { #if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) @@ -168,165 +75,60 @@ namespace osuCrypto::Subfield #endif } - void Subfield_Noisy_Vole_test(const oc::CLP& cmd) { - - { - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 400); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - u64 x = prng.get(); - std::vector y(n); - std::vector z0(n), z1(n); - prng.get(y.data(), y.size()); - - - using Trait = DefaultCoeffCtx; - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; - - recv.setTimer(timer); - send.setTimer(timer); + template + void subfield_vole_test(u64 n) + { + PRNG prng(CCBlock); - auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); + F delta = prng.get(); + std::vector c(n); + std::vector a(n), b(n); + prng.get(c.data(), c.size()); - BitVector recvChoice((u8*)&x, 64); - std::vector otRecvMsg(64); - std::vector> otSendMsg(64); - prng.get>(otSendMsg); - for (u64 i = 0; i < 64; ++i) - otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; - timer.setTimePoint("ot"); + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; - auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); - auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); + auto chls = cp::LocalAsyncSocket::makePair(); - eval(p0, p1); + BitVector recvChoice = Trait::binaryDecomposition(delta); + std::vector otRecvMsg(recvChoice.size()); + std::vector> otSendMsg(recvChoice.size()); + prng.get>(otSendMsg); + for (u64 i = 0; i < recvChoice.size(); ++i) + otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; - for (u64 i = 0; i < n; ++i) - { - if (x * y[i] != (z1[i] - z0[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); + // compute a,b such that + // + // a = b + c * delta + // + auto p0 = recv.receive(c, a, prng, otSendMsg, chls[0]); + auto p1 = send.send(delta, b, prng, otRecvMsg, chls[1]); - //std::cout << timer << std::endl; - } + eval(p0, p1); + for (u64 i = 0; i < n; ++i) { - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 400); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); + F prod, sum; - constexpr size_t N = 3; - using G = u32; - using F = std::array; - using CoeffCtx = CoeffCtxArray; - u64 bitsF = sizeof(F) * 8;; + Trait::mul(prod, delta, c[i]); + Trait::minus(sum, a[i], b[i]); - static_assert( - std::is_standard_layout::value && - std::is_trivial::value - ); - F x; - CoeffCtx::fromBlock(x, prng.get()); - std::vector y(n); - std::vector z0(n), z1(n); - prng.get(y.data(), y.size()); - - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; - - recv.setTimer(timer); - send.setTimer(timer); - - auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - BitVector recvChoice((u8*)&x, bitsF); - std::vector otRecvMsg(bitsF); - std::vector> otSendMsg(bitsF); - prng.get>(otSendMsg); - for (u64 i = 0; i < bitsF; ++i) - otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; - timer.setTimePoint("ot"); - - auto p0 = recv.receive((span)y, (span)z0, prng, otSendMsg, chls[0]); - auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); - - eval(p0, p1); - // std::cout << "transferred " << (chls[0].bytesSent() + chls[0].bytesReceived()) << std::endl; - timer.setTimePoint("verify"); - - for (u64 i = 0; i < n; ++i) + if (prod != sum) { - for (u64 j = 0; j < N; j++) { - G left = x[j] * y[i]; - G right = z1[i][j] - z0[i][j]; - if (left != right) - { - throw RTE_LOC; - } - } + throw RTE_LOC; } - timer.setTimePoint("done"); - - // std::cout << timer << std::endl; } + } - { - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 400); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - std::vector y(n); - std::vector z0(n), z1(n); - prng.get(y.data(), y.size()); - using F = block; - using G = block; - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; - - recv.setTimer(timer); - send.setTimer(timer); - - auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - size_t k = 128; - BitVector recvChoice((u8*)&x, k); - std::vector otRecvMsg(k); - std::vector> otSendMsg(k); - prng.get>(otSendMsg); - for (u64 i = 0; i < k; ++i) - otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; - timer.setTimePoint("ot"); - - auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); - auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - if (x.gf128Mul(y[i]) != (z1[i] ^ z0[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); + void Subfield_Noisy_Vole_test(const oc::CLP& cmd) + { - //std::cout << timer << std::endl; + for (u64 n : {1, 8, 433}) + { + subfield_vole_test(n); + subfield_vole_test(n); + subfield_vole_test(n); + subfield_vole_test, u32, CoeffCtxArray>(n); } } @@ -361,7 +163,7 @@ namespace osuCrypto::Subfield timer.setTimePoint("net"); timer.setTimePoint("ot"); - // fakeBase(n, nt, prng, x, recv, send); + // fakeBase(n, nt, prng, delta, recv, send); auto p0 = send.silentSend(x, span(z0), prng, chls[0]); auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); @@ -407,7 +209,7 @@ namespace osuCrypto::Subfield timer.setTimePoint("net"); timer.setTimePoint("ot"); - // fakeBase(n, nt, prng, x, recv, send); + // fakeBase(n, nt, prng, delta, recv, send); auto p0 = send.silentSend(x, span(b), prng, chls[0]); auto p1 = recv.silentReceive(span(c), span(a), prng, chls[1]); @@ -420,12 +222,12 @@ namespace osuCrypto::Subfield for (u64 i = 0; i < n; i++) { for (u64 j = 0; j < N; j++) { throw RTE_LOC;// fix this - // c = a x + b - // c - b = a x - //G left = a[i] * x[j]; + // c = a delta + b + // c - b = a delta + //G left = a[i] * delta[j]; //G right = c[i][j] - b[i][j]; //if (left != right) { - // std::cout << "bad " << i << "\n a[i] " << a[i] << " * x[j] " << x[j] << " = " << left << std::endl; + // std::cout << "bad " << i << "\n a[i] " << a[i] << " * delta[j] " << delta[j] << " = " << left << std::endl; // std::cout << "c[i][j] " << c[i][j] << " - b " << b[i][j] << " = " << right << std::endl; // throw RTE_LOC; //} @@ -455,7 +257,7 @@ namespace osuCrypto::Subfield timer.setTimePoint("net"); timer.setTimePoint("ot"); - // fakeBase(n, nt, prng, x, recv, send); + // fakeBase(n, nt, prng, delta, recv, send); auto p0 = send.silentSend(x, span(z0), prng, chls[0]); auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 6dbaeeb5..5011d77c 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -17,7 +17,7 @@ #include "libOTe_Tests/EACode_Tests.h" #include "libOTe/Tools/LDPC/Mtx.h" #include "libOTe_Tests/Subfield_Test.h" - +#include "libOTe_Tests/Pprf_Tests.h" using namespace osuCrypto; namespace tests_libOTe { @@ -57,12 +57,11 @@ namespace tests_libOTe tc.add("EACode_encode_basic_test ", EACode_encode_basic_test); tc.add("ExConvCode_encode_basic_test ", ExConvCode_encode_basic_test); - tc.add("Subfield_ExConvCode_encode_test ", Subfield::Subfield_ExConvCode_encode_test); tc.add("Tools_Pprf_expandOne_test ", Tools_Pprf_expandOne_test); - tc.add("Tools_Pprf_test ", Tools_Pprf_test); tc.add("Tools_Pprf_inter_test ", Tools_Pprf_inter_test); - tc.add("Tools_Pprf_blockTrans_test ", Tools_Pprf_blockTrans_test); + tc.add("Tools_Pprf_ByLeafIndex_test ", Tools_Pprf_ByLeafIndex_test); + tc.add("Tools_Pprf_ByTreeIndex_test ", Tools_Pprf_ByTreeIndex_test); tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); From 843207aa5d1e96eeec3676947c758b03048c49cd Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sat, 20 Jan 2024 11:46:59 -0800 Subject: [PATCH 11/23] mal vole working --- frontend/benchmark.h | 1 - frontend/main.cpp | 8 - libOTe/Tools/LDPC/LdpcDecoder.cpp | 613 --------- libOTe/Tools/LDPC/LdpcDecoder.h | 121 -- libOTe/Tools/LDPC/LdpcEncoder.cpp | 975 -------------- libOTe/Tools/LDPC/LdpcEncoder.h | 1175 ----------------- libOTe/Tools/LDPC/LdpcImpulseDist.cpp | 1132 ---------------- libOTe/Tools/LDPC/LdpcImpulseDist.h | 40 - libOTe/Tools/LDPC/LdpcSampler.cpp | 256 ---- libOTe/Tools/LDPC/LdpcSampler.h | 751 ----------- libOTe/Tools/LDPC/Mtx.cpp | 1 - libOTe/Tools/LDPC/Util.h | 15 +- libOTe/Tools/QuasiCyclicCode.h | 635 ++++----- libOTe/Tools/Subfield/SubfieldPprf.h | 170 +-- libOTe/TwoChooseOne/ConfigureCode.cpp | 60 +- libOTe/TwoChooseOne/ConfigureCode.h | 52 +- .../Silent/SilentOtExtReceiver.cpp | 1 - .../TwoChooseOne/Silent/SilentOtExtReceiver.h | 1 - .../TwoChooseOne/Silent/SilentOtExtSender.h | 5 - libOTe/Vole/Silent/SilentVoleReceiver.cpp | 1 - libOTe/Vole/Silent/SilentVoleReceiver.h | 1 - libOTe/Vole/Silent/SilentVoleSender.cpp | 1 - libOTe/Vole/Silent/SilentVoleSender.h | 4 - libOTe/Vole/Subfield/SilentVoleReceiver.h | 455 ++++--- libOTe/Vole/Subfield/SilentVoleSender.h | 234 ++-- libOTe_Tests/Pprf_Tests.cpp | 19 +- libOTe_Tests/Subfield_Test.h | 2 +- libOTe_Tests/Subfield_Tests.cpp | 291 ++-- libOTe_Tests/UnitTests.cpp | 22 +- libOTe_Tests/Vole_Tests.cpp | 430 ++---- libOTe_Tests/Vole_Tests.h | 2 - 31 files changed, 1083 insertions(+), 6391 deletions(-) delete mode 100644 libOTe/Tools/LDPC/LdpcDecoder.cpp delete mode 100644 libOTe/Tools/LDPC/LdpcDecoder.h delete mode 100644 libOTe/Tools/LDPC/LdpcEncoder.cpp delete mode 100644 libOTe/Tools/LDPC/LdpcEncoder.h delete mode 100644 libOTe/Tools/LDPC/LdpcImpulseDist.cpp delete mode 100644 libOTe/Tools/LDPC/LdpcImpulseDist.h delete mode 100644 libOTe/Tools/LDPC/LdpcSampler.cpp delete mode 100644 libOTe/Tools/LDPC/LdpcSampler.h diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 1f70a490..6d04550d 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -4,7 +4,6 @@ #include #include "libOTe/Tools/Tools.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" #include "libOTe/Tools/QuasiCyclicCode.h" diff --git a/frontend/main.cpp b/frontend/main.cpp index bdbb8550..5fe2b603 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -28,7 +28,6 @@ using namespace osuCrypto; #include "ExampleSilent.h" #include "ExampleVole.h" #include "ExampleMessagePassing.h" -#include "libOTe/Tools/LDPC/LdpcImpulseDist.h" #include "libOTe/Tools/LDPC/Util.h" #include "cryptoTools/Crypto/RandomOracle.h" #include "libOTe/Tools/EACode/EAChecker.h" @@ -134,13 +133,6 @@ int main(int argc, char** argv) EAChecker(cmd); return 0; } -#ifdef ENABLE_LDPC - if (cmd.isSet("ldpc")) - { - LdpcDecode_impulse(cmd); - return 0; - } -#endif // unit tests. if (cmd.isSet(unitTestTag)) diff --git a/libOTe/Tools/LDPC/LdpcDecoder.cpp b/libOTe/Tools/LDPC/LdpcDecoder.cpp deleted file mode 100644 index c0fb7edd..00000000 --- a/libOTe/Tools/LDPC/LdpcDecoder.cpp +++ /dev/null @@ -1,613 +0,0 @@ -#include "LdpcDecoder.h" -#ifdef ENABLE_LDPC - -#include -#include "Mtx.h" -#include "LdpcEncoder.h" -#include "LdpcSampler.h" -#include "Util.h" -#include -#include -#include -#include - -#include "LdpcImpulseDist.h" -namespace osuCrypto { - - - auto nan = std::nan(""); - - void LdpcDecoder::init(SparseMtx& H) - { - mH = H; - auto n = mH.cols(); - auto m = mH.rows(); - - assert(n > m); - mK = n - m; - - mR.resize(m, n); - mM.resize(m, n); - - mW.resize(n); - } - - - std::vector LdpcDecoder::bpDecode(span codeword, u64 maxIter) - { - auto n = mH.cols(); - - - assert(codeword.size() == n); - - std::array wVal{ { mP / (1 - mP), (1 - mP) / mP} }; - - // #1 - for (u64 i = 0; i < n; ++i) - { - assert(codeword[i] < 2); - mW[i] = wVal[codeword[i]]; - } - - return bpDecode(mW); - } - - std::vector LdpcDecoder::bpDecode(span lr, u64 maxIter) - { - - - auto n = mH.cols(); - auto m = mH.rows(); - - // #1 - for (u64 i = 0; i < n; ++i) - { - //assert(codeword[i] < 2); - mW[i] = lr[i]; - - for (auto j : mH.mCols[i]) - { - mR(j, i) = mW[i]; - } - } - - std::vector c(n); - std::vector rr; rr.reserve(100); - for (u64 ii = 0; ii < maxIter; ii++) - { - // #2 - for (u64 j = 0; j < m; ++j) - { - rr.resize(mH.mRows[j].size()); - for (u64 i : mH.mRows[j]) - { - // \Pi_{k in Nj \ {i} } (r_k^j + 1)/(r_k^j - 1) - double v = 1; - auto jj = 0; - for (u64 k : mH.mRows[j]) - { - rr[jj++] = mR(j, k); - if (k != i) - { - auto r = mR(j, k); - v *= (r + 1) / (r - 1); - } - } - - // m_j^i - auto mm = (v + 1) / (v - 1); - mM(j, i) = mm; - } - } - - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - // j indexes a row, [1,...,m] - for (u64 j : mH.mCols[i]) - { - // r_i^j = w_i * Pi_{k in Ni \ {j} } m_k^i - mR(j, i) = mW[i]; - - // j indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - if (k != j) - { - mR(j, i) *= mM(k, i); - } - } - } - } - - mL.resize(n); - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - //L(ci | wi, m^i) - mL[i] = mW[i]; - - // k indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - assert(mM(k, i) != nan); - mL[i] *= mM(k, i); - } - - c[i] = (mL[i] >= 1) ? 0 : 1; - - - mL[i] = std::log(mL[i]); - } - - if (check(c)) - { - c.resize(n - m); - return c; - } - } - - return {}; - } - - double sgn(double x) - { - if (x >= 0) - return 1; - return -1; - } - - u8 sgnBool(double x) - { - if (x >= 0) - return 0; - return 1; - } - - double phi(double x) - { - assert(x > 0); - x = std::min(20.0, x); - auto t = std::tanh(x * 0.5); - return -std::log(t); - } - - std::ostream& operator<<(std::ostream& o, const Matrix& m) - { - for (u64 i = 0; i < m.rows(); ++i) - { - for (u64 j = 0; j < m.cols(); ++j) - { - o << std::setw(4) << std::setfill(' ') << m(i, j) << " "; - } - - o << std::endl; - } - - return o; - } - - - std::vector LdpcDecoder::logbpDecode2(span codeword, u64 maxIter) - { - - auto n = mH.cols(); - //auto m = mH.rows(); - - assert(codeword.size() == n); - - std::array wVal{ - {std::log(mP / (1 - mP)), - std::log((1 - mP) / mP) - } - }; - - std::vector w(n); - for (u64 i = 0; i < n; ++i) - w[i] = wVal[codeword[i]]; - - return logbpDecode2(w, maxIter); - - } - std::vector LdpcDecoder::logbpDecode2(span llr, u64 maxIter) - { - auto n = mH.cols(); - auto m = mH.rows(); - std::vector c(n); - mL.resize(c.size()); - - - for (u64 i = 0; i < n; ++i) - { - mW[i] = llr[i]; - - for (auto j : mH.mCols[i]) - { - mR(j, i) = mW[i]; - } - } - - for (u64 ii = 0; ii < maxIter; ii++) - { - // #2 - for (u64 j = 0; j < m; ++j) - { - double v = 0; - u8 s = 1; - for (u64 k : mH.mRows[j]) - { - auto rr = mR(j, k); - v += phi(abs(rr)); - s ^= sgnBool(rr); - } - - - for (u64 k : mH.mRows[j]) - { - auto vv = phi(abs(mR(j, k))); - auto ss = sgnBool(mR(j, k)); - vv = phi(v - vv); - - mM(j, k) = (s ^ ss) ? vv : -vv; - } - } - - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - mL[i] = mW[i]; - for (u64 k : mH.mCols[i]) - { - mL[i] += mM(k, i); - } - for (u64 k : mH.mCols[i]) - { - // r_i^j = w_i * Pi_{k in Ni \ {j} } m_k^i - mR(k, i) = mL[i] - mM(k, i); - } - - c[i] = (mL[i] >= 0) ? 0 : 1; - } - - if (check(c)) - { - if (mAllowZero == false && isZero(c)) - continue; - - c.resize(n - m); - return c; - } - } - - return {}; - } - - std::vector LdpcDecoder::altDecode(span codeword, bool minSum, u64 maxIter) - { - auto _N = mH.cols(); - std::array wVal{ - {std::log(mP / (1 - mP)), - std::log((1 - mP) / mP) - } - }; - - std::vector w(_N); - for (u64 i = 0; i < _N; ++i) - { - w[i] = wVal[codeword[i]]; - } - return altDecode(w, minSum, maxIter); - } - - std::vector LdpcDecoder::altDecode(span w, bool min_sum, u64 maxIter) - { - - auto _N = mH.cols(); - auto _M = mH.rows(); - - for (u64 i = 0; i < _N; ++i) - { - mW[i] = w[i]; - } - - mL = mW; - std::vector decoded_cw(_N); - - - std::vector > forward_msg(_M); - std::vector > back_msg(_M); - for (u64 r = 0; r < _M; ++r) { - forward_msg[r].resize(mH.row(r).size()); - back_msg[r].resize(mH.row(r).size()); - } - auto maxLL = 20.0; - - for (u64 iter = 0; iter < maxIter; ++iter) { - - for (u64 r = 0; r < _M; ++r) { - - for (u64 c1 = 0; c1 < (u64)mH.row(r).size(); ++c1) { - double tmp = 1; - if (min_sum) - tmp = maxLL; - - for (u64 c2 = 0; c2 < (u64)mH.row(r).size(); ++c2) { - if (c1 == c2) - continue; - - auto i_col2 = mH.row(r)[c2]; - - double l1 = mL[i_col2] - back_msg[r][c2]; - l1 = std::min(l1, maxLL); - l1 = std::max(l1, -maxLL); - - if (min_sum) { - double sign_tmp = tmp < 0 ? -1 : 1; - double sign_l1 = l1 < 0.0 ? -1 : 1; - - tmp = sign_tmp * sign_l1 * std::min(std::abs(l1), std::abs(tmp)); - } - else - tmp = tmp * tanh(l1 / 2); - } - - - if (min_sum) { - forward_msg[r][c1] = tmp; - } - else { - forward_msg[r][c1] = 2 * atanh(tmp); - } - } - } - - back_msg = forward_msg; - - mL = mW; - - for (u64 r = 0; r < _M; ++r) { - - for (u64 i = 0; i < (u64)mH.row(r).size(); ++i) { - auto c = mH.row(r)[i]; - mL[c] += back_msg[r][i]; - } - } - - for (u64 c = 0; c < _N; ++c) { - decoded_cw[c] = mL[c] > 0 ? 0 : 1; - } - - if (check(decoded_cw)) { - decoded_cw.resize(_N - _M); - return decoded_cw; - } - - } // Iteration loop end - - return {}; - - //} - - } - - std::vector LdpcDecoder::minSumDecode(span codeword, u64 maxIter) - { - - auto n = mH.cols(); - auto m = mH.rows(); - - assert(codeword.size() == n); - - std::array wVal{ - {std::log(mP / (1 - mP)), - std::log((1 - mP) / mP)} }; - - auto nan = std::nan(""); - std::fill(mR.begin(), mR.end(), nan); - std::fill(mM.begin(), mM.end(), nan); - - // #1 - for (u64 i = 0; i < n; ++i) - { - assert(codeword[i] < 2); - mW[i] = wVal[codeword[i]]; - - for (auto j : mH.mCols[i]) - { - mR(j, i) = mW[i]; - } - } - - std::vector c(n); - std::vector rr; rr.reserve(100); - for (u64 ii = 0; ii < maxIter; ii++) - { - // #2 - for (u64 j = 0; j < m; ++j) - { - rr.resize(mH.mRows[j].size()); - for (u64 i : mH.mRows[j]) - { - // \Pi_{k in Nj \ {i} } (r_k^j + 1)/(r_k^j - 1) - double v = std::numeric_limits::max(); - double s = 1; - - for (u64 k : mH.mRows[j]) - { - if (k != i) - { - assert(mR(j, k) != nan); - - v = std::min(v, std::abs(mR(j, k))); - - s *= sgn(mR(j, k)); - } - } - - // m_j^i - mM(j, i) = s * v; - } - } - - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - // j indexes a row, [1,...,m] - for (u64 j : mH.mCols[i]) - { - // r_i^j = w_i * Pi_{k in Ni \ {j} } m_k^i - mR(j, i) = mW[i]; - - // j indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - if (k != j) - { - assert(mM(k, i) != nan); - mR(j, i) += mM(k, i); - } - } - } - } - mL.resize(n); - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - //log L(ci | wi, m^i) - mL[i] = mW[i]; - - // k indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - assert(mM(k, i) != nan); - mL[i] += mM(k, i); - } - - c[i] = (mL[i] >= 0) ? 0 : 1; - } - - if (check(c)) - { - c.resize(n - m); - return c; - } - } - - return {}; - } - - bool LdpcDecoder::check(const span& data) { - - // j indexes a row, [1,...,m] - for (u64 j = 0; j < mH.rows(); ++j) - { - u8 sum = 0; - - // i indexes a column, [1,...,n] - for (u64 i : mH.mRows[j]) - { - sum ^= data[i]; - } - - if (sum) - { - return false; - } - } - return true; - - } - - void tests::LdpcDecode_pb_test(const oc::CLP& cmd) - { - u64 rows = cmd.getOr("r", 40); - u64 cols = static_cast(rows * cmd.getOr("e", 2.0)); - u64 colWeight = cmd.getOr("cw", 3); - u64 dWeight = cmd.getOr("dw", 3); - u64 gap = cmd.getOr("g", 2); - - auto k = cols - rows; - - SparseMtx H; - LdpcEncoder E; - LdpcDecoder D; - - for (u64 i = 0; i < 2; ++i) - { - oc::PRNG prng(block(i, 1)); - bool b = true; - u64 tries = 0; - while (b) - { - H = sampleTriangularBand(rows, cols, - colWeight, gap, dWeight, false, prng); - // H = sampleTriangular(rows, cols, colWeight, gap, prng); - b = !E.init(H, gap); - - ++tries; - } - - D.init(H); - std::vector m(k), m2, code(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - E.encode(code, m); - auto ease = 1ull; - - - u64 min = 9999999; - u64 ee = 3; - while (true) - { - auto c = code; - for (u64 j = 0; j < ease; ++j) - { - c[j] ^= 1; - } - - u64 e = 0; - m2 = D.logbpDecode2(c); - - - if (m2 != m) - { - ++e; - min = std::min(min, ease); - } - m2 = D.altDecode(c, false); - - if (m2 != m) - { - ++e; - min = std::min(min, ease); - } - - - m2 = D.altDecode(c, true); - - if (m2 != m) - { - min = std::min(min, ease); - ++e; - } - if (e == ee) - break; - ++ease; - } - if (ease < 4 || min < 4) - { - throw std::runtime_error(LOCATION); - } - - //std::cout << "high " << ease << std::endl; - } - return; - - } - - - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/LdpcDecoder.h b/libOTe/Tools/LDPC/LdpcDecoder.h deleted file mode 100644 index 9b7009a3..00000000 --- a/libOTe/Tools/LDPC/LdpcDecoder.h +++ /dev/null @@ -1,121 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" - -#ifdef ENABLE_LDPC - -#include -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/Matrix.h" -#include "cryptoTools/Common/CLP.h" -#include "Mtx.h" -#include -namespace osuCrypto -{ - - class LdpcDecoder - { - public: - u64 mK = 0; - - bool mAllowZero = true; - double mP = 0.9; - Matrix mM, mR; - - std::vector> mMM, mRR; - std::vector mMData, mRData; - std::vector mW, mL; - - SparseMtx mH; - - LdpcDecoder() = default; - LdpcDecoder(const LdpcDecoder&) = default; - LdpcDecoder(LdpcDecoder&&) = default; - - - LdpcDecoder(SparseMtx& H) - { - init(H); - } - - void init(SparseMtx& H); - - std::vector bpDecode(span codeword, u64 maxIter = 1000); - std::vector logbpDecode2(span codeword, u64 maxIter = 1000); - std::vector altDecode(span codeword, bool minSum, u64 maxIter = 1000); - std::vector minSumDecode(span codeword, u64 maxIter = 1000); - - std::vector bpDecode(span codeword, u64 maxIter = 1000); - std::vector logbpDecode2(span codeword, u64 maxIter = 1000); - std::vector altDecode(span codeword, bool minSum, u64 maxIter = 1000); - - std::vector decode(span codeword, u64 maxIter = 1000) - { - return logbpDecode2(codeword, maxIter); - } - - bool check(const span& data); - static bool isZero(span data) - { - for (auto d : data) - if (d) return false; - return true; - } - - - inline static double LLR(double d) - { - assert(d > -1 && d < 1); - return std::log(d / (1 - d)); - } - inline static double LR(double d) - { - assert(d > -1 && d < 1); - return (d / (1 - d)); - } - - - inline static double encodeLLR(double p, bool bit) - { - assert(p >= 0.5); - assert(p < 1); - - p = bit ? (1 - p) : p; - - return LLR(p); - } - - inline static double encodeLR(double p, bool bit) - { - assert(p > 0.5); - assert(p < 1); - - p = bit ? (1 - p) : p; - - return LR(p); - } - - inline static u32 decodeLLR(double l) - { - return (l >= 0 ? 0 : 1); - } - - }; - - - - namespace tests - { - void LdpcDecode_pb_test(const oc::CLP& cmd); - - } - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/LdpcEncoder.cpp b/libOTe/Tools/LDPC/LdpcEncoder.cpp deleted file mode 100644 index fb570cbc..00000000 --- a/libOTe/Tools/LDPC/LdpcEncoder.cpp +++ /dev/null @@ -1,975 +0,0 @@ -#include "LdpcEncoder.h" -#ifdef ENABLE_INSECURE_SILVER - -//#include -#include -#include "cryptoTools/Crypto/PRNG.h" -#include "cryptoTools/Common/Timer.h" -#include "LdpcSampler.h" -#include "libOTe/Tools/Tools.h" -namespace osuCrypto -{ - namespace details - { - constexpr std::array, 16> SilverRightEncoder::diagMtx_g16_w5_seed1_t36; - constexpr std::array, 32> SilverRightEncoder::diagMtx_g32_w11_seed2_t36; - constexpr std::array SilverRightEncoder::mOffsets; - } - - bool LdpcEncoder::init(SparseMtx H, u64 gap) - { - -#ifndef NDEBUG - for (u64 i = H.cols() - H.rows() + gap, j = 0; i < H.cols(); ++i, ++j) - { - auto row = H.row(j); - assert(row[row.size() - 1] == i); - } -#endif - auto c0 = H.cols() - H.rows(); - auto c1 = c0 + gap; - auto r0 = H.rows() - gap; - - mN = H.cols(); - mM = H.rows(); - mGap = gap; - - - mA = H.subMatrix(0, 0, r0, c0); - mB = H.subMatrix(0, c0, r0, gap); - mC = H.subMatrix(0, c1, r0, H.rows() - gap); - mD = H.subMatrix(r0, 0, gap, c0); - mE = H.subMatrix(r0, c0, gap, gap); - mF = H.subMatrix(r0, c1, gap, H.rows() - gap); - mH = std::move(H); - - mCInv.init(mC); - - if (mGap) - { - SparseMtx CB; - - // CB = C^-1 B - mCInv.mult(mB, CB); - - //assert(mC.invert().mult(mB) == CB); - // Ep = F C^-1 B - mEp = mF.mult(CB); - //// Ep = F C^-1 B + E - mEp += mE; - mEp = mEp.invert(); - - return (mEp.rows() != 0); - } - - return true; - } - - void LdpcEncoder::encode(span c, span mm) - { - assert(mm.size() == mM); - assert(c.size() == mN); - - auto s = mM - mGap; - auto iter = c.begin() + mM; - span m(c.begin(), iter); - span p(iter, iter + mGap); iter += mGap; - span pp(iter, c.end()); - - - // m = mm - std::copy(mm.begin(), mm.end(), m.begin()); - std::fill(c.begin() + mM, c.end(), 0); - - // pp = A * m - mA.multAdd(m, pp); - - if (mGap) - { - std::vector t(s); - - // t = C^-1 pp = C^-1 A m - mCInv.mult(pp, t); - - // p = - F t + D m = -F C^-1 A m + D m - mF.multAdd(t, p); - mD.multAdd(m, p); - - // p = - Ep p = -Ep (-F C^-1 A m + D m) - t = mEp.mult(p); - std::copy(t.begin(), t.end(), p.begin()); - - // pp = pp + B p - mB.multAdd(p, pp); - } - - // pp = C^-1 pp - mCInv.mult(pp, pp); - } - - namespace details - { - - void DiagInverter::init(const SparseMtx& c) - { - mC = (&c); - assert(mC->rows() == mC->cols()); - -#ifndef NDEBUG - for (u64 i = 0; i < mC->rows(); ++i) - { - auto row = mC->row(i); - assert(row.size() && row[row.size() - 1] == i); - - for (u64 j = 0; j < row.size() - 1; ++j) - { - assert(row[j] < row[j + 1]); - } - } -#endif - } - - - std::vector DiagInverter::getSteps() - { - std::vector steps; - - u64 n = mC->cols(); - u64 nn = mC->cols() * 2; - - for (u64 i = 0; i < mC->rows(); ++i) - { - auto row = mC->row(i); - PointList points(nn, nn); - - points.push_back({ i, n + i }); - assert(row[row.size() - 1] == i); - for (u64 j = 0; j < (u64)row.size() - 1; ++j) - { - points.push_back({ i,row[j] }); - } - - for (u64 j = 0; j < i; ++j) - { - points.push_back({ j,j }); - } - - for (u64 j = 0; j < n; ++j) - { - points.push_back({ n + j, n + j }); - } - steps.emplace_back(nn, nn, points); - - } - - return steps; - } - - // computes x = mC^-1 * y - void DiagInverter::mult(span y, span x) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mC); - assert(mC->rows() == y.size()); - assert(mC->cols() == x.size()); - - for (u64 i = 0; i < mC->rows(); ++i) - { - auto row = mC->row(i); - x[i] = y[i]; - - assert(row[row.size() - 1] == i); - for (u64 j = 0; j < (u64)row.size() - 1; ++j) - { - x[i] ^= x[row[j]]; - } - } - } - void DiagInverter::mult(const SparseMtx& y, SparseMtx& x) - { - auto n = mC->rows(); - assert(n == y.rows()); - //assert(n == x.rows()); - //assert(y.cols() == x.cols()); - - auto xNumRows = n; - auto xNumCols = y.cols(); - - std::vector& xCol = x.mDataCol; xCol.reserve(y.mDataCol.size()); - std::vector - colSizes(xNumCols), - rowSizes(xNumRows); - - for (u64 c = 0; c < y.cols(); ++c) - { - auto cc = y.col(c); - auto yIter = cc.begin(); - auto yEnd = cc.end(); - - auto xColBegin = xCol.size(); - for (u64 i = 0; i < n; ++i) - { - u8 bit = 0; - if (yIter != yEnd && *yIter == i) - { - bit = 1; - ++yIter; - } - - auto rr = mC->row(i); - auto mIter = rr.begin(); - auto mEnd = rr.end() - 1; - - auto xIter = xCol.begin() + xColBegin; - auto xEnd = xCol.end(); - - while (mIter != mEnd && xIter != xEnd) - { - if (*mIter < *xIter) - ++mIter; - else if (*xIter < *mIter) - ++xIter; - else - { - bit ^= 1; - ++xIter; - ++mIter; - } - } - - if (bit) - { - xCol.push_back(i); - ++rowSizes[i]; - } - } - colSizes[c] = xCol.size(); - } - - x.mCols.resize(colSizes.size()); - auto iter = xCol.begin(); - for (u64 i = 0; i < colSizes.size(); ++i) - { - auto end = xCol.begin() + colSizes[i]; - x.mCols[i] = end - iter ? SparseMtx::Col(span(iter, end)) : SparseMtx::Col{}; - iter = end; - } - - x.mRows.resize(rowSizes.size()); - x.mDataRow.resize(x.mDataCol.size()); - iter = x.mDataRow.begin(); - //auto prevSize = 0ull; - for (u64 i = 0; i < rowSizes.size(); ++i) - { - auto end = iter + rowSizes[i]; - - rowSizes[i] = 0; - //auto ss = rowSizes[i]; - //rowSizes[i] = rowSizes[i] - prevSize; - //prevSize = ss; - - x.mRows[i] = SparseMtx::Row(iter != end ? span(iter, end) : span{}); - iter = end; - } - - iter = xCol.begin(); - for (u64 i = 0; i < x.cols(); ++i) - { - for (u64 j : x.col(i)) - { - x.mRows[j][rowSizes[j]++] = i; - } - } - - } - - - - void SilverLeftEncoder::init(u64 rows, std::vector rs, u64 rep) - { - mRows = rows; - mWeight = rs.size(); - assert(mWeight > 4); - - mRs = rs; - mYs.resize(rs.size()); - std::set s; - std::vector reps(rep); - - u64 trials = 0; - for (u64 i = 0; i < mWeight; ++i) - { - mYs[i] = u64(rows * rs[i]) % rows; - while ( - //(rep && reps[mYs[i] % rep]) || - (rep && mYs[i] % rep) || - s.insert(mYs[i]).second == false) - { - mYs[i] = (mYs[i] + 1) % rows; - - if (++trials > 1000) - { - std::cout << "these ratios resulted in too many collisions. " LOCATION << std::endl; - throw std::runtime_error("these ratios resulted in too many collisions. " LOCATION); - } - } - - if(rep) - reps[mYs[i] % rep] = 1; - } - } - - void SilverLeftEncoder::init(u64 rows, SilverCode code, u64 rep) - { - auto weight = code.weight(); - switch (weight) - { - case 5: - init(rows, { { 0, 0.372071, 0.576568, 0.608917, 0.854475} }, rep); - - // 0 0.0494143 0.437702 0.603978 0.731941 - - // yset 3,785 - // 0 0.372071 0.576568 0.608917 0.854475 - break; - case 11: - init(rows, { { 0, 0.00278835, 0.0883852, 0.238023, 0.240532, 0.274624, 0.390639, 0.531551, 0.637619, 0.945265, 0.965874} }, rep); - // 0 0.00278835 0.0883852 0.238023 0.240532 0.274624 0.390639 0.531551 0.637619 0.945265 0.965874 - break; - default: - // no preset parameters - throw RTE_LOC; - } - } - - void SilverLeftEncoder::encode(span pp, span m) - { - auto cols = mRows; - assert(pp.size() == mRows); - assert(m.size() == cols); - - // pp = pp + A * m - auto v = mYs; - for (u64 i = 0; i < cols; ++i) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - pp[row] ^= m[i]; - - ++v[j]; - if (v[j] == mRows) - v[j] = 0; - } - } - } - - - void SilverLeftEncoder::getPoints(PointList& points) - { - auto cols = mRows; - auto v = mYs; - - for (u64 i = 0; i < cols; ++i) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - - points.push_back({ row, i }); - - ++v[j]; - if (v[j] == mRows) - v[j] = 0; - } - } - } - - SparseMtx SilverLeftEncoder::getMatrix() - { - PointList points(mRows, mRows); - getPoints(points); - return SparseMtx(mRows, mRows, points); - } - - void SilverLeftEncoder::getTransPoints(PointList& points) - { - - auto cols = mRows; - auto v = mYs; - - for (u64 i = 0; i < cols; ) - { - auto end = cols; - for (u64 j = 0; j < mWeight; ++j) - { - if (v[j] == mRows) - v[j] = 0; - - auto jEnd = cols - v[j] + i; - end = std::min(end, jEnd); - } - //T* __restrict P = &pp[i]; - //T* __restrict PE = &pp[end]; - - while (i != end) - { - points.push_back({ i,i }); - - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - points.push_back({ i, row + cols }); - ++v[j]; - } - ++i; - } - } - - } - - SparseMtx SilverLeftEncoder::getTransMatrix() - { - PointList points(mRows, 2 * mRows); - getTransPoints(points); - return points; - } - - void SilverRightEncoder::init(u64 rows, SilverCode c, bool extend) - { - mGap = c.gap(); - assert(mGap < rows); - mCode = c; - mRows = rows; - mExtend = extend; - mCols = extend ? rows : rows - mGap; - } - - void SilverRightEncoder::encode(span x, span y) - { - assert(mExtend); - for (u64 i = 0; i < mRows; ++i) - { - x[i] = y[i]; - if (mCode == SilverCode::Weight5) - { - for (u64 j = 0; j < 4; ++j) - { - auto col = i - 16 + diagMtx_g16_w5_seed1_t36[i & 15][j]; - if (col < mRows) - x[i] = x[i] ^ x[col]; - } - } - - if (mCode == SilverCode::Weight11) - { - for (u64 j = 0; j < 10; ++j) - { - auto col = i - 32 + diagMtx_g32_w11_seed2_t36[i & 31][j]; - if (col < mRows) - x[i] = x[i] ^ x[col]; - } - } - - for (u64 j = 0; j < mOffsets.size(); ++j) - { - auto p = i - mOffsets[j] - mGap; - if (p >= mRows) - break; - x[i] = x[i] ^ x[p]; - } - } - } - bool gVerbose = false; - void SilverRightEncoder::getPoints(PointList& points, u64 colOffset) - { - auto rr = mRows; - - for (u64 i = 0; i < rr; ++i) - { - if (i < mCols) - points.push_back({ i, i + colOffset }); - - switch (mCode) - { - case SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = i - 16 + diagMtx_g16_w5_seed1_t36[i & 15][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - case SilverCode::Weight11: - for (u64 j = 0; j < 10; ++j) - { - auto col = i - 32 + diagMtx_g32_w11_seed2_t36[i & 31][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - default: - break; - } - - for (u64 j = 0; j < mOffsets.size(); ++j) - { - auto col = i - mOffsets[j] - mGap; - if (col < mRows) - points.push_back({ i, col + colOffset }); - } - - if(gVerbose) - std::cout << i << "\n" << SparseMtx(points) << std::endl; - - } - - if (mExtend) - { - for (u64 i = rr; i < mRows; ++i) - points.push_back({ i, i + colOffset }); - } - } - - SparseMtx SilverRightEncoder::getMatrix() - { - PointList points(mRows, cols()); - getPoints(points, 0); - return SparseMtx(mRows, cols(), points); - } - std::vector SilverRightEncoder::getTransMatrices() - { - std::vector ret;ret.reserve(mRows); - auto colOffset = 0; - auto rr = mRows; - - for (u64 i = 0; i < rr; ++i) - { - PointList points(mRows, mRows); - for (u64 j = 0; j < mRows; ++j) - points.push_back({ j,j }); - - switch (mCode) - { - case SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = i - 16 + diagMtx_g16_w5_seed1_t36[i & 15][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - case SilverCode::Weight11: - for (u64 j = 0; j < 10; ++j) - { - auto col = i - 32 + diagMtx_g32_w11_seed2_t36[i & 31][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - default: - break; - } - - for (u64 j = 0; j < mOffsets.size(); ++j) - { - auto col = i - mOffsets[j] - mGap; - if (col < mRows) - points.push_back({ i, col + colOffset }); - } - - if (gVerbose) - std::cout << i << "\n" << SparseMtx(points) << std::endl; - - ret.emplace_back(points); - } - - //if (mExtend) - //{ - // for (u64 i = rr; i < mRows; ++i) - // points.push_back({ i, i + colOffset }); - //} - return ret; - } - SparseMtx SilverRightEncoder::getTransMatrix() - { - auto Es = getTransMatrices(); - SparseMtx E = Es.back(); - //std::cout << "Eb" << E << std::endl; - for (u64 i = Es.size() - 2; i < Es.size(); --i) - { - E = E * Es[i]; - } - - return E; - } - } - - - void tests::LdpcEncoder_diagonalSolver_test() - { - u64 n = 10; - u64 w = 4; - u64 t = 10; - - oc::PRNG prng(block(0, 0)); - std::vector x(n), y(n); - for (u64 tt = 0; tt < t; ++tt) - { - SparseMtx H = sampleTriangular(n, 0.5, prng); - - //std::cout << H << std::endl; - - for (auto& yy : y) - yy = prng.getBit(); - - details::DiagInverter HInv(H); - - HInv.mult(y, x); - - auto z = H.mult(x); - - assert(z == y); - - auto Y = sampleFixedColWeight(n, w, 3, prng, false); - - SparseMtx X; - - HInv.mult(Y, X); - - auto Z = H * X; - - assert(Z == Y); - - } - - - - - return; - } - - void tests::LdpcEncoder_encode_test() - { - - u64 rows = 16; - u64 cols = rows * 2; - u64 colWeight = 4; - u64 dWeight = 3; - u64 gap = 6; - - auto k = cols - rows; - - assert(gap >= dWeight); - - oc::PRNG prng(block(0, 2)); - - - SparseMtx H; - LdpcEncoder E; - - - //while (b) - for (u64 i = 0; i < 40; ++i) - { - bool b = true; - //std::cout << " +====================" << std::endl; - while (b) - { - H = sampleTriangularBand( - rows, cols, - colWeight, gap, - dWeight, false, prng); - //H = sampleTriangular(rows, cols, colWeight, gap, prng); - b = !E.init(H, gap); - } - - //std::cout << H << std::endl; - - std::vector m(k), c(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - - E.encode(c, m); - - auto ss = H.mult(c); - - //for (auto sss : ss) - // std::cout << int(sss) << " "; - //std::cout << std::endl; - assert(ss == std::vector(H.rows(), 0)); - - } - return; - - } - - void tests::LdpcEncoder_encode_g0_test() - { - - u64 rows = 17; - u64 cols = rows * 2; - u64 colWeight = 4; - - auto k = cols - rows; - - oc::PRNG prng(block(0, 2)); - - - SparseMtx H; - LdpcEncoder E; - - - //while (b) - for (u64 i = 0; i < 40; ++i) - { - bool b = true; - //std::cout << " +====================" << std::endl; - while (b) - { - //H = sampleTriangularBand( - // rows, cols, - // colWeight, 0, - // 1, false, prng); - // - // - - - H = sampleTriangularBand( - rows, cols, - colWeight, 8, - colWeight, colWeight, 0, 0, { 5,31 }, true, true, true, prng, prng); - //H = sampleTriangular(rows, cols, colWeight, gap, prng); - b = !E.init(H, 0); - } - - //std::cout << H << std::endl; - - std::vector m(k), c(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - - E.encode(c, m); - - auto ss = H.mult(c); - - assert(ss == std::vector(H.rows(), 0)); - - } - return; - } - - - void tests::LdpcS1Encoder_encode_test() - { - u64 rows = 100; - SilverCode weight = SilverCode::Weight5; - - details::SilverLeftEncoder zz; - zz.init(rows, weight); - - std::vector m(rows), pp(rows); - - PRNG prng(ZeroBlock); - - for (u64 i = 0; i < rows; ++i) - m[i] = prng.getBit(); - - zz.encode(pp, m); - - auto p2 = zz.getMatrix().mult(m); - - if (p2 != pp) - { - throw RTE_LOC; - } - - } - - - - void tests::LdpcS1Encoder_encode_Trans_test() - { - u64 rows = 100; - SilverCode weight = SilverCode::Weight5; - - - details::SilverLeftEncoder zz; - zz.init(rows, weight); - - std::vector m(rows), pp(rows); - - PRNG prng(ZeroBlock); - - for (u64 i = 0; i < rows; ++i) - m[i] = prng.getBit(); - - zz.dualEncode(pp, m); - - auto At = zz.getMatrix().dense().transpose().sparse(); - auto p2 = At.mult(m); - - //std::cout << "At\n" << At << std::endl; - //std::cout << "M\n" << zz.getTransMatrix() << std::endl; - - if (p2 != pp) - { - throw RTE_LOC; - } - } - - - void tests::LdpcComposit_RegRepDiagBand_encode_test() - { - u64 rows = 100; - - - PRNG prng(ZeroBlock); - using namespace details; - using Encoder = SilverEncoder; - - for (auto code : { SilverCode::Weight5 , SilverCode::Weight11 }) - { - - //{ - // gVerbose = false; - // details::SilverRightEncoder rr; - // rr.init(rows, code, true); - // auto RR = rr.getMatrix(); - // std::cout << "R\n" << RR << std::endl; - // gVerbose = false; - //} - - Encoder enc; - enc.mL.init(rows, code); - enc.mR.init(rows, code, true); - - auto H = enc.getMatrix(); - - LdpcEncoder enc2; - enc2.init(H, 0); - - auto cols = enc.cols(); - auto k = cols - rows; - std::vector m(k), c(cols), c2(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - enc.encode(c, m); - enc2.encode(c2, m); - - auto ss = H.mult(c); - - if (ss != std::vector(H.rows(), 0)) - throw RTE_LOC; - if (c2 != c) - throw RTE_LOC; - - auto R = enc.mR.getMatrix(); - auto M = enc.mR.getTransMatrix(); - //std::cout << "R\n" << R << std::endl; - //std::cout << "M\n" << M << std::endl; - - auto g = SilverCode::gap(code); - auto d1 = enc.mR.mOffsets[0] + g; - auto d2 = enc.mR.mOffsets[1] + g; - - - for (u64 i = 0; i < R.cols() - g; ++i) - { - std::set ss; - - auto col = R.col(i); - for (auto cc : col) - { - ss.insert(cc); - } - - auto expSize = SilverCode::weight(code); - - if (d1 < R.rows()) - { - ++expSize; - if (ss.find(d1) == ss.end()) - throw RTE_LOC; - } - - if (d2 < R.rows()) - { - ++expSize; - if (ss.find(d2) == ss.end()) - throw RTE_LOC; - } - - if (col.size() != expSize) - throw RTE_LOC; - - ++d1; - ++d2; - - - } - } - } - - - void tests::LdpcComposit_RegRepDiagBand_Trans_test() - { - - u64 rows = 101; - - using namespace details; - SilverCode code = SilverCode::Weight5; - - - using Encoder = LdpcCompositEncoder; - PRNG prng(ZeroBlock); - - Encoder enc; - enc.mL.init(rows, code); - enc.mR.init(rows, code, true); - - auto H = enc.getMatrix(); - auto HD = H.dense(); - auto Gt = computeGen(HD).transpose(); - - - LdpcEncoder enc2; - enc2.init(H, 0); - - - auto cols = enc.cols(); - auto k = cols - rows; - - std::vector c(cols); - - for (auto& cc : c) - cc = prng.getBit(); - //std::cout << "\n"; - - auto mOld = c; - enc2.dualEncode(mOld); - mOld.resize(k); - - ////std::cout << "R\n" << enc.mR.getMatrix() << std::endl << std::endl; - //std::cout << "L\n" << enc.mL.getTransMatrix() << std::endl << std::endl; - - auto mCur = c; - enc.dualEncode(mCur); - mCur.resize(k); - } - - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/LdpcEncoder.h b/libOTe/Tools/LDPC/LdpcEncoder.h deleted file mode 100644 index e1a53413..00000000 --- a/libOTe/Tools/LDPC/LdpcEncoder.h +++ /dev/null @@ -1,1175 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" -#ifdef ENABLE_INSECURE_SILVER - -#include "Mtx.h" -#include "cryptoTools/Crypto/PRNG.h" -#include "cryptoTools/Common/Timer.h" -#include - -namespace osuCrypto -{ - namespace details - { - class DiagInverter - { - public: - - const SparseMtx* mC = nullptr; - - DiagInverter() = default; - DiagInverter(const DiagInverter&) = default; - DiagInverter& operator=(const DiagInverter&) = default; - - DiagInverter(const SparseMtx& c) { init(c); } - - void init(const SparseMtx& c); - - // returns a list of matrices to multiply with to encode. - std::vector getSteps(); - - // computes x = mC^-1 * y - void mult(span y, span x); - - // computes x = mC^-1 * y - void mult(const SparseMtx& y, SparseMtx& x); - - template - void cirTransMult(span x, span y) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mC); - assert(mC->cols() == x.size()); - - for (u64 i = mC->rows() - 1; i != ~u64(0); --i) - { - auto row = mC->row(i); - assert(row[row.size() - 1] == i); - for (u64 j = 0; j < (u64)row.size() - 1; ++j) - { - auto col = row[j]; - assert(col < i); - x[col] = x[col] ^ x[i]; - } - } - } - }; - } - - // a generic encoder for any g-ALT LDPC code - class LdpcEncoder - { - public: - - LdpcEncoder() = default; - LdpcEncoder(const LdpcEncoder&) = default; - LdpcEncoder(LdpcEncoder&&) = default; - - - u64 mN, mM, mGap; - SparseMtx mA, mH; - SparseMtx mB; - SparseMtx mC; - SparseMtx mD; - SparseMtx mE, mEp; - SparseMtx mF; - details::DiagInverter mCInv; - - // initialize the encoder with the given matrix which is - // in gap-ALT form. - bool init(SparseMtx mtx, u64 gap); - - // encode the given message m and populate c with - // the resulting codeword. - void encode(span c, span m); - - // perform the circuit transpose of the encoding algorithm. - // the inputs and output is c. - template - void dualEncode(span c) - { - if (mGap) - throw std::runtime_error(LOCATION); - assert(c.size() == mN); - - auto k = mN - mM; - span pp(c.subspan(k, mM)); - span mm(c.subspan(0, k)); - - mCInv.cirTransMult(pp, mm); - - for (u64 i = 0; i < k; ++i) - { - for (auto row : mA.col(i)) - { - c[i] = c[i] ^ pp[row]; - } - } - } - }; - - // INSECURE - // enum struct to specify which silver code variant to use. - // https://eprint.iacr.org/2021/1150 - // see also https://eprint.iacr.org/2023/882 - struct SilverCode - { - enum code - { - Weight5 = 5, - Weight11 = 11, - }; - code mCode; - - SilverCode() = default; - SilverCode(const SilverCode&) = default; - SilverCode& operator=(const SilverCode&) = default; - SilverCode(const code& c) : mCode(c) {} - - bool operator==(code c) { return mCode == c; } - bool operator!=(code c) { return mCode != c; } - operator code() - { - return mCode; - } - - u64 weight() { return weight(mCode); } - u64 gap() { - return gap(mCode); - } - static u64 weight(code c) - { - return (u64)c; - } - static u64 gap(code c) - { - switch (c) - { - case Weight5: - return 16; - break; - case Weight11: - return 32; - break; - default: - throw RTE_LOC; - break; - } - } - }; - - namespace details - { - - // the silver encoder for the left half of the matrix. - // This part of the code constists of mWeight diagonals. - // The positions of these diagonals is determined by - class SilverLeftEncoder - { - public: - - u64 mRows, mWeight; - std::vector mYs; - std::vector mRs; - - // a custom initialization which the specifed diagonal - // factional positions. - void init(u64 rows, std::vector rs, u64 rep = 0); - - // initialize the left half with the given silver presets. - void init(u64 rows, SilverCode code, u64 rep = 0); - - // encode the given message m and populate c with - // the resulting codeword. - void encode(span pp, span m); - - u64 cols() { return mRows; } - - u64 rows() { return mRows; } - - // populates points with the matrix representation of this - // encoder. - void getPoints(PointList& points); - - // return parity check matrix representation of this encoder. - SparseMtx getMatrix(); - - - // return generator matrix representation of this encoder. - void getTransPoints(PointList& points); - - // return generator matrix representation of this encoder. - SparseMtx getTransMatrix(); - - // perform the circuit transpose of the encoding algorithm. - // the output it written to ppp. - template - void dualEncode(span ppp, span mm); - - // perform the circuit transpose of the encoding algorithm twice. - // the output it written to ppp0 and ppp1. - template - void dualEncode2( - span ppp0, span ppp1, - span mm0, span mm1); - }; - - - class SilverRightEncoder - { - public: - - static constexpr std::array, 16> diagMtx_g16_w5_seed1_t36 - { { - { {0, 4, 11, 15 }}, - { {0, 8, 9, 10 } }, - { {1, 2, 10, 14 } }, - { {0, 5, 8, 15 } }, - { {3, 13, 14, 15 } }, - { {2, 4, 7, 8 } }, - { {0, 9, 12, 15 } }, - { {1, 6, 8, 14 } }, - { {4, 5, 6, 14 } }, - { {1, 3, 8, 13 } }, - { {3, 4, 7, 8 } }, - { {3, 5, 9, 13 } }, - { {8, 11, 12, 14 } }, - { {6, 10, 12, 13 } }, - { {2, 7, 8, 13 } }, - { {0, 6, 10, 15 } } - } }; - - - static constexpr std::array, 32> diagMtx_g32_w11_seed2_t36 - { { - { { 6, 7, 8, 12, 16, 17, 20, 22, 24, 25 } }, - { { 0, 1, 6, 10, 12, 13, 17, 19, 30, 31 } }, - { { 1, 4, 7, 10, 12, 16, 21, 22, 30, 31 } }, - { { 3, 5, 9, 13, 15, 21, 23, 25, 26, 27 } }, - { { 3, 8, 9, 14, 17, 19, 24, 25, 26, 28 } }, - { { 3, 11, 12, 13, 14, 16, 17, 21, 22, 30 } }, - { { 2, 4, 5, 11, 12, 17, 22, 24, 30, 31 } }, - { { 5, 8, 11, 12, 13, 17, 18, 20, 27, 29 } }, - { {13, 16, 17, 18, 19, 20, 21, 22, 26, 30 } }, - { { 3, 8, 13, 15, 16, 17, 19, 20, 21, 27 } }, - { { 0, 2, 4, 5, 6, 21, 23, 26, 28, 30 } }, - { { 2, 4, 6, 8, 10, 11, 22, 26, 28, 30 } }, - { { 7, 9, 11, 14, 15, 16, 17, 18, 24, 30 } }, - { { 0, 3, 7, 12, 13, 18, 20, 24, 25, 28 } }, - { { 1, 5, 7, 8, 12, 13, 21, 24, 26, 27 } }, - { { 0, 16, 17, 19, 22, 24, 25, 27, 28, 31 } }, - { { 0, 6, 7, 15, 16, 18, 22, 24, 29, 30 } }, - { { 2, 3, 4, 7, 15, 17, 18, 20, 22, 26 } }, - { { 2, 3, 9, 16, 17, 19, 24, 27, 29, 31 } }, - { { 1, 3, 5, 7, 13, 14, 20, 23, 24, 27 } }, - { { 0, 2, 3, 9, 10, 14, 19, 20, 21, 25 } }, - { { 4, 13, 16, 20, 21, 23, 25, 27, 28, 31 } }, - { { 1, 2, 5, 6, 9, 13, 15, 17, 20, 24 } }, - { { 0, 4, 7, 8, 12, 13, 20, 23, 28, 30 } }, - { { 0, 3, 4, 5, 8, 9, 23, 25, 26, 28 } }, - { { 0, 3, 4, 7, 8, 10, 11, 15, 21, 26 } }, - { { 5, 6, 7, 8, 10, 11, 15, 21, 22, 25 } }, - { { 0, 1, 2, 3, 8, 9, 22, 24, 27, 28 } }, - { { 1, 2, 13, 14, 15, 16, 19, 22, 29, 30 } }, - { { 2, 14, 15, 16, 19, 20, 25, 26, 28, 29 } }, - { { 8, 9, 11, 12, 13, 15, 17, 18, 23, 27 } }, - { { 0, 2, 4, 5, 6, 7, 10, 12, 14, 19 } } - } }; - - static constexpr std::array mOffsets{ {5,31} }; - - u64 mGap; - - u64 mRows, mCols; - SilverCode mCode; - bool mExtend; - - // initialize the right half of the silver encoder - // with the given preset. extend should be true if - // this is used to encode and false to determine the - // effective minimum distance. - void init(u64 rows, SilverCode c, bool extend); - - u64 cols() { return mCols; } - - u64 rows() { return mRows; } - - // encode the given message m and populate c with - // the resulting codeword. - void encode(span x, span y); - - // populates points with the matrix representation of this - // encoder. - void getPoints(PointList& points, u64 colOffset); - - // return matrix representation of this encoder. - SparseMtx getMatrix(); - - - // return generator matrix representation of this encoder. - std::vector getTransMatrices(); - SparseMtx getTransMatrix(); - - - // perform the circuit transpose of the encoding algorithm. - // the inputs and output is x. - template - void dualEncode(span x); - - // perform the circuit transpose of the encoding algorithm twice. - // the inputs and output is x0 and x1. - template - void dualEncode2(span x0, span x1); - }; - - // a full encoder expressed and the left and right encoder. - template - class LdpcCompositEncoder : public TimerAdapter - { - public: - - LEncoder mL; - REncoder mR; - - template - void encode(spanc, span mm) - { - assert(mm.size() == cols() - rows()); - assert(c.size() == cols()); - - auto s = rows(); - auto iter = c.begin() + s; - span m(c.begin(), iter); - span pp(iter, c.end()); - - // m = mm - std::copy(mm.begin(), mm.end(), m.begin()); - std::fill(c.begin() + s, c.end(), 0); - - // pp = A * m - mL.encode(pp, mm); - - // pp = C^-1 pp - mR.encode(pp, pp); - } - - template - void dualEncode(span c) - { - auto k = cols() - rows(); - assert(c.size() == cols()); - setTimePoint("encode_begin"); - span pp(c.subspan(k, rows())); - - mR.template dualEncode(pp); - setTimePoint("diag"); - mL.template dualEncode(c.subspan(0, k), pp); - setTimePoint("L"); - - } - - - template - void dualEncode2(span c0, span c1) - { - auto k = cols() - rows(); - assert(c0.size() == cols()); - - setTimePoint("encode_begin"); - span pp0(c0.subspan(k, rows())); - span pp1(c1.subspan(k, rows())); - - mR.template dualEncode2(pp0, pp1); - - setTimePoint("diag"); - mL.template dualEncode2(c0.subspan(0, k), c1.subspan(0, k), pp0, pp1); - setTimePoint("L"); - } - - - u64 cols() { return mL.cols() + mR.cols(); } - - u64 rows() { return mR.rows(); } - - void getPoints(PointList& points) - { - mL.getPoints(points); - mR.getPoints(points, mL.cols()); - } - - SparseMtx getMatrix() - { - PointList points(rows(), cols()); - getPoints(points); - return SparseMtx(rows(), cols(), points); - } - - }; - - } - - // the full silver encoder which is composed - // of the left and right sub-encoders. - struct SilverEncoder : public details::LdpcCompositEncoder - { - void init(u64 rows, SilverCode code) - { - mL.init(rows, code); - mR.init(rows, code, true); - } - }; - - - namespace tests - { - - void LdpcEncoder_diagonalSolver_test(); - void LdpcEncoder_encode_test(); - void LdpcEncoder_encode_g0_test(); - - void LdpcS1Encoder_encode_test(); - void LdpcS1Encoder_encode_Trans_test(); - - void LdpcComposit_RegRepDiagBand_encode_test(); - void LdpcComposit_RegRepDiagBand_Trans_test(); - - } - - - - - - - - - - - - - - - - - - - - - - - - // perform the circuit transpose of the encoding algorithm. - // the output it written to ppp. - template - void details::SilverLeftEncoder::dualEncode(span ppp, span mm) - { - auto cols = mRows; - assert(ppp.size() == mRows); - assert(mm.size() == cols); - - auto v = mYs; - T* __restrict pp = ppp.data(); - const T* __restrict m = mm.data(); - - for (u64 i = 0; i < cols; ) - { - auto end = cols; - for (u64 j = 0; j < mWeight; ++j) - { - if (v[j] == mRows) - v[j] = 0; - - auto jEnd = cols - v[j] + i; - end = std::min(end, jEnd); - } - T* __restrict P = &pp[i]; - T* __restrict PE = &pp[end]; - - switch (mWeight) - { - case 5: - { - const T* __restrict M0 = &m[v[0]]; - const T* __restrict M1 = &m[v[1]]; - const T* __restrict M2 = &m[v[2]]; - const T* __restrict M3 = &m[v[3]]; - const T* __restrict M4 = &m[v[4]]; - - v[0] += end - i; - v[1] += end - i; - v[2] += end - i; - v[3] += end - i; - v[4] += end - i; - i = end; - - while (P != PE) - { - *P = *P - ^ *M0 - ^ *M1 - ^ *M2 - ^ *M3 - ^ *M4 - ; - - ++M0; - ++M1; - ++M2; - ++M3; - ++M4; - ++P; - } - - - break; - } - case 11: - { - - const T* __restrict M0 = &m[v[0]]; - const T* __restrict M1 = &m[v[1]]; - const T* __restrict M2 = &m[v[2]]; - const T* __restrict M3 = &m[v[3]]; - const T* __restrict M4 = &m[v[4]]; - const T* __restrict M5 = &m[v[5]]; - const T* __restrict M6 = &m[v[6]]; - const T* __restrict M7 = &m[v[7]]; - const T* __restrict M8 = &m[v[8]]; - const T* __restrict M9 = &m[v[9]]; - const T* __restrict M10 = &m[v[10]]; - - v[0] += end - i; - v[1] += end - i; - v[2] += end - i; - v[3] += end - i; - v[4] += end - i; - v[5] += end - i; - v[6] += end - i; - v[7] += end - i; - v[8] += end - i; - v[9] += end - i; - v[10] += end - i; - i = end; - - while (P != PE) - { - *P = *P - ^ *M0 - ^ *M1 - ^ *M2 - ^ *M3 - ^ *M4 - ^ *M5 - ^ *M6 - ^ *M7 - ^ *M8 - ^ *M9 - ^ *M10 - ; - - ++M0; - ++M1; - ++M2; - ++M3; - ++M4; - ++M5; - ++M6; - ++M7; - ++M8; - ++M9; - ++M10; - ++P; - } - - break; - } - default: - while (i != end) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - pp[i] = pp[i] ^ m[row]; - ++v[j]; - } - ++i; - } - break; - } - - } - } - - // perform the circuit transpose of the encoding algorithm twice. - // the output it written to ppp0 and ppp1. - template - void details::SilverLeftEncoder::dualEncode2( - span ppp0, span ppp1, - span mm0, span mm1) - { - auto cols = mRows; - // pp = pp + m * A - auto v = mYs; - T0* __restrict pp0 = ppp0.data(); - T1* __restrict pp1 = ppp1.data(); - const T0* __restrict m0 = mm0.data(); - const T1* __restrict m1 = mm1.data(); - - for (u64 i = 0; i < cols; ) - { - auto end = cols; - for (u64 j = 0; j < mWeight; ++j) - { - if (v[j] == mRows) - v[j] = 0; - - auto jEnd = cols - v[j] + i; - end = std::min(end, jEnd); - } - switch (mWeight) - { - case 5: - while (i != end) - { - auto& r0 = v[0]; - auto& r1 = v[1]; - auto& r2 = v[2]; - auto& r3 = v[3]; - auto& r4 = v[4]; - - pp0[i] = pp0[i] - ^ m0[r0] - ^ m0[r1] - ^ m0[r2] - ^ m0[r3] - ^ m0[r4]; - - pp1[i] = pp1[i] - ^ m1[r0] - ^ m1[r1] - ^ m1[r2] - ^ m1[r3] - ^ m1[r4]; - - ++r0; - ++r1; - ++r2; - ++r3; - ++r4; - ++i; - } - break; - case 11: - while (i != end) - { - auto& r0 = v[0]; - auto& r1 = v[1]; - auto& r2 = v[2]; - auto& r3 = v[3]; - auto& r4 = v[4]; - auto& r5 = v[5]; - auto& r6 = v[6]; - auto& r7 = v[7]; - auto& r8 = v[8]; - auto& r9 = v[9]; - auto& r10 = v[10]; - - pp0[i] = pp0[i] - ^ m0[r0] - ^ m0[r1] - ^ m0[r2] - ^ m0[r3] - ^ m0[r4] - ^ m0[r5] - ^ m0[r6] - ^ m0[r7] - ^ m0[r8] - ^ m0[r9] - ^ m0[r10] - ; - - pp1[i] = pp1[i] - ^ m1[r0] - ^ m1[r1] - ^ m1[r2] - ^ m1[r3] - ^ m1[r4] - ^ m1[r5] - ^ m1[r6] - ^ m1[r7] - ^ m1[r8] - ^ m1[r9] - ^ m1[r10] - ; - - ++r0; - ++r1; - ++r2; - ++r3; - ++r4; - ++r5; - ++r6; - ++r7; - ++r8; - ++r9; - ++r10; - ++i; - } - - break; - default: - while (i != end) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - pp0[i] = pp0[i] ^ m0[row]; - pp1[i] = pp1[i] ^ m1[row]; - ++v[j]; - } - ++i; - } - break; - } - } - } - - - template - void details::SilverRightEncoder::dualEncode(span x) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mExtend); - assert(cols() == x.size()); - - constexpr int FIXED_OFFSET_SIZE = 2; - if (mOffsets.size() != FIXED_OFFSET_SIZE) - throw RTE_LOC; - - std::vector offsets(mOffsets.size()); - for (u64 j = 0; j < offsets.size(); ++j) - { - offsets[j] = mRows - 1 - mOffsets[j] - mGap; - } - - u64 i = mRows - 1; - T* __restrict ofCol0 = &x[offsets[0]]; - T* __restrict ofCol1 = &x[offsets[1]]; - T* __restrict xi = &x[i]; - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - { - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 16); - - T* __restrict xx = xi - 16; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 16; ++jj) - { - - auto col0 = diagMtx_g16_w5_seed1_t36[i & 15][0]; - auto col1 = diagMtx_g16_w5_seed1_t36[i & 15][1]; - auto col2 = diagMtx_g16_w5_seed1_t36[i & 15][2]; - auto col3 = diagMtx_g16_w5_seed1_t36[i & 15][3]; - - T* __restrict xc0 = xx + col0; - T* __restrict xc1 = xx + col1; - T* __restrict xc2 = xx + col2; - T* __restrict xc3 = xx + col3; - - *xc0 = *xc0 ^ *xi; - *xc1 = *xc1 ^ *xi; - *xc2 = *xc2 ^ *xi; - *xc3 = *xc3 ^ *xi; - - *ofCol0 = *ofCol0 ^ *xi; - *ofCol1 = *ofCol1 ^ *xi; - - - --ofCol0; - --ofCol1; - - --xx; - --xi; - --i; - } - } - - break; - } - case osuCrypto::SilverCode::Weight11: - { - - - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 32); - - T* __restrict xx = xi - 32; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 32; ++jj) - { - - auto col0 = diagMtx_g32_w11_seed2_t36[i & 31][0]; - auto col1 = diagMtx_g32_w11_seed2_t36[i & 31][1]; - auto col2 = diagMtx_g32_w11_seed2_t36[i & 31][2]; - auto col3 = diagMtx_g32_w11_seed2_t36[i & 31][3]; - auto col4 = diagMtx_g32_w11_seed2_t36[i & 31][4]; - auto col5 = diagMtx_g32_w11_seed2_t36[i & 31][5]; - auto col6 = diagMtx_g32_w11_seed2_t36[i & 31][6]; - auto col7 = diagMtx_g32_w11_seed2_t36[i & 31][7]; - auto col8 = diagMtx_g32_w11_seed2_t36[i & 31][8]; - auto col9 = diagMtx_g32_w11_seed2_t36[i & 31][9]; - - T* __restrict xc0 = xx + col0; - T* __restrict xc1 = xx + col1; - T* __restrict xc2 = xx + col2; - T* __restrict xc3 = xx + col3; - T* __restrict xc4 = xx + col4; - T* __restrict xc5 = xx + col5; - T* __restrict xc6 = xx + col6; - T* __restrict xc7 = xx + col7; - T* __restrict xc8 = xx + col8; - T* __restrict xc9 = xx + col9; - - *xc0 = *xc0 ^ *xi; - *xc1 = *xc1 ^ *xi; - *xc2 = *xc2 ^ *xi; - *xc3 = *xc3 ^ *xi; - *xc4 = *xc4 ^ *xi; - *xc5 = *xc5 ^ *xi; - *xc6 = *xc6 ^ *xi; - *xc7 = *xc7 ^ *xi; - *xc8 = *xc8 ^ *xi; - *xc9 = *xc9 ^ *xi; - - *ofCol0 = *ofCol0 ^ *xi; - *ofCol1 = *ofCol1 ^ *xi; - - - --ofCol0; - --ofCol1; - - --xx; - --xi; - --i; - } - } - - break; - } - default: - throw RTE_LOC; - break; - } - - offsets[0] = ofCol0 - x.data(); - offsets[1] = ofCol1 - x.data(); - - for (; i != ~u64(0); --i) - { - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = diagMtx_g16_w5_seed1_t36[i & 15][j] + i - 16; - if (col < mRows) - x[col] = x[col] ^ x[i]; - } - break; - case osuCrypto::SilverCode::Weight11: - - for (u64 j = 0; j < 10; ++j) - { - auto col = diagMtx_g32_w11_seed2_t36[i & 31][j] + i - 32; - if (col < mRows) - x[col] = x[col] ^ x[i]; - } - break; - default: - break; - } - - for (u64 j = 0; j < FIXED_OFFSET_SIZE; ++j) - { - auto& col = offsets[j]; - - if (col >= mRows) - break; - assert(i - mOffsets[j] - mGap == col); - - x[col] = x[col] ^ x[i]; - --col; - } - } - } - - template - void details::SilverRightEncoder::dualEncode2(span x0, span x1) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mExtend); - assert(cols() == x0.size()); - assert(cols() == x1.size()); - - constexpr int FIXED_OFFSET_SIZE = 2; - if (mOffsets.size() != FIXED_OFFSET_SIZE) - throw RTE_LOC; - - std::vector offsets(mOffsets.size()); - for (u64 j = 0; j < offsets.size(); ++j) - { - offsets[j] = mRows - 1 - mOffsets[j] - mGap; - } - - u64 i = mRows - 1; - T0* __restrict ofCol00 = &x0[offsets[0]]; - T0* __restrict ofCol10 = &x0[offsets[1]]; - T1* __restrict ofCol01 = &x1[offsets[0]]; - T1* __restrict ofCol11 = &x1[offsets[1]]; - T0* __restrict xi0 = &x0[i]; - T1* __restrict xi1 = &x1[i]; - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - { - - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 16); - - T0* __restrict xx0 = xi0 - 16; - T1* __restrict xx1 = xi1 - 16; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 16; ++jj) - { - - auto col0 = diagMtx_g16_w5_seed1_t36[i & 15][0]; - auto col1 = diagMtx_g16_w5_seed1_t36[i & 15][1]; - auto col2 = diagMtx_g16_w5_seed1_t36[i & 15][2]; - auto col3 = diagMtx_g16_w5_seed1_t36[i & 15][3]; - - T0* __restrict xc00 = xx0 + col0; - T0* __restrict xc10 = xx0 + col1; - T0* __restrict xc20 = xx0 + col2; - T0* __restrict xc30 = xx0 + col3; - T1* __restrict xc01 = xx1 + col0; - T1* __restrict xc11 = xx1 + col1; - T1* __restrict xc21 = xx1 + col2; - T1* __restrict xc31 = xx1 + col3; - - *xc00 = *xc00 ^ *xi0; - *xc10 = *xc10 ^ *xi0; - *xc20 = *xc20 ^ *xi0; - *xc30 = *xc30 ^ *xi0; - - *xc01 = *xc01 ^ *xi1; - *xc11 = *xc11 ^ *xi1; - *xc21 = *xc21 ^ *xi1; - *xc31 = *xc31 ^ *xi1; - - *ofCol00 = *ofCol00 ^ *xi0; - *ofCol10 = *ofCol10 ^ *xi0; - *ofCol01 = *ofCol01 ^ *xi1; - *ofCol11 = *ofCol11 ^ *xi1; - - - --ofCol00; - --ofCol10; - --ofCol01; - --ofCol11; - - --xx0; - --xx1; - --xi0; - --xi1; - --i; - } - } - - break; - } - case osuCrypto::SilverCode::Weight11: - { - - - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 32); - - T0* __restrict xx0 = xi0 - 32; - T1* __restrict xx1 = xi1 - 32; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 32; ++jj) - { - - auto col0 = diagMtx_g32_w11_seed2_t36[i & 31][0]; - auto col1 = diagMtx_g32_w11_seed2_t36[i & 31][1]; - auto col2 = diagMtx_g32_w11_seed2_t36[i & 31][2]; - auto col3 = diagMtx_g32_w11_seed2_t36[i & 31][3]; - auto col4 = diagMtx_g32_w11_seed2_t36[i & 31][4]; - auto col5 = diagMtx_g32_w11_seed2_t36[i & 31][5]; - auto col6 = diagMtx_g32_w11_seed2_t36[i & 31][6]; - auto col7 = diagMtx_g32_w11_seed2_t36[i & 31][7]; - auto col8 = diagMtx_g32_w11_seed2_t36[i & 31][8]; - auto col9 = diagMtx_g32_w11_seed2_t36[i & 31][9]; - - T0* __restrict xc00 = xx0 + col0; - T0* __restrict xc10 = xx0 + col1; - T0* __restrict xc20 = xx0 + col2; - T0* __restrict xc30 = xx0 + col3; - T0* __restrict xc40 = xx0 + col4; - T0* __restrict xc50 = xx0 + col5; - T0* __restrict xc60 = xx0 + col6; - T0* __restrict xc70 = xx0 + col7; - T0* __restrict xc80 = xx0 + col8; - T0* __restrict xc90 = xx0 + col9; - - T1* __restrict xc01 = xx1 + col0; - T1* __restrict xc11 = xx1 + col1; - T1* __restrict xc21 = xx1 + col2; - T1* __restrict xc31 = xx1 + col3; - T1* __restrict xc41 = xx1 + col4; - T1* __restrict xc51 = xx1 + col5; - T1* __restrict xc61 = xx1 + col6; - T1* __restrict xc71 = xx1 + col7; - T1* __restrict xc81 = xx1 + col8; - T1* __restrict xc91 = xx1 + col9; - - *xc00 = *xc00 ^ *xi0; - *xc10 = *xc10 ^ *xi0; - *xc20 = *xc20 ^ *xi0; - *xc30 = *xc30 ^ *xi0; - *xc40 = *xc40 ^ *xi0; - *xc50 = *xc50 ^ *xi0; - *xc60 = *xc60 ^ *xi0; - *xc70 = *xc70 ^ *xi0; - *xc80 = *xc80 ^ *xi0; - *xc90 = *xc90 ^ *xi0; - - *xc01 = *xc01 ^ *xi1; - *xc11 = *xc11 ^ *xi1; - *xc21 = *xc21 ^ *xi1; - *xc31 = *xc31 ^ *xi1; - *xc41 = *xc41 ^ *xi1; - *xc51 = *xc51 ^ *xi1; - *xc61 = *xc61 ^ *xi1; - *xc71 = *xc71 ^ *xi1; - *xc81 = *xc81 ^ *xi1; - *xc91 = *xc91 ^ *xi1; - - *ofCol00 = *ofCol00 ^ *xi0; - *ofCol10 = *ofCol10 ^ *xi0; - - *ofCol01 = *ofCol01 ^ *xi1; - *ofCol11 = *ofCol11 ^ *xi1; - - - --ofCol00; - --ofCol10; - --ofCol01; - --ofCol11; - - --xx0; - --xx1; - - --xi0; - --xi1; - --i; - } - } - - break; - } - default: - throw RTE_LOC; - break; - } - - offsets[0] = ofCol00 - x0.data(); - offsets[1] = ofCol10 - x0.data(); - - for (; i != ~u64(0); --i) - { - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = diagMtx_g16_w5_seed1_t36[i & 15][j] + i - 16; - if (col < mRows) - { - x0[col] = x0[col] ^ x0[i]; - x1[col] = x1[col] ^ x1[i]; - } - } - break; - case osuCrypto::SilverCode::Weight11: - - for (u64 j = 0; j < 10; ++j) - { - auto col = diagMtx_g32_w11_seed2_t36[i & 31][j] + i - 32; - if (col < mRows) - { - x0[col] = x0[col] ^ x0[i]; - x1[col] = x1[col] ^ x1[i]; - } - } - break; - default: - break; - } - - for (u64 j = 0; j < FIXED_OFFSET_SIZE; ++j) - { - auto& col = offsets[j]; - - if (col >= mRows) - break; - assert(i - mOffsets[j] - mGap == col); - - x0[col] = x0[col] ^ x0[i]; - x1[col] = x1[col] ^ x1[i]; - --col; - } - } - } - - - - -} -#endif // ENABLE_INSECURE_SILVER - diff --git a/libOTe/Tools/LDPC/LdpcImpulseDist.cpp b/libOTe/Tools/LDPC/LdpcImpulseDist.cpp deleted file mode 100644 index 325b5595..00000000 --- a/libOTe/Tools/LDPC/LdpcImpulseDist.cpp +++ /dev/null @@ -1,1132 +0,0 @@ - - -#define _CRT_SECURE_NO_WARNINGS -#include "LdpcImpulseDist.h" - -#ifdef ENABLE_LDPC - -#include "LdpcDecoder.h" -#include "Util.h" -#include -#include -#include "LdpcSampler.h" -#include "cryptoTools/Common/Timer.h" -#include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" -#include "libOTe/Tools/Tools.h" - -#ifdef ENABLE_ALGO994 -extern "C" { -#include "libOTe/Tools/LDPC/Algo994/data_defs.h" -} -#endif - -#include -#include - -#include // put_time - -namespace osuCrypto -{ - - - struct ListIter - { - std::vector set; - - ListDecoder mType; - u64 mTotalI = 0, mTotalEnd; - u64 mI = 0, mEnd = 1, mN = 0, mCurWeight = 0; - - u64 mD = 1; - void initChase(u64 d) - { - assert(d < 24); - mType = ListDecoder::Chase; - mD = d; - mI = 0; - ++(*this); - } - - void initOsd(u64 d, u64 n, bool startAtZero) - { - assert(d < 24); - mType = ListDecoder::OSD; - - mN = n; - mTotalI = 0; - mTotalEnd = 1ull << d; - - mCurWeight = startAtZero ? 0 : 1; - mI = 0; - mEnd = choose(n, mCurWeight); - - set = ithCombination(mI, n, mCurWeight); - } - - - void operator++() - { - assert(*this); - if (mType == ListDecoder::Chase) - { - set.clear(); - ++mI; - oc::BitIterator ii((u8*)&mI); - for (u64 i = 0; i < mD; ++i) - { - if (*ii) - set.push_back(i); - ++ii; - } - } - else - { - - ++mI; - ++mTotalI; - - if (mI == mEnd) - { - mI = 0; - ++mCurWeight; - mEnd = choose(mN, mCurWeight); - } - - if (mN >= mCurWeight) - { - set.resize(mCurWeight); - ithCombination(mI, mN, set); - } - else - set.clear(); - } - } - - std::vector& operator*() - { - return set; - } - - - operator bool() const - { - if (mType == ListDecoder::Chase) - return mI != (1ull << mD); - else - { - return mTotalI != mTotalEnd && mCurWeight <= mN; - } - } - }; - - - template - void sort_indexes(span v, span idx) { - - // initialize original index locations - assert(v.size() == idx.size()); - std::iota(idx.begin(), idx.end(), 0); - - //std::partial_sort() - std::stable_sort(idx.begin(), idx.end(), - [&v](size_t i1, size_t i2) {return v[i1] < v[i2]; }); - } - - std::mutex minWeightMtx; - u32 minWeight(0); - std::vector minCW; - - std::unordered_set heatSet; - std::vector heatMap; - std::vector heatMapCount; - u64 nextTimeoutIdx; - - struct Worker - { - std::vector llr; - std::vector y; - std::vector llrSetList; - std::vector codeword; - std::vector sortIdx, permute, weights; - std::vector> backProps; - LdpcDecoder D; - DynSparseMtx H; - DenseMtx DH; - std::vector dSet, eSet; - std::unordered_set eeSet; - bool verbose = false; - BPAlgo algo = BPAlgo::LogBP; - ListDecoder listDecoder = ListDecoder::OSD; - u64 Ng = 3; - oc::PRNG prng; - bool abs = false; - double timeout; - - void impulseDist(u64 i, u64 k, u64 Nd, u64 maxIter, bool randImpulse) - { - if (prng.mBufferByteCapacity == 0) - prng.SetSeed(oc::sysRandomSeed()); - - auto n = D.mH.cols(); - auto m = D.mH.rows(); - - llr.resize(n);// , lr.resize(n); - y.resize(m); - codeword.resize(n); - sortIdx.resize(n); - backProps.resize(m); - weights.resize(n); - - //auto p = 0.9999; - auto llr0 = LdpcDecoder::encodeLLR(0.501, 0); - auto llr1 = LdpcDecoder::encodeLLR(0.999, 1); - //auto lr0 = encodeLR(0.501, 0); - //auto lr1 = encodeLR(0.9999999, 1); - - std::fill(llr.begin(), llr.end(), llr0); - std::vector impulse; - if (randImpulse) - { - std::set set; - //u64 w = prng.get(); - //w = (w % (k)) + 1; - //assert(k + 1 < n); - while (set.size() != k) - { - auto i = prng.get() % n; - set.insert(i); - } - impulse.insert(impulse.end(), set.begin(), set.end()); - } - else - { - impulse = ithCombination(i, n, k); - - } - - for (auto i : impulse) - llr[i] = llr1; - - switch (algo) - { - case BPAlgo::LogBP: - D.logbpDecode2(llr, maxIter); - break; - case BPAlgo::AltLogBP: - D.altDecode(llr, false, maxIter); - break; - case BPAlgo::MinSum: - D.altDecode(llr, true, maxIter); - break; - default: - std::cout << "bad algo " << (int)algo << std::endl; - std::abort(); - break; - } - //bpDecode(lr, maxIter); - //for (auto& l : mL) - for (u64 i = 0; i < n; ++i) - { - if (abs) - llr[i] = std::abs(D.mL[i]); - else - llr[i] = (D.mL[i]); - } - - sort_indexes(llr, sortIdx); - - - u64 ii = 0; - dSet.clear(); - eSet.clear(); - - bool sparse = false; - if (sparse) - H = D.mH; - else - DH = D.mH.dense(); - - VecSortSet col; - - while (ii < n && eSet.size() < m) - { - auto c = sortIdx[ii++]; - - if (sparse) - col = H.col(c); - else - DH.colIndexSet(c, col.mData); - - bool set = false; - - for (auto r : col) - { - if (r >= eSet.size()) - { - if (set) - { - if (sparse) - H.rowAdd(r, eSet.size()); - else - DH.row(r) ^= DH.row(eSet.size()); - //assert(H(r, c) == 0); - } - else - { - set = true; - if (sparse) - H.rowSwap(eSet.size(), r); - else - DH.row(eSet.size()).swap(DH.row(r)); - } - } - } - - if (set == false) - { - dSet.push_back(c); - } - else - { - eSet.push_back(c); - } - } - - if (!sparse) - H = DH; - - - //auto HH1 = H.sparse().dense().gausianElimination(); - //auto HH2 = mH.dense().gausianElimination(); - //assert(HH1 == HH2); - - - if (eSet.size() != m) - { - std::cout << "bad eSet size " << LOCATION << std::endl; - abort(); - } - - //auto gap = dSet.size(); - - while (dSet.size() < Nd) - { - auto col = sortIdx[ii++]; - dSet.push_back(col); - } - - permute = eSet; - permute.insert(permute.end(), dSet.begin(), dSet.end()); - permute.insert(permute.end(), sortIdx.begin() + permute.size(), sortIdx.end()); - //if (v) - //{ - - // auto H2 = H.selectColumns(permute); - - // std::cout << " " << gap << "\n" << H2 << std::endl; - // //for (auto l : mL) - - // for (u64 i = 0; i < n; ++i) - // { - // std::cout << decodeLLR(D.mL[permute[i]]) << " "; - // } - // std::cout << std::endl; - //} - - std::fill(y.begin(), y.end(), 0); - - llrSetList.clear(); - - for (u64 i = m; i < n; ++i) - { - auto col = permute[i]; - codeword[col] = LdpcDecoder::decodeLLR(D.mL[col]); - - - if (codeword[col]) - { - llrSetList.push_back(col); - for (auto row : H.col(col)) - { - y[row] ^= 1; - } - } - } - - eeSet.clear(); - eeSet.insert(eSet.begin(), eSet.end()); - for (u64 i = 0; i < m - 1; ++i) - { - backProps[i].clear(); - for (auto c : H.row(i)) - { - if (c != permute[i] && - eeSet.find(c) != eeSet.end()) - { - backProps[i].push_back(c); - } - } - } - - - - ListIter cIter; - - if (listDecoder == ListDecoder::Chase) - cIter.initChase(Nd); - else - cIter.initOsd(Nd, std::min(Ng, n - m), llrSetList.size() > 0); - - //u32 s = 0; - //u32 e = 1 << Nd; - while (cIter) - { - - //for (u64 i = m + Nd; i < n; ++i) - //{ - // auto col = permute[i]; - // codeword[col] = y[i]; - //} - - ////yp = y; - std::fill(codeword.begin(), codeword.end(), 0); - - - // check if BP resulted in any - // of the top bits being set to 1. - // if so, preload the partially - // solved solution for this. - if (llrSetList.size()) - { - for (u64 i = 0; i < m; ++i) - codeword[permute[i]] = y[i]; - for (auto i : llrSetList) - codeword[i] = 1; - } - - // next, iterate over setting some - // of the remaining bits. - auto& oneSet = *cIter; - //for (u64 i = m; i < m + Nd; ++i, ++iter) - for (auto i : oneSet) - { - - auto col = permute[i + m]; - - // check if this was not already set to - // 1 by the BP. - if (codeword[col] == 0) - { - codeword[col] = 1; - for (auto row : H.col(col)) - { - codeword[permute[row]] ^= 1; - } - } - } - - ++cIter; - - // now perform back prop on the remaining - // postions. - for (u64 i = m - 1; i != ~0ull; --i) - { - for (auto c : backProps[i]) - { - if (codeword[c]) - codeword[permute[i]] ^= 1; - } - } - - - // check if its a codework (should always be one) - if (D.check(codeword) == false) { - std::cout << "bad codeword " << LOCATION << std::endl; - abort(); - } - - // record the weight. - auto w = std::accumulate(codeword.begin(), codeword.end(), 0ull); - ++weights[w]; - - if (w && w < minWeight + 10) - { - - oc::RandomOracle ro(sizeof(u64)); - ro.Update(codeword.data(), codeword.size()); - u64 h; - ro.Final(h); - - std::lock_guard lock(minWeightMtx); - - if (w < minWeight) - { - if (verbose) - std::cout << " w=" << w << std::flush; - - minWeight = (u32) w; - minCW.clear(); - minCW.insert(minCW.end(), codeword.begin(), codeword.end()); - - if (timeout > 0) - { - u64 nn = static_cast((i + 1) * timeout); - nextTimeoutIdx = std::max(nextTimeoutIdx, nn); - } - } - - - - if (heatSet.find(h) == heatSet.end()) - { - heatSet.insert(h); - //heatMap[w].resize(n); - for (u64 i = 0; i < n; ++i) - { - if (codeword[i]) - ++heatMap[i]; - ++heatMapCount[w]; - } - } - - } - } - } - - }; - - - - - u64 impulseDist( - SparseMtx& mH, - u64 Nd, u64 Ng, - u64 w, - u64 maxIter, u64 nt, bool randImpulse, u64 trials, BPAlgo algo, - ListDecoder listDecoder, bool verbose, double timeout) - { - assert(Nd < 32); - auto n = mH.cols(); - //auto m = mH.rows(); - - LdpcDecoder D; - D.init(mH); - minWeight = 999999999; - minCW.clear(); - - - nt = nt ? nt : 1; - std::vector wrks(nt); - - for (auto& ww : wrks) - { - ww.algo = algo; - ww.D.init(mH); - ww.verbose = verbose; - - ww.Ng = Ng; - ww.listDecoder = listDecoder; - ww.timeout = timeout; - } - nextTimeoutIdx = 0; - if (randImpulse) - { - std::vector thrds(nt); - - for (u64 t = 0; t < nt; ++t) - { - thrds[t] = std::thread([&, t]() { - - for (u64 i = t; i < trials; i += nt) - { - wrks[t].impulseDist(i, w, Nd, maxIter, randImpulse); - } - }); - } - - for (u64 t = 0; t < nt; ++t) - thrds[t].join(); - } - else - { - - bool timedOut = false; - std::vector thrds(nt); - - for (u64 t = 0; t < nt; ++t) - { - thrds[t] = std::thread([&, t]() { - - u64 ii = t; - for (u64 k = 0; k < w + 1; ++k) - { - - auto f = choose(n, k); - for (; ii < f && timedOut == false; ii += nt) - { - - wrks[t].impulseDist(ii, k, Nd, maxIter, randImpulse); - - - if (k == w && (ii % 100) == 0) - { - std::lock_guard lock(minWeightMtx); - - if (nextTimeoutIdx && nextTimeoutIdx < ii) - timedOut = true; - } - } - - ii -= f; - } - } - ); - } - - for (u64 t = 0; t < nt; ++t) - thrds[t].join(); - - } - - std::vector weights(n); - for (u64 t = 0; t < nt; ++t) - { - for (u64 i = 0; i < wrks[t].weights.size(); ++i) - weights[i] += wrks[t].weights[i]; - } - - auto i = 1; - while (weights[i] == 0) ++i; - auto ret = i; - - if (verbose) - { - std::cout << std::endl; - auto j = 0; - while (j++ < 10) - { - std::cout << i << " " << weights[i] << std::endl; - ++i; - } - } - - return ret; - } - - std::string return_current_time_and_date() - { - auto now = std::chrono::system_clock::now(); - auto in_time_t = std::chrono::system_clock::to_time_t(now); - - std::stringstream ss; - ss << std::put_time(std::localtime(&in_time_t), "%c %X"); - return ss.str(); - } - - void LdpcDecode_impulse(const oc::CLP& cmd) - { - // general parameter - // the number of rows of H, "m" - auto rowVec = cmd.getManyOr("r", { 50 }); - - // the expansion ratio, e="n/m" - double e = cmd.getOr("e", 2.0); - - // the number of trials. - u64 trial = cmd.getOr("trials", 1); - - // which trial to start at. - u64 tStart = cmd.getOr("tStart", 0); - - // a specific set of trials to run. - auto tSet = cmd.getManyOr("tSet", {}); - - // the need used to generate the samples - u64 seed = cmd.getOr("seed", 0); - - // the seed(s) to generate the left half. One for each trial. - auto lSeed = cmd.getManyOr("lSeed", {}); - - // verbose flag. - bool verbose = cmd.isSet("v"); - - // the number of threads - u64 nt = cmd.getOr("nt", cmd.isSet("nt") ? std::thread::hardware_concurrency() : 1); - - // A flag to just test the uniform matrix. - bool uniform = cmd.isSet("u"); - - - // silver parameters - // ================================ - - // use the silver preset for the given column weight. - bool silver = cmd.isSet("slv"); - - // The column weight for the left half - u64 colWeight = cmd.getOr("cw", 5); - - // the column weight for the right half. Counts the - // main diagonal. - u64 dWeight = cmd.getOr("dw", 2); - - // The size of the gap. - u64 gap = cmd.getOr("g", 1); - - // the number of diagonals on the left half. - u64 diag = cmd.getOr("diag", 0); - - // the number of diagonals on the right half. - u64 dDiag = cmd.getOr("dDiag", 0); - - // the extra diagonal bands below the main diagonal (right half) - // The values denote how far below the diagonal they should be. - // requires -trim - auto doubleBand = cmd.getMany("db"); - - // delete the first g columns of the right half - bool trim = cmd.isSet("trim"); - - // extend the right half to be square. - bool extend = cmd.isSet("extend"); - - // how often the right half should repeat. - u64 period = cmd.getOr("period", 0); - - // a flag to randomly sample the position of the left diagonals. - // requires -diag > 0. - bool randY = cmd.isSet("randY"); - - // the slopes of the left diagonals, default 1. - // requires -diag > 0. - slopes_ = cmd.getManyOr("slope", {}); - - // the fixed indexes of the left diagononals - // requires -diag > 0. - ys_ = cmd.getManyOr("ys", {}); - - // the fractional positions of the left diagononals. - // requires -diag > 0. - yr_ = cmd.getManyOr("yr", {}); - - // print the factional positions of the left diagonals. - bool printYs = cmd.isSet("py"); - - // sample the right half to be regular (the same number of - // ones in the rows as the columns). - bool reg = cmd.isSet("reg"); - - // print the heat map for where we found minimum codewords - bool hm = cmd.isSet("hm"); - - // when sample, dont check that H corresponds to a value - // LDPC code - bool noCheck = cmd.isSet("noCheck"); - - // the path to the log file. - std::string logPath = cmd.getOr("log", ""); - - // the amount of "time" that can pass in the impluse technique - // without finding a new minimum. the value denote the faction - // of the total number of impluses which should be performed. - double timeout = cmd.getOr("to", 0.0); - - - // estimator parameters - // ============================== - - // the weight of the noise impulse. - u64 w = cmd.getOr("w", 1); - - // once partial gaussian elemination is performed, - // we will consider all codewords which have the - // next Nd out-of Ng bits set to one. e.g. if we have - // Nd=2, Ng=4, the the codewords - // - // xx...xx 1100 0000... - // xx...xx 1010 0000... - // ... - // xx...xx 0011 0000... - // - // will be tried where the x bits are solved for. - - u64 Nd = cmd.getOr("Nd", 10); - u64 Ng = cmd.getOr("Ng", 50); - - // the number of BD iterations that should be performed. - u64 iter = cmd.getOr("iter", 10); - - // should random noise impulses of the given weight be tried. - bool rand = cmd.isSet("rand"); - - // how many random noise impulse should be tried. - u64 n = cmd.getOr("rand", 100); - - // the type of list decoder. - ListDecoder listDecoder = (ListDecoder)cmd.getOr("ld", 1); - - // print the regular diagonal so it can be used to generate a - // actual code, e.g. diagMtx_g32_w11_seed2_t36. - printDiag = cmd.isSet("printDiag"); - - auto algo = (BPAlgo)cmd.getOr("bp", 2); - - // algo994 parameters - auto trueDist = cmd.isSet("true"); -#ifdef ENABLE_ALGO994 - alg994 = cmd.getOr("algo994", ALG_SAVED_UNROLLED); - num_saved_generators = cmd.getOr("numGen", 5); - num_cores = (int)nt; - num_permutations = cmd.getOr("numPerm", 10); - print_matrices = 0; -#endif - - SparseMtx H; - LdpcDecoder D; - std::stringstream label; - - - label << return_current_time_and_date() << "\n"; - - if (e != 2) - label << "-ldpc -e " << e << " "; - - if (silver) - label << " -slv "; - - label << "-r "; - for (auto rows : rowVec) - label << rows << " "; - - if (trueDist) - label << "-true "; - else - { - label << " -ld " << (int)listDecoder << " -bp " << int(algo) << "-Nd " << Nd << " -Ng " << Ng << " -w " << w; - if (timeout) - label << " -to " << timeout; - } - label << " -nt " << nt; - - if (tSet.size()) - { - label << " -tSet "; - for (auto t : tSet) - label << t << " "; - - } - else - { - label << " -trials " << trial; - if (tStart) - label << " -tStart " << tStart; - } - - label << " -seed " << seed; - - if (uniform) - { - label << " -u "; - if (cmd.isSet("cw")) - label << " -cw " << colWeight; - } - else - { - if (cmd.isSet("lb")) - label << "-ld"; - - label << " -cw " << colWeight << " -g " << gap << " -dw " << dWeight - << " -diag " << diag << " -dDiag " << dDiag; - - if (doubleBand.size()) - { - label << " -db "; - for (auto db : doubleBand) - label << db << " "; - } - - if (trim) - label << " -trim "; - - if (extend) - label << " -extend "; - - label << " -period " << period; - //if(zp) - // label << " -zp "; - - if (reg) - label << " -reg"; - - } - - std::ofstream log; - if (logPath.size()) - { - log.open(logPath, std::ios::out | std::ios::app); - } - - if (log.is_open()) - log << "\n" << label.str() << std::endl; - - if (tSet.size() == 0) - { - for (u64 i = tStart; i < trial; ++i) - tSet.push_back(i); - } - - for (auto rows : rowVec) - { - - //if (zp) - //{ - // if (isPrime(rows + 1) == false) - // rows = nextPrime(rows + 1) - 1; - - //} - u64 cols = static_cast(rows * e); - - if (uniform && trim) - { - cols -= gap; - } - - std::vector dd; - - heatMap.clear(); - heatMap.resize(cols); - heatMapCount.clear(); - heatMapCount.resize(cols); - heatSet.clear(); - - if (log.is_open()) - log << rows << ": "; - - auto lIter = lSeed.begin(); - for (auto i : tSet) - { - oc::PRNG lPrng, rPrng(block(seed, i)); - - if (lIter != lSeed.end()) - { - lPrng.SetSeed(block(3, *lIter++)); - - if (lIter == lSeed.end()) - lIter = lSeed.begin(); - } - else - lPrng.SetSeed(block(seed, i)); - - - if (uniform) - { - if (cmd.isSet("cw")) - { - - - H = sampleFixedColWeight(rows, cols, colWeight, rPrng, !noCheck); - } - else - H = sampleUniformSystematic(rows, cols, rPrng); - } - else if (silver) - { - - SilverEncoder enc; - SilverCode code; - - if (colWeight == 5) - { - code = SilverCode::Weight5; - } - else if (colWeight == 11) - code = SilverCode::Weight11; - else - { - std::cout << "-slv can only be used with -cw 5 or -cw 11" << std::endl; - throw RTE_LOC; - } - - enc.mR.init(rows, code, extend); - enc.mL.init(rows, code, cmd.isSet("rep") ? SilverCode::gap(code) : 0); - H = enc.getMatrix(); - } - else if (reg) - { - H = sampleRegTriangularBand( - rows, cols, - colWeight, gap, - dWeight, diag, dDiag, period, - doubleBand, trim, extend, randY, rPrng); - //std::cout << H << std::endl; - } - else - { - H = sampleTriangularBand( - rows, cols, - colWeight, gap, - dWeight, diag, dDiag, period, - doubleBand, trim, extend, randY, lPrng,rPrng); - } - - //impulseDist(5, 5000); - //oc::Timer timer; - //timer.setTimePoint(""); - - //timer.setTimePoint("e"); - - if(verbose) - std::cout << "\n" << H << std::endl; - - if (trueDist) - { - auto d = minDist2(H.dense(), nt, false); - dd.push_back(d); - } - else - { - auto d = impulseDist(H, Nd, Ng, w, iter, nt, rand, n, algo, listDecoder, verbose, timeout); - dd.push_back(d); - } - - if (log.is_open()) - log << " " << dd.back() << std::flush; - - if (verbose) - { - std::cout << dd.back(); - - for (auto c : minCW) - { - if (c) - std::cout << oc::Color::Green << int(c) << " " << oc::Color::Default; - else - std::cout << int(c) << " "; - } - std::cout << std::endl; - - if (hm || verbose) - { - u64 max = 0ull; - for (u64 i = 0; i < heatMap.size(); ++i) - { - max = std::max(max, heatMap[i]); - } - - double tick = max / 10.0; - - - for (u64 j = 1; j <= 10; ++j) - { - for (u64 i = 0; i < heatMap.size(); ++i) - { - if (heatMap[i] >= j * tick) - std::cout << "* "; - else - std::cout << " "; - } - std::cout << "|\n"; - } - std::cout << std::flush; - - heatMap.clear(); - heatMap.resize(cols); - heatMapCount.clear(); - heatMapCount.resize(cols); - heatSet.clear(); - } - - } - else if (!cmd.isSet("silent")) - { - std::cout << dd.back() << " " << std::flush; - - if (printYs) - { - std::cout << "~ "; - for (auto y : lastYs_ ) - std::cout << y << " "; - - std::cout << "~ "; - for (auto y : yr_) - std::cout << y << " "; - - std::cout << std::endl; - } - } - //std::cout << timer << std::endl;; - - } - - if (log.is_open()) - log << std::endl; - - auto tt = tSet.size(); - auto min = *std::min_element(dd.begin(), dd.end()); - auto max = *std::max_element(dd.begin(), dd.end()); - auto avg = std::accumulate(dd.begin(), dd.end(), 0ull) / double(tt); - //avg = avg / tt; - - //std::cout << "\r"; - //auto str = ss.str(); - //for (u64 i = 0; i < str.size(); ++i) - // std::cout << " "; - //std::cout << ; - - { - std::cout << oc::Color::Green << "\r" << rows << ": "; - std::cout << min << " " << avg << " " << max << " ~ " << oc::Color::Default; - for (auto d : dd) - std::cout << d << " "; - - std::cout << std::endl; - } - - - - if (hm && !verbose) - { - u64 max = 0ull; - for (u64 i = 0; i < heatMap.size(); ++i) - { - max = std::max(max, heatMap[i]); - } - - double tick = max / 10.0; - - - for (u64 j = 1; j <= 10; ++j) - { - for (u64 i = 0; i < heatMap.size(); ++i) - { - if (heatMap[i] >= j * tick) - std::cout << "* "; - else - std::cout << " "; - } - std::cout << "|\n"; - } - std::cout << std::flush; - - heatMap.clear(); - heatMap.resize(cols); - heatMapCount.clear(); - heatMapCount.resize(cols); - heatSet.clear(); - } - - } - - - - - return; - - } - - - -} - -#endif diff --git a/libOTe/Tools/LDPC/LdpcImpulseDist.h b/libOTe/Tools/LDPC/LdpcImpulseDist.h deleted file mode 100644 index e3bcfe34..00000000 --- a/libOTe/Tools/LDPC/LdpcImpulseDist.h +++ /dev/null @@ -1,40 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" -#ifdef ENABLE_LDPC - -#include "Mtx.h" -#include "LdpcDecoder.h" - -namespace osuCrypto -{ - - - enum class BPAlgo - { - LogBP = 0, - AltLogBP = 1, - MinSum = 2 - }; - - enum class ListDecoder - { - Chase = 0 , - OSD = 1 - }; - - //extern std::vector minCW; - - void LdpcDecode_impulse(const oc::CLP& cmd); - - //u64 impulseDist(LdpcDecoder& D, u64 i, u64 n, u64 k, u64 Ne, u64 maxIter); - //u64 impulseDist(SparseMtx& mH, u64 Ne, u64 w, u64 maxIter, u64 numThreads, bool randImpulse, u64 trials, BPAlgo algo, bool verbose); -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/LdpcSampler.cpp b/libOTe/Tools/LDPC/LdpcSampler.cpp deleted file mode 100644 index 0693da62..00000000 --- a/libOTe/Tools/LDPC/LdpcSampler.cpp +++ /dev/null @@ -1,256 +0,0 @@ -#include "libOTe/Tools/LDPC/LdpcSampler.h" -#ifdef ENABLE_LDPC - -#include "libOTe/Tools/LDPC/LdpcEncoder.h" -#include -#include "libOTe/Tools/LDPC/Util.h" -#include - -#ifdef ENABLE_ALGO994 -extern "C" { -#include "libOTe/Tools/LDPC/Algo994/data_defs.h" -} -#endif - -namespace osuCrypto -{ - - std::vector slopes_, ys_, lastYs_; - std::vector yr_; - bool printDiag = false; - - - void sampleRegDiag(u64 rows, u64 gap, u64 weight, oc::PRNG& prng, PointList& points) - { - - if (rows < gap * 2) - throw RTE_LOC; - - auto cols = rows - gap; - std::vector rowWeights(cols); - std::vector> rowSets(rows - gap); - - - for (u64 c = 0; c < cols; ++c) - { - std::set s; - //auto remCols = cols - c; - - for (u64 j = 0; j < gap; ++j) - { - auto rowIdx = c + j; - - // how many remaining slots are there to the left (speial case at the start) - u64 remA = std::max(gap - rowIdx, 0); - - // how many remaining slots there are to the right. - u64 remB = std::min(j, cols - c); - - u64 rem = remA + remB; - - - auto& w = rowWeights[rowIdx % cols]; - auto needed = (weight - w); - - if (needed > rem) - throw RTE_LOC; - - if (needed && needed == rem) - { - s.insert(rowIdx); - points.push_back({ rowIdx, c }); - ++w; - } - } - - if (s.size() > weight) - throw RTE_LOC; - - while (s.size() != weight) - { - auto j = (prng.get() % gap); - auto r = c + j; - - auto& w = rowWeights[r % cols]; - - if (w != weight && s.insert(r).second) - { - ++w; - points.push_back({ r, c }); - } - } - - for (auto ss : s) - { - rowSets[ss % cols].insert(c); - if (rowSets[ss % cols].size() > weight) - throw RTE_LOC; - } - - if (c > gap && rowSets[c].size() != weight) - { - SparseMtx H(c + gap + 1, c + 1, points); - std::cout << H << std::endl << std::endl; - throw RTE_LOC; - } - - - } - - if (printDiag) - { - - //std::vector> hh; - auto pp = points; - - for (u64 i = 0; i < rows - 1; ++i) - pp.push_back({ i, i + 1 }); - - SparseMtx H(rows, rows, points); - //std::cout << (SparseMtx(pp)) << std::endl << std::endl; - - std::cout << "{{\n"; - - - for (u64 i = 0; i < cols; ++i) - { - std::cout << "{{ "; - bool first = true; - //hh.emplace_back(); - - for (u64 j = 0; j < (u64)H.col(i).size(); ++j) - { - auto c = H.col(i)[j]; - c = (c - i) % cols; - - if (!first) - std::cout << ", "; - std::cout << c; - - //hh[i].push_back(H.row(i)[j]); - first = false; - } - - //if (i + cols < H.rows() && 0) - //{ - // for (u64 j = 0; j < (u64)H.row(i + cols).size(); ++j) - // { - - // auto c = H.row(i + cols)[j]; - // c = (c + cols - 1 - i) % cols; - - // if (!first) - // std::cout << ", "; - // std::cout << c; - // //hh[i].push_back(H.row(i+cols)[j]); - // first = false; - // } - //} - - - - std::cout << "}},\n"; - } - std::cout << "}}" << std::endl; - - - //{ - // u64 rowIdx = 0; - // for (auto row : hh) - // { - // std::set s; - // std::cout << "("; - // for (auto c : row) - // { - // std::cout << int(c) << " "; - // s.insert(); - // } - // std::cout << std::endl << "{"; - - // for (auto c : s) - // std::cout << c << " "; - // std::cout << std::endl; - // ++rowIdx; - // } - //} - } - } - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - void sampleRegTriangularBand(u64 rows, u64 cols, - u64 weight, u64 gap, u64 dWeight, - u64 diag, u64 dDiag, u64 period, - std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& prng, PointList& points) - { - //auto dHeight =; - - assert(extend == false || trim == true); - assert(gap < rows); - assert(dWeight > 0); - assert(dWeight <= gap + 1); - - if (trim == false) - throw RTE_LOC; - - - if (period == 0 || period > rows) - period = rows; - - if (extend) - { - for (u64 i = 0; i < gap; ++i) - { - points.push_back({ rows - gap + i, cols - gap + i }); - } - } - - //auto b = trim ? cols - rows + gap : cols - rows; - auto b = cols - rows; - auto diagOffset = sampleFixedColWeight(rows, b, weight, diag, randY, prng, points); - u64 e = rows - gap; - auto e2 = cols - gap; - - - PointList diagPoints(period + gap, period); - sampleRegDiag(period + gap, gap, dWeight - 1, prng, diagPoints); - //std::cout << "cols " << cols << std::endl; - - std::set> ss; - for (u64 i = 0; i < e; ++i) - { - - points.push_back({ i, b + i }); - if (b + i >= cols) - throw RTE_LOC; - - //if (ss.insert({ i, b + i })); - for (auto db : doubleBand) - { - assert(db >= 1); - u64 j = db + gap + i; - - if (j < rows) - points.push_back({ j, b + i }); - } - } - - auto blks = (e + period - 1) / (period); - for (u64 i = 0; i < blks; ++i) - { - auto ii = i * period; - for (auto p : diagPoints) - { - auto r = ii + p.mRow + 1; - auto c = ii + p.mCol + b; - if (r < rows && c < e2) - points.push_back({ r, c }); - } - } - - } -} -#endif diff --git a/libOTe/Tools/LDPC/LdpcSampler.h b/libOTe/Tools/LDPC/LdpcSampler.h deleted file mode 100644 index 4bcb8848..00000000 --- a/libOTe/Tools/LDPC/LdpcSampler.h +++ /dev/null @@ -1,751 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" -#ifdef ENABLE_LDPC -#include "Mtx.h" -#include "cryptoTools/Crypto/PRNG.h" -#include -#include -#include "cryptoTools/Common/CLP.h" -#include "cryptoTools/Common/BitVector.h" -#include "Util.h" -#include "libOTe/Tools/Tools.h" -#include - -namespace osuCrypto -{ - - //inline void push(PointList& p, Point x) - //{ - // //for (u64 i = 0; i < p.size(); ++i) - // //{ - // // if (p[i].mCol == x.mCol && p[i].mRow == x.mRow) - // // { - // // assert(0); - // // } - // //} - - // //std::cout << "{" << x.mRow << ", " << x.mCol << " } " << std::endl; - // p.push_back(x); - //} - - - extern std::vector slopes_, ys_, lastYs_; - extern std::vector yr_; - extern bool printDiag; - // samples a uniform partiy check matrix with - // each column having weight w. - inline std::vector sampleFixedColWeight( - u64 rows, u64 cols, - u64 w, u64 diag, bool randY, - oc::PRNG& prng, PointList& points) - { - std::vector& diagOffsets = lastYs_; - diagOffsets.clear(); - - diag = std::min(diag, w); - - if (slopes_.size() == 0) - slopes_ = { {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1} }; - - if (slopes_.size() < diag) - throw RTE_LOC; - - if (diag) - { - if (randY) - { - yr_.clear(); - std::uniform_real_distribution<> dist(0, 1); - yr_.push_back(0); - //s.insert(0); - - while (yr_.size() != diag) - { - auto v = dist(prng); - yr_.push_back(v); - } - - std::sort(yr_.begin(), yr_.end()); - //diagOffsets.insert(diagOffsets.end(), s.begin(), s.end()); - } - - if (yr_.size()) - { - if (yr_.size() < diag) - { - std::cout << "yr.size() < diag" << std::endl; - throw RTE_LOC; - } - std::set ss; - //ss.insert(0); - //diagOffsets.resize(diag); - - for (u64 i = 0; ss.size() < diag; ++i) - { - auto p = u64(rows * yr_[i]) % rows; - while (ss.insert(p).second == false) - p = u64(p + 1) % rows; - } - - //for (u64 i = 0; i < diag; ++i) - //{ - //} - diagOffsets.clear(); - for (auto s : ss) - diagOffsets.push_back(s); - } - else if (ys_.size()) - { - - if (ys_.size() < diag) - throw RTE_LOC; - - diagOffsets = ys_; - diagOffsets.resize(diag); - - } - else - { - diagOffsets.resize(diag); - - for (u64 i = 0; i < diag; ++i) - diagOffsets[i] = rows / 2; - } - - } - std::set set; - for (u64 i = 0; i < cols; ++i) - { - set.clear(); - if (diag && i < rows) - { - //assert(diag <= minWeight); - for (u64 j = 0; j < diag; ++j) - { - i64& r = diagOffsets[j]; - - - - r = (slopes_[j] + r); - - if (r >= i64(rows)) - r -= rows; - - if (r < 0) - r += rows; - - if (r >= i64(rows) || r < 0) - { - //std::cout << i << " " << r << " " << rows << std::endl; - throw RTE_LOC; - } - - - set.insert(r); - //auto& pat = patterns[i % patterns.size()]; - //for (u64 k = 0; k < pat.size(); ++k) - //{ - - // auto nn = set.insert((r + pat[k]) % rows); - // if (!nn.second) - // set.erase(nn.first); - //} - - //auto r2 = (r + 1 + (i & 15)) % rows; - //nn = set.insert(r2).second; - //if (nn) - // points.push_back({ r2, i }); - } - } - - for (auto ss : set) - { - points.push_back({ ss, i }); - } - - while (set.size() < w) - { - auto j = prng.get() % rows; - if (set.insert(j).second) - points.push_back({ j, i }); - } - } - - return diagOffsets; - } - - DenseMtx computeGen(DenseMtx& H); - - // samples a uniform partiy check matrix with - // each column having weight w. - inline SparseMtx sampleFixedColWeight(u64 rows, u64 cols, u64 w, oc::PRNG& prng, bool checked) - { - PointList points(rows, cols); - sampleFixedColWeight(rows, cols, w, false, false, prng, points); - - if (checked) - { - u64 i = 1; - SparseMtx H(rows, cols, points); - - auto D = H.dense(); - auto g = computeGen(D); - - while(g.rows() == 0) - { - ++i; - points.mPoints.clear(); - sampleFixedColWeight(rows, cols, w, false, false, prng, points); - - H = SparseMtx(rows, cols, points); - D = H.dense(); - g = computeGen(D); - } - - //std::cout << "("< - inline void shuffle(Iter begin, Iter end, oc::PRNG& prng) - { - u64 n = u64(end - begin); - - for (u64 i = 0; i < n; ++i) - { - auto j = prng.get() % (n - i); - - std::swap(*begin, *(begin + j)); - - ++begin; - } - } - - // samples a uniform set of size weight in the - // inteveral [begin, end). If diag, then begin - // will always be in the set. - inline std::set sampleCol(u64 begin, u64 end, u64 mod, u64 weight, bool diag, oc::PRNG& prng) - { - //std::cout << "sample " << prng.get() << std::endl; - - std::set idxs; - - auto n = end - begin; - if (n < weight) - { - std::cout << "n < weight " << LOCATION << std::endl; - abort(); - } - - if (diag) - { - idxs.insert(begin % mod); - ++begin; - --n; - --weight; - } - - if (n < 3 * weight) - { - auto nn = std::min(3 * weight, n); - std::vector set(nn); - std::iota(set.begin(), set.end(), begin); - - shuffle(set.begin(), set.end(), prng); - - //for (u64 i = 0; i < weight; ++i) - //{ - // std::cout << set[i] << std::endl; - //} - //for(auto s : set) - //auto iter = set.beg - for (u64 i = 0; i < weight; ++i) - idxs.insert((set[i]) % mod); - } - else - { - while (idxs.size() < weight) - { - auto x = prng.get() % n; - //std::cout << x << std::endl; - idxs.insert((x + begin) % mod); - } - } - - return idxs; - } - - - inline std::set sampleCol(u64 begin, u64 end, u64 weight, bool diag, oc::PRNG& prng) - { - std::set idxs; - - auto n = end - begin; - if (n < weight) - { - std::cout << "n < weight " << LOCATION << std::endl; - abort(); - } - - if (diag) - { - idxs.insert(begin); - ++begin; - --n; - --weight; - } - - if (n < 3 * weight) - { - auto nn = std::min(3 * weight, n); - std::vector set(nn); - std::iota(set.begin(), set.end(), begin); - - shuffle(set.begin(), set.end(), prng); - - //for (u64 i = 0; i < weight; ++i) - //{ - // std::cout << set[i] << std::endl; - //} - - idxs.insert(set.begin(), set.begin() + weight); - } - else - { - while (idxs.size() < weight) - { - auto x = prng.get() % n; - //std::cout << x << std::endl; - idxs.insert(x + begin); - } - } - return idxs; - } - - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline void sampleTriangular(u64 rows, u64 cols, u64 weight, u64 gap, oc::PRNG& prng, PointList& points) - { - auto b = cols - rows + gap; - sampleFixedColWeight(rows, b, weight, 0, false, prng, points); - - for (u64 i = 0; i < rows - gap; ++i) - { - auto w = std::min(weight - 1, (rows - i) / 2); - auto s = sampleCol(i + 1, rows, w, false, prng); - - points.push_back({ i, b + i }); - for (auto ss : s) - points.push_back({ ss, b + i }); - - } - } - - - inline void sampleUniformSystematic(u64 rows, u64 cols, oc::PRNG& prng, PointList& points) - { - - for (u64 i = 0; i < rows; ++i) - { - points.push_back({ i, cols - rows + i }); - - for (u64 j = 0; j < cols - rows; ++j) - { - if (prng.get()) - { - points.push_back({ i,j }); - } - } - } - - - } - - inline SparseMtx sampleUniformSystematic(u64 rows, u64 cols, oc::PRNG& prng) - { - PointList points(rows, cols); - sampleUniformSystematic(rows, cols, prng, points); - return SparseMtx(rows, cols, points); - } - - void sampleRegDiag( - u64 rows, u64 gap, u64 weight, - oc::PRNG& prng, PointList& points - ); - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - void sampleRegTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, - std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& prng, PointList& points); - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline void sampleTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, - std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& lPrng, PRNG& rPrng, PointList& points) - { - auto dHeight = gap + 1; - - assert(extend == false || trim == true); - assert(gap < rows); - assert(dWeight > 0); - assert(dWeight <= dHeight); - - if (extend) - { - for (u64 i = 0; i < gap; ++i) - { - points.push_back({ rows - gap + i, cols - gap + i }); - } - } - - //auto b = trim ? cols - rows + gap : cols - rows; - auto b = cols - rows; - - auto diagOffset = sampleFixedColWeight(rows, b, weight, diag, randY, lPrng, points); - - u64 ii = trim ? 0 : rows - gap; - u64 e = trim ? rows - gap : rows; - - - //if (doubleBand.size()) - //{ - // if (dDiag || !trim) - // { - // std::cout << "assumed no dDiag and assumed trim" << std::endl; - // abort(); - // } - - // //for (auto db : doubleBand) - // //{ - // // assert(db >= 1); - - // // for (u64 j = db + gap, c = b; j < rows; ++j, ++c) - // // { - // // points.push_back({ j, c }); - // // } - // //} - //} - - if (period && dDiag) - throw RTE_LOC; - - if (period) - { - if (trim == false) - throw RTE_LOC; - - for (u64 p = 0; p < period; ++p) - { - std::set s; - - auto ww = dWeight - 1; - - assert(ww < dHeight); - - s = sampleCol(1, dHeight, ww, false, rPrng); - - for (auto db : doubleBand) - { - assert(db >= 1); - u64 j = db + gap; - s.insert(j); - - } - - for (u64 i = p; i < e; i += period) - { - - points.push_back({ i % rows, b + i }); - for (auto ss : s) - { - if(i + ss < rows) - points.push_back({ (i+ss), b + i }); - } - } - - } - } - else - { - - for (u64 i = 0; i < e; ++i, ++ii) - { - auto ww = dWeight - 1; - for (auto db : doubleBand) - { - assert(db >= 1); - u64 j = db + gap + ii; - - if (j >= rows) - { - if (dDiag) - ++ww; - } - else - points.push_back({ j, b + i }); - - } - assert(ww < dHeight); - - auto s = sampleCol(ii + 1, ii + dHeight, ww, false, rPrng); - - points.push_back({ ii % rows, b + i }); - for (auto ss : s) - points.push_back({ ss % rows, b + i }); - - } - } - } - - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, oc::PRNG& prng) - { - PointList points(rows, cols); - sampleTriangularBand(rows, cols, weight, - gap, dWeight, diag, 0, 0, - {}, false, false, false, prng, prng, points); - - return SparseMtx(rows, cols, points); - } - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& lPrng, - oc::PRNG& rPrng) - { - PointList points(rows, cols); - sampleTriangularBand( - rows, cols, - weight, gap, - dWeight, diag, dDiag, period, doubleBand, trim, extend, randY, - lPrng, rPrng, points); - - auto cc = (trim && !extend) ? cols - gap : cols; - - return SparseMtx(rows, cc, points); - } - - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleRegTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& prng) - { - auto cc = (trim && !extend) ? cols - gap : cols; - PointList points(rows, cc); - sampleRegTriangularBand( - rows, cols, - weight, gap, - dWeight, diag, dDiag, period, doubleBand, trim, extend, randY, - prng, points); - - - return SparseMtx(rows, cc, points); - } - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline void sampleTriangularLongBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, bool doubleBand, - oc::PRNG& prng, PointList& points) - { - auto dHeight = gap + 1; - assert(gap < rows); - assert(dWeight < weight); - assert(dWeight <= dHeight); - - //sampleFixedColWeight(rows, cols - rows, weight, diag, prng, points); - - std::set s; - for (u64 i = 0, ii = rows - gap; i < cols; ++i, ++ii) - { - if (doubleBand) - { - assert(dWeight >= 2); - s = sampleCol(ii + 1, ii + dHeight - 1, rows, dWeight - 2, false, prng); - s.insert((ii + dHeight) % rows); - //points.push_back({ () % rows, i }); - } - else - s = sampleCol(ii + 1, ii + dHeight, rows, dWeight - 1, false, prng); - - - s.insert(ii % rows); - - if (i < rows) - { - while (s.size() != weight) - { - auto j = prng.get() % rows; - s.insert(j); - } - } - - //points.push_back({ ii % rows, i }); - for (auto ss : s) - points.push_back({ ss % rows, i }); - - - - } - } - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangularLongBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, bool doubleBand, - oc::PRNG& prng) - { - PointList points(rows, cols); - sampleTriangularLongBand( - rows, cols, - weight, gap, - dWeight, diag, doubleBand, - prng, points); - - return SparseMtx(rows, cols, points); - } - - //// sample a parity check which is approx triangular with. - //// The diagonal will have fixed weight = dWeight. - //// The other columns will have weight = weight. - //inline void sampleTriangularBand2(u64 rows, u64 cols, u64 weight, u64 gap, u64 dWeight, oc::PRNG& prng, PointList& points) - //{ - // auto dHeight = gap + 1; - // assert(dWeight > 0); - // assert(dWeight <= dHeight); - - // sampleFixedColWeight(rows, cols - rows, weight, false, prng, points); - - // auto b = cols - rows; - // for (u64 i = 0, ii = rows - gap; i < rows; ++i, ++ii) - // { - // if (ii >= rows) - // { - // auto s = sampleCol(ii + 1, ii + dHeight, dWeight - 1, false, prng); - // points.push_back({ ii % rows, b + i }); - // for (auto ss : s) - // points.push_back({ ss % rows, b + i }); - // } - // else - // { - // auto s = sampleCol(ii, ii + dHeight, dWeight, false, prng); - // for (auto ss : s) - // points.push_back({ ss % rows, b + i }); - // } - - // } - //} - - //// sample a parity check which is approx triangular with. - //// The diagonal will have fixed weight = dWeight. - //// The other columns will have weight = weight. - //inline SparseMtx sampleTriangularBand2(u64 rows, u64 cols, u64 weight, u64 gap, u64 dWeight, oc::PRNG& prng) - //{ - // PointList points; - // sampleTriangularBand2(rows, cols, weight, gap, dWeight, prng, points); - // return SparseMtx(rows, cols, points); - //} - - - // sample a parity check which is approx triangular with. - // The other columns will have weight = weight. - inline void sampleTriangular(u64 rows, double density, oc::PRNG& prng, PointList& points) - { - assert(density > 0); - - u64 t = static_cast(~u64{ 0 } *density); - - for (u64 i = 0; i < rows; ++i) - { - points.push_back({ i, i }); - - for (u64 j = 0; j < i; ++j) - { - if (prng.get() < t) - { - points.push_back({ i, j }); - } - } - } - } - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangular(u64 rows, double density, oc::PRNG& prng) - { - PointList points(rows, rows); - sampleTriangular(rows, density, prng, points); - return SparseMtx(rows, rows, points); - } - - - inline SparseMtx sampleTriangular(u64 rows, u64 cols, u64 weight, u64 gap, oc::PRNG& prng) - { - PointList points(rows, cols); - sampleTriangular(rows, cols, weight, gap, prng, points); - return SparseMtx(rows, cols, points); - } - - - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/Mtx.cpp b/libOTe/Tools/LDPC/Mtx.cpp index 6430a207..1354ba1f 100644 --- a/libOTe/Tools/LDPC/Mtx.cpp +++ b/libOTe/Tools/LDPC/Mtx.cpp @@ -2,7 +2,6 @@ #include "cryptoTools/Crypto/PRNG.h" #include "Util.h" #include "cryptoTools/Common/Matrix.h" -#include "LdpcSampler.h" namespace osuCrypto { diff --git a/libOTe/Tools/LDPC/Util.h b/libOTe/Tools/LDPC/Util.h index a05844ce..783a1b36 100644 --- a/libOTe/Tools/LDPC/Util.h +++ b/libOTe/Tools/LDPC/Util.h @@ -61,13 +61,6 @@ namespace osuCrypto while (i < mK - 1 && mSet[i] + 1 == mSet[i + 1]) ++i; - //if (i == mK - 1 && mSet.back() == mN - 1) - //{ - // mSet.clear(); - // return; - // //assert(mSet.back() != mN - 1); - //} - ++mSet[i]; for (u64 j = 0; j < i; ++j) mSet[j] = j; @@ -79,13 +72,7 @@ namespace osuCrypto } }; -#ifdef ENABLE_ALGO994 - extern int alg994; - extern int num_saved_generators; - extern int num_cores; - extern int num_permutations; - extern int print_matrices; -#endif + } diff --git a/libOTe/Tools/QuasiCyclicCode.h b/libOTe/Tools/QuasiCyclicCode.h index d1429fc9..d9de8859 100644 --- a/libOTe/Tools/QuasiCyclicCode.h +++ b/libOTe/Tools/QuasiCyclicCode.h @@ -25,384 +25,333 @@ namespace osuCrypto { - // https://eprint.iacr.org/2019/1159.pdf - struct QuasiCyclicCode : public TimerAdapter - { - private: - u64 mScaler = 0; - u64 mNumThreads = 1; + // https://eprint.iacr.org/2019/1159.pdf + struct QuasiCyclicCode : public TimerAdapter + { + private: + + // the length of the encoding + u64 mMessageSize = 0; + + + // the length of the input. mCodeSize = mMessageSize * mScaler; + u64 mCodeSize = 0; + + // the next prime starting at mMessageSize. The + // real code size will in fact be of size mScaler * mPrimeMosulus + u64 mPrimeModulus = 0; + public: + + //size of the input + u64 size() { return mCodeSize; } + + // initialize the compressing matrix that maps a + // vector of size n * scaler to a vector of size n. + void init(u64 n, u64 scaler = 2) + { + if (scaler <= 1) + throw RTE_LOC; + + mMessageSize = n; + mPrimeModulus = nextPrime(n); + mCodeSize = mMessageSize * scaler; + } + void init2(u64 messageSize, u64 codeSize) + { + auto scaler = divCeil(codeSize, messageSize); + + mMessageSize = messageSize; + mPrimeModulus = nextPrime(messageSize); + mCodeSize = codeSize; + } + + static void bitShiftXor(span dest, span in, u8 bitShift) + { + if (bitShift > 127) + throw RTE_LOC; + if (u64(in.data()) % 16) + throw RTE_LOC; + + if (bitShift >= 64) + { + bitShift -= 64; + const int bitShift2 = 64 - bitShift; + u8* inPtr = ((u8*)in.data()) + sizeof(u64); + + auto end = std::min(dest.size(), in.size() - 1); + for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) + { + block + b0 = toBlock(inPtr), + b1 = toBlock(inPtr + sizeof(u64)); - // the length of the encoding - u64 mP = 0; + b0 = (b0 >> bitShift); + b1 = (b1 << bitShift2); - // the length of the input. mM = mP * mScaler; - u64 mM = 0; - - public: - - //size of the input - u64 size() { return mM; } - - // initialize the compressing matrix that maps a - // vector of size n * scaler to a vector of size n. - void init(u64 n, u64 scaler = 2) - { - if (scaler <= 1) - throw RTE_LOC; - - if (isPrime(n) == false) - throw RTE_LOC; - - mP = n;// nextPrime(n); - //mN = n; - mM = mP * scaler; - mScaler = scaler; - } - - static void bitShiftXor(span dest, span in, u8 bitShift) - { - if (bitShift > 127) - throw RTE_LOC; - if (u64(in.data()) % 16) - throw RTE_LOC; - - if (bitShift >= 64) - { - bitShift -= 64; - const int bitShift2 = 64 - bitShift; - u8* inPtr = ((u8*)in.data()) + sizeof(u64); - - auto end = std::min(dest.size(), in.size() - 1); - for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) - { - block - b0 = toBlock(inPtr), - b1 = toBlock(inPtr + sizeof(u64)); + dest[i] = dest[i] ^ b0 ^ b1; + } - b0 = (b0 >> bitShift); - b1 = (b1 << bitShift2); + if (end != static_cast(dest.size())) + { + u64 b0 = *(u64*)inPtr; + b0 = (b0 >> bitShift); - dest[i] = dest[i] ^ b0 ^ b1; - } + *(u64*)(&dest[end]) ^= b0; + } + } + else if (bitShift) + { + const int bitShift2 = 64 - bitShift; + u8* inPtr = (u8*)in.data(); + + auto end = std::min(dest.size(), in.size() - 1); + for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) + { + block + b0 = toBlock(inPtr), + b1 = toBlock(inPtr + sizeof(u64)); + + b0 = (b0 >> bitShift); + b1 = (b1 << bitShift2); + + //bv0.append((u8*)&b0, 128); + //bv1.append((u8*)&b1, 128); + + dest[i] = dest[i] ^ b0 ^ b1; + } + + if (end != static_cast(dest.size())) + { + block b0 = toBlock(inPtr); + b0 = (b0 >> bitShift); + + //bv0.append((u8*)&b0, 128); + + dest[end] = dest[end] ^ b0; + + u64 b1 = *(u64*)(inPtr + sizeof(u64)); + b1 = (b1 << bitShift2); + + //bv1.append((u8*)&b1, 64); + + *(u64*)&dest[end] ^= b1; + } + + + + //std::cout << " b0 " << bv0 << std::endl; + //std::cout << " b1 " << bv1 << std::endl; + } + else + { + auto end = std::min(dest.size(), in.size()); + for (u64 i = 0; i < end; ++i) + { + dest[i] = dest[i] ^ in[i]; + } + } + } - if (end != static_cast(dest.size())) - { - u64 b0 = *(u64*)inPtr; - b0 = (b0 >> bitShift); + static void modp(span dest, span in, u64 p) + { + auto pBlocks = (p + 127) / 128; + auto pBytes = (p + 7) / 8; - *(u64*)(&dest[end]) ^= b0; - } - } - else if (bitShift) - { - const int bitShift2 = 64 - bitShift; - u8* inPtr = (u8*)in.data(); - - auto end = std::min(dest.size(), in.size() - 1); - for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) - { - block - b0 = toBlock(inPtr), - b1 = toBlock(inPtr + sizeof(u64)); - - b0 = (b0 >> bitShift); - b1 = (b1 << bitShift2); - - //bv0.append((u8*)&b0, 128); - //bv1.append((u8*)&b1, 128); - - dest[i] = dest[i] ^ b0 ^ b1; - } - - if (end != static_cast(dest.size())) - { - block b0 = toBlock(inPtr); - b0 = (b0 >> bitShift); - - //bv0.append((u8*)&b0, 128); - - dest[end] = dest[end] ^ b0; - - u64 b1 = *(u64*)(inPtr + sizeof(u64)); - b1 = (b1 << bitShift2); - - //bv1.append((u8*)&b1, 64); - - *(u64*)&dest[end] ^= b1; - } - - - - //std::cout << " b0 " << bv0 << std::endl; - //std::cout << " b1 " << bv1 << std::endl; - } - else - { - auto end = std::min(dest.size(), in.size()); - for (u64 i = 0; i < end; ++i) - { - dest[i] = dest[i] ^ in[i]; - } - } - } - - static void modp(span dest, span in, u64 p) - { - auto pBlocks = (p + 127) / 128; - auto pBytes = (p + 7) / 8; - - if (static_cast(dest.size()) < pBlocks) - throw RTE_LOC; - - if (static_cast(in.size()) < pBlocks) - throw RTE_LOC; - - auto count = (in.size() * 128 + p - 1) / p; + if (static_cast(dest.size()) < pBlocks) + throw RTE_LOC; - memcpy(dest.data(), in.data(), pBytes); - - for (u64 i = 1; i < count; ++i) - { - auto begin = i * p; - auto end = std::min(i * p + p, in.size() * 128); + if (static_cast(in.size()) < pBlocks) + throw RTE_LOC; - auto shift = begin & 127; - auto beginBlock = in.data() + (begin / 128); - auto endBlock = in.data() + ((end + 127) / 128); + auto count = (in.size() * 128 + p - 1) / p; - if (endBlock > in.data() + in.size()) - throw RTE_LOC; + memcpy(dest.data(), in.data(), pBytes); + for (u64 i = 1; i < count; ++i) + { + auto begin = i * p; + auto end = std::min(i * p + p, in.size() * 128); - auto in_i = span(beginBlock, endBlock); + auto shift = begin & 127; + auto beginBlock = in.data() + (begin / 128); + auto endBlock = in.data() + ((end + 127) / 128); - bitShiftXor(dest, in_i, static_cast(shift)); - } + if (endBlock > in.data() + in.size()) + throw RTE_LOC; - auto offset = (p & 7); - if (offset) - { - u8 mask = (1 << offset) - 1; - auto idx = p / 8; - ((u8*)dest.data())[idx] &= mask; - } + auto in_i = span(beginBlock, endBlock); - auto rem = dest.size() * 16 - pBytes; - if (rem) - memset(((u8*)dest.data()) + pBytes, 0, rem); - } + bitShiftXor(dest, in_i, static_cast(shift)); + } - void dualEncode(span X) - { - std::vector XX(X.size()); - for (auto i : rng(X.size())) - { - if (X[i] > 1) - throw RTE_LOC; - XX[i] = block(X[i], X[i]); - } - dualEncode(XX); - for (auto i : rng(X.size())) - { - X[i] = XX[i] == ZeroBlock ? 0 : 1; - } - } + auto offset = (p & 7); + if (offset) + { + u8 mask = (1 << offset) - 1; + auto idx = p / 8; + ((u8*)dest.data())[idx] &= mask; + } - inline void transpose(span s, MatrixView r) - { - MatrixView ss((u8*)s.data(), s.size(), sizeof(block)); - MatrixView rr((u8*)r.data(), r.rows(), r.cols() * sizeof(block)); - ::oc::transpose(ss, rr); - } + auto rem = dest.size() * 16 - pBytes; + if (rem) + memset(((u8*)dest.data()) + pBytes, 0, rem); + } + void dualEncode(span X) + { + std::vector XX(X.size()); + for (auto i : rng(X.size())) + { + if (X[i] > 1) + throw RTE_LOC; - void dualEncode(span X) - { - if(X.size() != mM) - throw RTE_LOC; - const u64 rows(128); + XX[i] = block(X[i], X[i]); + } + dualEncode(XX); + for (auto i : rng(X.size())) + { + X[i] = XX[i] == ZeroBlock ? 0 : 1; + } + } - auto nBlocks = (mP + rows-1) / rows; - auto n2Blocks = ((mM-mP) + rows -1) / rows; + inline void transpose(span s, MatrixView r) + { + MatrixView ss((u8*)s.data(), s.size(), sizeof(block)); - Matrix XT(rows, n2Blocks); - transpose(X.subspan(mP), XT); + auto colLen = r.cols() * sizeof(block); + if (colLen < r.size() / 8) + throw RTE_LOC; - auto n64 = i64(nBlocks * 2); - - std::vector a(mScaler - 1); + MatrixView rr((u8*)r.data(), r.rows(), colLen); + ::oc::transpose(ss, rr); + } - MatrixcModP1(128, nBlocks, AllocType::Uninitialized); - //std::unique_ptr brs(new ThreadBarrier[mScaler + 1]); - //for (u64 i = 0; i <= mScaler; ++i) - //brs[i].reset(mNumThreads); + void dualEncode(span X) + { + if (X.size() != size()) + throw RTE_LOC; + const u64 rows(128); - //auto routine = [&](u64 index) - { - //u64 j = 0; + //auto scaler = divCeil(mCodeSize, mPrimeModulus); + auto remSize = X.size() - mMessageSize; + auto scalerMinusOne = divCeil(remSize, mPrimeModulus); + + // the number of blocks required to represent a single poly + auto polyBlockSize = divCeil(mPrimeModulus, 128); + + // the number of blocks required to represent scalerMinusOne poly's + auto multiPolyBlockSize = polyBlockSize * scalerMinusOne; + + Matrix XT(rows, multiPolyBlockSize); + transpose(X.subspan(mMessageSize), XT); + + auto polyU64Size = i64(polyBlockSize * 2); - //{ - // std::array tpBuffer; - // auto numBlocks = mM / 128; - // auto begin = index * numBlocks / mNumThreads; - // auto end = (index + 1) * numBlocks / mNumThreads; + std::vector a(scalerMinusOne); - // for (u64 i = begin; i < end; ++i) - // { - // u64 j = i * tpBuffer.size(); + MatrixcModP1(128, polyBlockSize, AllocType::Uninitialized); - // for (u64 k = 0; k < tpBuffer.size(); ++k) - // tpBuffer[k] = X[j + k]; + FFTPoly bPoly; + FFTPoly cPoly; - // transpose128(tpBuffer); + AlignedUnVector temp128(2 * polyBlockSize); - // auto end = i * tpBuffer.size() + 128; - // for (u64 k = 0; j < end; ++j, ++k) - // X[j] = tpBuffer[k]; - // } + FFTPoly::DecodeCache cache; + for (u64 s = 0; s < scalerMinusOne; s += 1) + { + auto a64 = spanCast(temp128).subspan(polyU64Size); + PRNG pubPrng(toBlock(s) ^ CCBlock); + pubPrng.get(a64.data(), a64.size()); + a[s].encode(a64); + } + + for (u64 i = 0; i < rows; i += 1) + { + + for (u64 s = 0; s < scalerMinusOne; ++s) + { + auto& aPoly = a[s]; + auto b64 = spanCast(XT[i]).subspan(s * polyU64Size, polyU64Size); + + bPoly.encode(b64); + + if (s == 0) + { + cPoly.mult(aPoly, bPoly); + } + else + { + bPoly.multEq(aPoly); + cPoly.addEq(bPoly); + } + } + + // decode c[i] and store it at t64Ptr + cPoly.decode(spanCast(temp128), cache, true); + + // reduce s[i] mod (x^p - 1) and store it at cModP1[i] + modp(cModP1[i], temp128, mPrimeModulus); + } + + AlignedArray tpBuffer; + auto numBlocks = divCeil(mMessageSize, 128); + auto begin = 0 * numBlocks; + auto end = (1) * numBlocks; + for (u64 i = begin; i < end; ++i) + { + u64 j = i * tpBuffer.size(); + auto min = std::min(tpBuffer.size(), mMessageSize - j); + + for (u64 k = 0; k < tpBuffer.size(); ++k) + tpBuffer[k] = cModP1(k, i); + + transpose128(tpBuffer.data()); + + auto end2 = i * tpBuffer.size() + min; + for (u64 k = 0; j < end2; ++j, ++k) + X[j] = X[j] ^ tpBuffer[k]; + } + + } - // if (index == 0) - // setTimePoint("sender.expand.qc.transposeXor"); - //} - //brs[j++].decrementWait(); + DenseMtx getMatrix() + { + + DenseMtx mtx(mCodeSize, mMessageSize); - FFTPoly bPoly; - FFTPoly cPoly; + for (u64 i = 0; i < mCodeSize; ++i) + { + std::vector in(mCodeSize); + in[i] = oc::AllOneBlock; - AlignedUnVector temp128(2 * nBlocks); - - FFTPoly::DecodeCache cache; - for (u64 s = 1; s < mScaler; s += 1) - { - auto a64 = spanCast(temp128).subspan(n64); - PRNG pubPrng(toBlock(s)); - pubPrng.get(a64.data(), a64.size()); - //memset(a64.data(), 0, a64.size() * sizeof(u64)); - //a64[0] = 1; - - a[s - 1].encode(a64); - } - - - //auto multAddReduce = [this, nBlocks, n64, &a, &bPoly, &cPoly, &temp128, &cache](span b128, span dest) - //{ - - //}; - - for (u64 i = 0; i < rows; i += 1) - { - - for (u64 s = 0; s < mScaler-1; ++s) - { - auto& aPoly = a[s]; - auto b64 = spanCast(XT[i]).subspan(s * n64, n64); - - bPoly.encode(b64); - - if (s == 0) - { - cPoly.mult(aPoly, bPoly); - } - else - { - bPoly.multEq(aPoly); - cPoly.addEq(bPoly); - } - } - - // decode c[i] and store it at t64Ptr - cPoly.decode(spanCast(temp128), cache, true); - - //for (u64 j = 0; j < nBlocks; ++j) - // temp128[j] = temp128[j] ^ XT[i][j]; - - // reduce s[i] mod (x^p - 1) and store it at cModP1[i] - modp(cModP1[i], temp128, mP); - } - //multAddReduce(rT[i], cModP1[i]); - - //if (index == 0) - // setTimePoint("sender.expand.qc.mulAddReduce"); - - //brs[j++].decrementWait(); - - { - - AlignedArray tpBuffer; - auto numBlocks = (mP + 127) / 128; - auto begin = 0 * numBlocks / mNumThreads; - auto end = (1) * numBlocks / mNumThreads; - for (u64 i = begin; i < end; ++i) - { - u64 j = i * tpBuffer.size(); - auto min = std::min(tpBuffer.size(), mP - j); - - for (u64 k = 0; k < tpBuffer.size(); ++k) - tpBuffer[k] = cModP1(k, i); - - transpose128(tpBuffer.data()); - - auto end = i * tpBuffer.size() + min; - for (u64 k = 0; j < end; ++j, ++k) - X[j] = X[j] ^ tpBuffer[k]; - } - - //if (index == 0) - // setTimePoint("sender.expand.qc.transposeXor"); - } - }; - - //std::vector thrds(mNumThreads - 1); - //for (u64 i = 0; i < thrds.size(); ++i) - // thrds[i] = std::thread(routine, i); - - //routine(thrds.size()); - - //for (u64 i = 0; i < thrds.size(); ++i) - // thrds[i].join(); - } - - - DenseMtx getMatrix() - { - - DenseMtx mtx(mM, mP); - - - - for (u64 i = 0; i < mM; ++i) - { - std::vector in(mM); - in[i] = oc::AllOneBlock; - - dualEncode(in); - - u64 w = 0; - for (u64 j = 0; j < mP; ++j) - { - if (in[j] == oc::AllOneBlock) - { - ++w; - mtx(i, j) = 1; - } - else if (in[j] == oc::ZeroBlock) - { - } - else - throw RTE_LOC; - } - - if (std::abs((long long)(mP - w)) < mP / 2 - std::sqrt(mP)) - throw RTE_LOC; - } - - return mtx; - } - }; + dualEncode(in); + + u64 w = 0; + for (u64 j = 0; j < mMessageSize; ++j) + { + if (in[j] == oc::AllOneBlock) + { + ++w; + mtx(i, j) = 1; + } + else if (in[j] == oc::ZeroBlock) + { + } + else + throw RTE_LOC; + } + + if (std::abs((long long)(mPrimeModulus - w)) < mPrimeModulus / 2 - std::sqrt(mPrimeModulus)) + throw RTE_LOC; + } + + return mtx; + } + }; } diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Subfield/SubfieldPprf.h index fdf8f0a1..91a2b0ae 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Subfield/SubfieldPprf.h @@ -119,10 +119,11 @@ namespace osuCrypto bool programPuncturedPoint, std::vector& buff, span, 2>>& sums, - span& leaf) + span& leaf, + CoeffCtx& ctx) { - u64 elementSize = CoeffCtx::byteSize(); + u64 elementSize = ctx.byteSize(); using SumType = std::array, 2>; // num of bytes they will take up. @@ -204,6 +205,8 @@ namespace osuCrypto { if (domainSize & 1) throw std::runtime_error("Pprf domain must be even. " LOCATION); + if (domainSize < 4) + throw std::runtime_error("Pprf domain must must be at least 4. " LOCATION); if (mPntCount % 8) throw std::runtime_error("pointCount must be a multiple of 8 (general case not impl). " LOCATION); @@ -248,7 +251,8 @@ namespace osuCrypto VecF& output, PprfOutputFormat oFormat, bool programPuncturedPoint, - u64 numThreads) + u64 numThreads, + CoeffCtx ctx = {}) { if (programPuncturedPoint) setValue(value); @@ -257,7 +261,7 @@ namespace osuCrypto validateExpandFormat(oFormat, output, mDomain, mPntCount); - MC_BEGIN(task<>, this, numThreads, oFormat, &output, seed, &chl, programPuncturedPoint, + MC_BEGIN(task<>, this, numThreads, oFormat, &output, seed, &chl, programPuncturedPoint, ctx, treeIndex = u64{}, tree = span>{}, levels = std::vector> >{}, @@ -289,15 +293,15 @@ namespace osuCrypto // we will use leaf level as a buffer before // copying the result to the output. leafIndex = 0; - CoeffCtx::resize(leafLevel, mDomain * 8); + ctx.resize(leafLevel, mDomain * 8); leafLevelPtr = &leafLevel; } // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs); + allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs, ctx); // exapnd the tree - expandOne(seed, treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs); + expandOne(seed, treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs, ctx); MC_AWAIT(chl.send(std::move(buff))); @@ -347,7 +351,8 @@ namespace osuCrypto VecF& leafLevel, const u64 leafOffset, span, 2>> encSums, - span leafMsgs) + span leafMsgs, + CoeffCtx ctx) { // the first level should be size 1, the root of the tree. @@ -466,10 +471,10 @@ namespace osuCrypto // clear the sums std::array, 2> leafSums; - CoeffCtx::resize(leafSums[0], 8); - CoeffCtx::resize(leafSums[1], 8); - CoeffCtx::zero(leafSums[0].begin(), leafSums[0].end()); - CoeffCtx::zero(leafSums[1].begin(), leafSums[1].end()); + ctx.resize(leafSums[0], 8); + ctx.resize(leafSums[1], 8); + ctx.zero(leafSums[0].begin(), leafSums[0].end()); + ctx.zero(leafSums[1].begin(), leafSums[1].end()); // for the leaf nodes we need to hash both children. for (u64 parentIdx = 0, outIdx = leafOffset, childIdx = 0; parentIdx < width; ++parentIdx) @@ -491,25 +496,25 @@ namespace osuCrypto // where each half defines one of the children. gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); - CoeffCtx::fromBlock(leafLevel[outIdx + 0], child[0]); - CoeffCtx::fromBlock(leafLevel[outIdx + 1], child[1]); - CoeffCtx::fromBlock(leafLevel[outIdx + 2], child[2]); - CoeffCtx::fromBlock(leafLevel[outIdx + 3], child[3]); - CoeffCtx::fromBlock(leafLevel[outIdx + 4], child[4]); - CoeffCtx::fromBlock(leafLevel[outIdx + 5], child[5]); - CoeffCtx::fromBlock(leafLevel[outIdx + 6], child[6]); - CoeffCtx::fromBlock(leafLevel[outIdx + 7], child[7]); + ctx.fromBlock(leafLevel[outIdx + 0], child[0]); + ctx.fromBlock(leafLevel[outIdx + 1], child[1]); + ctx.fromBlock(leafLevel[outIdx + 2], child[2]); + ctx.fromBlock(leafLevel[outIdx + 3], child[3]); + ctx.fromBlock(leafLevel[outIdx + 4], child[4]); + ctx.fromBlock(leafLevel[outIdx + 5], child[5]); + ctx.fromBlock(leafLevel[outIdx + 6], child[6]); + ctx.fromBlock(leafLevel[outIdx + 7], child[7]); // leafSum += child auto& leafSum = leafSums[keep]; - CoeffCtx::plus(leafSum[0], leafSum[0], leafLevel[outIdx + 0]); - CoeffCtx::plus(leafSum[1], leafSum[1], leafLevel[outIdx + 1]); - CoeffCtx::plus(leafSum[2], leafSum[2], leafLevel[outIdx + 2]); - CoeffCtx::plus(leafSum[3], leafSum[3], leafLevel[outIdx + 3]); - CoeffCtx::plus(leafSum[4], leafSum[4], leafLevel[outIdx + 4]); - CoeffCtx::plus(leafSum[5], leafSum[5], leafLevel[outIdx + 5]); - CoeffCtx::plus(leafSum[6], leafSum[6], leafLevel[outIdx + 6]); - CoeffCtx::plus(leafSum[7], leafSum[7], leafLevel[outIdx + 7]); + ctx.plus(leafSum[0], leafSum[0], leafLevel[outIdx + 0]); + ctx.plus(leafSum[1], leafSum[1], leafLevel[outIdx + 1]); + ctx.plus(leafSum[2], leafSum[2], leafLevel[outIdx + 2]); + ctx.plus(leafSum[3], leafSum[3], leafLevel[outIdx + 3]); + ctx.plus(leafSum[4], leafSum[4], leafLevel[outIdx + 4]); + ctx.plus(leafSum[5], leafSum[5], leafLevel[outIdx + 5]); + ctx.plus(leafSum[6], leafSum[6], leafLevel[outIdx + 6]); + ctx.plus(leafSum[7], leafSum[7], leafLevel[outIdx + 7]); } } @@ -524,7 +529,7 @@ namespace osuCrypto // This will be done by sending the sums and the sums plus // delta and ensure that they can only decrypt the correct ones. CoeffCtx::template Vec leafOts; - CoeffCtx::resize(leafOts, 2); + ctx.resize(leafOts, 2); PRNG otMasker; for (u64 j = 0; j < 8; ++j) @@ -541,20 +546,20 @@ namespace osuCrypto if (k == 0) { // m0 = (s0, s1 + val) - CoeffCtx::copy(leafOts[0], leafSums[0][j]); - CoeffCtx::plus(leafOts[1], leafSums[1][j], mValue[treeIdx + j]); + ctx.copy(leafOts[0], leafSums[0][j]); + ctx.plus(leafOts[1], leafSums[1][j], mValue[treeIdx + j]); } else { // m1 = (s0+val, s1) - CoeffCtx::plus(leafOts[0], leafSums[0][j], mValue[treeIdx + j]); - CoeffCtx::copy(leafOts[1], leafSums[1][j]); + ctx.plus(leafOts[0], leafSums[0][j], mValue[treeIdx + j]); + ctx.copy(leafOts[1], leafSums[1][j]); } // copy m0 into the output buffer. - span buff = leafMsgs.subspan(0, 2 * CoeffCtx::byteSize()); + span buff = leafMsgs.subspan(0, 2 * ctx.byteSize()); leafMsgs = leafMsgs.subspan(buff.size()); - CoeffCtx::serialize(leafOts.begin(), leafOts.end(), buff.begin()); + ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); // encrypt the output buffer. otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); @@ -567,7 +572,7 @@ namespace osuCrypto else { CoeffCtx::template Vec leafOts; - CoeffCtx::resize(leafOts, 1); + ctx.resize(leafOts, 1); PRNG otMasker; for (u64 j = 0; j < 8; ++j) @@ -575,10 +580,10 @@ namespace osuCrypto for (u64 k = 0; k < 2; ++k) { // copy the sum k into the output buffer. - CoeffCtx::copy(leafOts[0], leafSums[k][j]); - span buff = leafMsgs.subspan(0, CoeffCtx::byteSize()); + ctx.copy(leafOts[0], leafSums[k][j]); + span buff = leafMsgs.subspan(0, ctx.byteSize()); leafMsgs = leafMsgs.subspan(buff.size()); - CoeffCtx::serialize(leafOts.begin(), leafOts.end(), buff.begin()); + ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); // encrypt the output buffer. otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); @@ -626,6 +631,11 @@ namespace osuCrypto { if (domainSize & 1) throw std::runtime_error("Pprf domain must be even. " LOCATION); + if (domainSize < 4) + throw std::runtime_error("Pprf domain must must be at least 4. " LOCATION); + if (mPntCount % 8) + throw std::runtime_error("pointCount must be a multiple of 8 (general case not impl). " LOCATION); + mDomain = domainSize; mDepth = log2ceil(mDomain); mPntCount = pointCount; @@ -636,7 +646,7 @@ namespace osuCrypto // this function sample mPntCount integers in the range // [0,domain) and returns these as the choice bits. - BitVector sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng) + BitVector sampleChoiceBits(PRNG& prng) { BitVector choices(mPntCount * mDepth); @@ -659,7 +669,7 @@ namespace osuCrypto } // choices is in the same format as the output from sampleChoiceBits. - void setChoiceBits(PprfOutputFormat format, BitVector choices) + void setChoiceBits(const BitVector& choices) { // Make sure we're given the right number of OTs. if (choices.size() != baseOtCount()) @@ -763,11 +773,12 @@ namespace osuCrypto VecF& output, PprfOutputFormat oFormat, bool programPuncturedPoint, - u64 numThreads) + u64 numThreads, + CoeffCtx ctx = {}) { validateExpandFormat(oFormat, output, mDomain, mPntCount); - MC_BEGIN(task<>, this, oFormat, &output, &chl, programPuncturedPoint, + MC_BEGIN(task<>, this, oFormat, &output, &chl, programPuncturedPoint, ctx, treeIndex = u64{}, tree = span>{}, levels = std::vector>>{}, @@ -803,17 +814,17 @@ namespace osuCrypto // we will use leaf level as a buffer before // copying the result to the output. leafIndex = 0; - CoeffCtx::resize(leafLevel, mDomain * 8); + ctx.resize(leafLevel, mDomain * 8); leafLevelPtr = &leafLevel; } // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs); + allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs, ctx); MC_AWAIT(chl.recv(buff)); // exapnd the tree - expandOne(treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs); + expandOne(treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs, ctx); // if we aren't interleaved, we need to copy the // leaf layer to the output. @@ -849,7 +860,8 @@ namespace osuCrypto VecF& leafLevel, const u64 outputOffset, span, 2>> theirSums, - span leafMsg) + span leafMsg, + CoeffCtx ctx) { // We will process 8 trees at a time. @@ -1019,21 +1031,21 @@ namespace osuCrypto // inactiveChildValues to use the new hash and subtract // these from the leafSums CoeffCtx::template Vec temp; - CoeffCtx::resize(temp, 2); + ctx.resize(temp, 2); std::array, 2> leafSums; for (u64 k = 0; k < 2; ++k) { inactiveChildValues[k] = gGgmAes[k].hashBlock(ZeroBlock); - CoeffCtx::fromBlock(temp[k], inactiveChildValues[k]); + ctx.fromBlock(temp[k], inactiveChildValues[k]); // leafSum = -inactiveChildValues - CoeffCtx::resize(leafSums[k], 8); - CoeffCtx::zero(leafSums[k].begin(), leafSums[k].end()); - CoeffCtx::minus(leafSums[k][0], leafSums[k][0], temp[k]); + ctx.resize(leafSums[k], 8); + ctx.zero(leafSums[k].begin(), leafSums[k].end()); + ctx.minus(leafSums[k][0], leafSums[k][0], temp[k]); for (u64 i = 1; i < 8; ++i) - CoeffCtx::copy(leafSums[k][i], leafSums[k][0]); + ctx.copy(leafSums[k][i], leafSums[k][0]); } // for leaf nodes both children should be hashed. @@ -1052,26 +1064,26 @@ namespace osuCrypto // where each half defines one of the children. gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); - CoeffCtx::fromBlock(leafLevel[outputIdx + 0], child[0]); - CoeffCtx::fromBlock(leafLevel[outputIdx + 1], child[1]); - CoeffCtx::fromBlock(leafLevel[outputIdx + 2], child[2]); - CoeffCtx::fromBlock(leafLevel[outputIdx + 3], child[3]); - CoeffCtx::fromBlock(leafLevel[outputIdx + 4], child[4]); - CoeffCtx::fromBlock(leafLevel[outputIdx + 5], child[5]); - CoeffCtx::fromBlock(leafLevel[outputIdx + 6], child[6]); - CoeffCtx::fromBlock(leafLevel[outputIdx + 7], child[7]); + ctx.fromBlock(leafLevel[outputIdx + 0], child[0]); + ctx.fromBlock(leafLevel[outputIdx + 1], child[1]); + ctx.fromBlock(leafLevel[outputIdx + 2], child[2]); + ctx.fromBlock(leafLevel[outputIdx + 3], child[3]); + ctx.fromBlock(leafLevel[outputIdx + 4], child[4]); + ctx.fromBlock(leafLevel[outputIdx + 5], child[5]); + ctx.fromBlock(leafLevel[outputIdx + 6], child[6]); + ctx.fromBlock(leafLevel[outputIdx + 7], child[7]); auto& leafSum = leafSums[keep]; - CoeffCtx::plus(leafSum[0], leafSum[0], leafLevel[outputIdx + 0]); - CoeffCtx::plus(leafSum[1], leafSum[1], leafLevel[outputIdx + 1]); - CoeffCtx::plus(leafSum[2], leafSum[2], leafLevel[outputIdx + 2]); - CoeffCtx::plus(leafSum[3], leafSum[3], leafLevel[outputIdx + 3]); - CoeffCtx::plus(leafSum[4], leafSum[4], leafLevel[outputIdx + 4]); - CoeffCtx::plus(leafSum[5], leafSum[5], leafLevel[outputIdx + 5]); - CoeffCtx::plus(leafSum[6], leafSum[6], leafLevel[outputIdx + 6]); - CoeffCtx::plus(leafSum[7], leafSum[7], leafLevel[outputIdx + 7]); + ctx.plus(leafSum[0], leafSum[0], leafLevel[outputIdx + 0]); + ctx.plus(leafSum[1], leafSum[1], leafLevel[outputIdx + 1]); + ctx.plus(leafSum[2], leafSum[2], leafLevel[outputIdx + 2]); + ctx.plus(leafSum[3], leafSum[3], leafLevel[outputIdx + 3]); + ctx.plus(leafSum[4], leafSum[4], leafLevel[outputIdx + 4]); + ctx.plus(leafSum[5], leafSum[5], leafLevel[outputIdx + 5]); + ctx.plus(leafSum[6], leafSum[6], leafLevel[outputIdx + 6]); + ctx.plus(leafSum[7], leafSum[7], leafLevel[outputIdx + 7]); } } @@ -1085,7 +1097,7 @@ namespace osuCrypto // values. Two for each case (left active or right active). //timer.setTimePoint("recv.recvleaf"); VecF leafOts; - CoeffCtx::resize(leafOts, 2); + ctx.resize(leafOts, 2); PRNG otMasker; for (u64 j = 0; j < 8; ++j) @@ -1105,25 +1117,25 @@ namespace osuCrypto // decrypt the ot string - span buff = leafMsg.subspan(offset, CoeffCtx::byteSize() * 2); + span buff = leafMsg.subspan(offset, ctx.byteSize() * 2); leafMsg = leafMsg.subspan(buff.size() * 2); otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) buff[i] ^= otMasker.get(); - CoeffCtx::deserialize(buff.begin(), buff.end(), leafOts.begin()); + ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); auto out0 = (activeChildIdx & ~1ull) * 8 + j + outputOffset; auto out1 = (activeChildIdx | 1ull) * 8 + j + outputOffset; - CoeffCtx::minus(leafLevel[out0], leafOts[0], leafSums[0][j]); - CoeffCtx::minus(leafLevel[out1], leafOts[1], leafSums[1][j]); + ctx.minus(leafLevel[out0], leafOts[0], leafSums[0][j]); + ctx.minus(leafLevel[out1], leafOts[1], leafSums[1][j]); } } else { VecF leafOts; - CoeffCtx::resize(leafOts, 1); + ctx.resize(leafOts, 1); PRNG otMasker; for (u64 j = 0; j < 8; ++j) @@ -1141,13 +1153,13 @@ namespace osuCrypto auto offset = CoeffCtx::template byteSize() * notAi; // decrypt the ot string - span buff = leafMsg.subspan(offset, CoeffCtx::byteSize()); + span buff = leafMsg.subspan(offset, ctx.byteSize()); leafMsg = leafMsg.subspan(buff.size() * 2); otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) buff[i] ^= otMasker.get(); - CoeffCtx::deserialize(buff.begin(), buff.end(), leafOts.begin()); + ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); std::array out{ (activeChildIdx & ~1ull) * 8 + j + outputOffset, @@ -1157,8 +1169,8 @@ namespace osuCrypto auto keep = leafLevel.begin() + out[notAi]; auto zero = leafLevel.begin() + out[notAi ^ 1]; - CoeffCtx::minus(*keep, leafOts[0], leafSums[notAi][j]); - CoeffCtx::zero(zero, zero + 1); + ctx.minus(*keep, leafOts[0], leafSums[notAi][j]); + ctx.zero(zero, zero + 1); } } } diff --git a/libOTe/TwoChooseOne/ConfigureCode.cpp b/libOTe/TwoChooseOne/ConfigureCode.cpp index 7b84d6bf..2519de01 100644 --- a/libOTe/TwoChooseOne/ConfigureCode.cpp +++ b/libOTe/TwoChooseOne/ConfigureCode.cpp @@ -4,7 +4,6 @@ #include "cryptoTools/Common/Range.h" #include "libOTe/TwoChooseOne/TcoOtDefines.h" #include "libOTe/Tools/Tools.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" #include "libOTe/Tools/QuasiCyclicCode.h" #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" @@ -12,30 +11,6 @@ #include namespace osuCrypto { - //u64 secLevel(u64 scale, u64 n, u64 points) - //{ - // auto x1 = std::log2(scale * n / double(n)); - // auto x2 = std::log2(scale * n) / 2; - // return static_cast(points * x1 + x2); - //} - - //u64 getPartitions(u64 scaler, u64 n, u64 secParam) - //{ - // if (scaler < 2) - // throw std::runtime_error("scaler must be 2 or greater"); - - // u64 ret = 1; - // auto ss = secLevel(scaler, n, ret); - // while (ss < secParam) - // { - // ++ret; - // ss = secLevel(scaler, n, ret); - // if (ret > 1000) - // throw std::runtime_error("failed to find silent OT parameters"); - // } - // return roundUpTo(ret, 8); - //} - // We get e^{-2td} security against linear attacks, // with noise weigh t and minDist d. @@ -145,42 +120,31 @@ namespace osuCrypto void ExConvConfigure( - u64 numOTs, u64 secParam, + double scaler, MultType mMultType, - u64& mRequestedNumOTs, - u64& mNumPartitions, - u64& mSizePer, - u64& mN2, - u64& mN, - ExConvCode2& mEncoder - ) + u64& expanderWeight, + u64& accumulatorWeight, + double& minDist) { - u64 a = 24; - auto mScaler = 2; - u64 w; - double minDist; + if (scaler != 2) + throw RTE_LOC; switch (mMultType) { case osuCrypto::MultType::ExConv7x24: - w = 7; - minDist = 0.1; + accumulatorWeight = 24; + expanderWeight = 7; + minDist = 0.2; // psuedo min dist estimate break; case osuCrypto::MultType::ExConv21x24: - w = 21; - minDist = 0.15; + accumulatorWeight = 24; + expanderWeight = 21; + minDist = 0.25; // psuedo min dist estimate break; default: throw RTE_LOC; break; } - mRequestedNumOTs = numOTs; - mNumPartitions = getRegNoiseWeight(minDist, secParam); - mSizePer = roundUpTo((numOTs * mScaler + mNumPartitions - 1) / mNumPartitions, 8); - mN2 = mSizePer * mNumPartitions; - mN = mN2 / mScaler; - - mEncoder.config(numOTs, numOTs * mScaler, w, a, true); } diff --git a/libOTe/TwoChooseOne/ConfigureCode.h b/libOTe/TwoChooseOne/ConfigureCode.h index 4af551da..720e75d5 100644 --- a/libOTe/TwoChooseOne/ConfigureCode.h +++ b/libOTe/TwoChooseOne/ConfigureCode.h @@ -10,11 +10,7 @@ namespace osuCrypto { // https://eprint.iacr.org/2019/1159.pdf QuasiCyclic = 1, -#ifdef ENABLE_INSECURE_SILVER - // https://eprint.iacr.org/2021/1150, see https://eprint.iacr.org/2023/882 for attack. - slv5 = 2, - slv11 = 3, -#endif + // https://eprint.iacr.org/2022/1014 ExAcc7 = 4, // fast ExAcc11 = 5,// fast but more conservative @@ -33,14 +29,7 @@ namespace osuCrypto case osuCrypto::MultType::QuasiCyclic: o << "QuasiCyclic"; break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - o << "slv5"; - break; - case osuCrypto::MultType::slv11: - o << "slv11"; - break; -#endif + case osuCrypto::MultType::ExAcc7: o << "ExAcc7"; break; @@ -106,32 +95,14 @@ namespace osuCrypto ); - class ExConvCode2; void ExConvConfigure( - u64 numOTs, u64 secParam, + double scaler, MultType mMultType, - u64& mRequestedNumOTs, - u64& mNumPartitions, - u64& mSizePer, - u64& mN2, - u64& mN, - ExConvCode2& mEncoder + u64& expanderWeight, + u64& accumulatorWeight, + double& minDist ); -#ifdef ENABLE_INSECURE_SILVER - struct SilverEncoder; - void SilverConfigure( - u64 numOTs, u64 secParam, - MultType mMultType, - u64& mRequestedNumOTs, - u64& mNumPartitions, - u64& mSizePer, - u64& mN2, - u64& mN, - u64& gap, - SilverEncoder& mEncoder); -#endif - void QuasiCyclicConfigure( u64 numOTs, u64 secParam, u64 scaler, @@ -143,4 +114,15 @@ namespace osuCrypto u64& mN, u64& mP, u64& mScaler); + + + inline void QuasiCyclicConfigure( + double mScaler, + double& minDist) + { + if (mScaler == 2) + minDist = 0.2; // psuedo min dist + else + throw RTE_LOC; // not impl + } } \ No newline at end of file diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index 24b890e3..321d886e 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -3,7 +3,6 @@ #include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" #include -#include #include #include diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h index ec43ff6c..0d1334ce 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include "libOTe/Tools/EACode/EACode.h" diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h index 36247933..9ad0c7b4 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h @@ -19,7 +19,6 @@ #include #include #include -#include #include #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" @@ -119,10 +118,6 @@ namespace osuCrypto // The flag which controls whether the malicious check is performed. SilentSecType mMalType = SilentSecType::SemiHonest; - // The Silver encoder for MultType::slv5, MultType::slv11 -#ifdef ENABLE_INSECURE_SILVER - SilverEncoder mEncoder; -#endif ExConvCode mExConvEncoder; EACode mEAEncoder; diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.cpp b/libOTe/Vole/Silent/SilentVoleReceiver.cpp index 8d1972b8..ecd25b4e 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.cpp +++ b/libOTe/Vole/Silent/SilentVoleReceiver.cpp @@ -4,7 +4,6 @@ #include "libOTe/Vole/Silent/SilentVoleSender.h" #include "libOTe/Vole/Noisy/NoisyVoleReceiver.h" #include -#include #include #include diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index be968a6f..91e4c110 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include diff --git a/libOTe/Vole/Silent/SilentVoleSender.cpp b/libOTe/Vole/Silent/SilentVoleSender.cpp index a78d6736..fa79e962 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.cpp +++ b/libOTe/Vole/Silent/SilentVoleSender.cpp @@ -9,7 +9,6 @@ #include "libOTe/Tools/Tools.h" #include "cryptoTools/Common/Log.h" #include "cryptoTools/Crypto/RandomOracle.h" -#include "libOTe/Tools/LDPC/LdpcSampler.h" namespace osuCrypto diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index 99610316..ebd38289 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -64,9 +63,6 @@ namespace osuCrypto #endif MultType mMultType = DefaultMultType; -#ifdef ENABLE_INSECURE_SILVER - SilverEncoder mEncoder; -#endif ExConvCode mExConvEncoder; EACode mEAEncoder; diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h index fea4d61c..a9fbc31e 100644 --- a/libOTe/Vole/Subfield/SilentVoleReceiver.h +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -18,27 +18,31 @@ #include #include #include -#include #include #include #include #include #include - -namespace osuCrypto::Subfield +#include +#include "libOTe/Tools/QuasiCyclicCode.h" +namespace osuCrypto { template< typename F, typename G = F, - typename CoeffCtx = DefaultCoeffCtx + typename Ctx = DefaultCoeffCtx > class SilentSubfieldVoleReceiver : public TimerAdapter { public: static constexpr u64 mScaler = 2; + static constexpr bool MaliciousSupported = + std::is_same_v&& + std::is_same_v; + enum class State { Default, @@ -46,25 +50,29 @@ namespace osuCrypto::Subfield HasBase }; - using VecF = typename CoeffCtx::template Vec; - using VecG = typename CoeffCtx::template Vec; + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; // The current state of the protocol State mState = State::Default; - // The number of OTs the user requested. - u64 mRequestedNumOTs = 0; + // the context used to perform F, G operations + Ctx mCtx; + + // The number of correlations the user requested. + u64 mRequestSize = 0; - // The number of OTs actually produced (at least the number requested). - u64 mN = 0; + // the LPN security parameter + u64 mSecParam = 0; // The length of the noisy vectors (2 * mN for the most codes). - u64 mN2 = 0; + u64 mNoiseVecSize = 0; // We perform regular LPN, so this is the // size of the each chunk. u64 mSizePer = 0; + // the number of noisy positions u64 mNumPartitions = 0; // The noisy coordinates. @@ -77,17 +85,15 @@ namespace osuCrypto::Subfield // the sparse vector. MultType mMultType = DefaultMultType; - ExConvCode2 mExConvEncoder; - // The multi-point punctured PRF for generating // the sparse vectors. - SilentSubfieldPprfReceiver mGen; + SilentSubfieldPprfReceiver mGen; // The internal buffers for holding the expanded vectors. - // mA + mB = mC * delta + // mA = mB + mC * delta VecF mA; - // mA + mB = mC * delta + // mA = mB + mC * delta VecG mC; u64 mNumThreads = 1; @@ -98,10 +104,11 @@ namespace osuCrypto::Subfield SilentSecType mMalType = SilentSecType::SemiHonest; - block mMalCheckSeed, mMalCheckX, mDeltaShare; + block mMalCheckSeed, mMalCheckX, mMalBaseA; - VecF mNoiseDeltaShare; - VecG mNoiseValues; + // we + VecF mBaseA; + VecG mBaseC; #ifdef ENABLE_SOFTSPOKEN_OT @@ -150,8 +157,8 @@ namespace osuCrypto::Subfield chl2 = Socket{}, prng2 = std::move(PRNG{}), noiseVals = VecG{}, - noiseDeltaShares = VecF{}, - nv = NoisySubfieldVoleReceiver{} + baseAs = VecF{}, + nv = NoisySubfieldVoleReceiver{} ); @@ -177,7 +184,7 @@ namespace osuCrypto::Subfield // other party will program the PPRF to output their share of delta * noiseVals. // noiseVals = sampleBaseVoleVals(prng); - CoeffCtx::resize(noiseDeltaShares, noiseVals.size()); + Ctx::resize(baseAs, noiseVals.size()); if (mTimer) nv.setTimer(*mTimer); @@ -202,7 +209,7 @@ namespace osuCrypto::Subfield bb); msg.resize(msg.size() - mOtExtSender.baseOtCount()); - MC_AWAIT(nv.receive(noiseVals, noiseDeltaShares, prng, mOtExtSender, chl)); + MC_AWAIT(nv.receive(noiseVals, baseAs, prng, mOtExtSender, chl)); } else { @@ -212,7 +219,7 @@ namespace osuCrypto::Subfield MC_AWAIT( macoro::when_all_ready( - nv.receive(noiseVals, noiseDeltaShares, prng2, mOtExtSender, chl2), + nv.receive(noiseVals, baseAs, prng2, mOtExtSender, chl2), mOtExtRecver.receive(choice, msg, prng, chl) )); } @@ -225,10 +232,10 @@ namespace osuCrypto::Subfield chl2 = chl.fork(); prng2.SetSeed(prng.get()); MC_AWAIT(baseOt.receive(choice, msg, prng, chl)); - MC_AWAIT(nv.receive(noiseVals, noiseDeltaShares, prng2, baseOt, chl2)); + MC_AWAIT(nv.receive(noiseVals, baseAs, prng2, baseOt, chl2)); } - setSilentBaseOts(msg, noiseDeltaShares); + setSilentBaseOts(msg, baseAs); setTimePoint("SilentVoleReceiver.genSilent.done"); MC_END(); }; @@ -238,25 +245,38 @@ namespace osuCrypto::Subfield // will be needed. These can then be ganerated for // a different OT extension or using a base OT protocol. void configure( - u64 numOTs, + u64 requestSize, SilentBaseType type = SilentBaseType::BaseExtend, - u64 secParam = 128) + u64 secParam = 128, + Ctx ctx = {}) { + mCtx = std::move(ctx); + mSecParam = secParam; + mRequestSize = requestSize; mState = State::Configured; mBaseType = type; - + double minDist = 0; switch (mMultType) { case osuCrypto::MultType::ExConv7x24: case osuCrypto::MultType::ExConv21x24: - - ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); + { + u64 _1, _2; + ExConvConfigure(mScaler, mMultType, _1, _2, minDist); + break; + } + case MultType::QuasiCyclic: + QuasiCyclicConfigure(mScaler, minDist); break; default: throw RTE_LOC; break; } + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); + mNoiseVecSize = mSizePer * mNumPartitions; + mGen.configure(mSizePer, mNumPartitions); } @@ -283,7 +303,7 @@ namespace osuCrypto::Subfield if (isConfigured() == false) throw std::runtime_error("configure(...) must be called first"); - auto choice = mGen.sampleChoiceBits(mN2, getPprfFormat(), prng); + auto choice = mGen.sampleChoiceBits(prng); return choice; } @@ -294,59 +314,45 @@ namespace osuCrypto::Subfield throw RTE_LOC; // sample the values of the noisy coordinate of c - // and perform a noicy vole to get x+y = mD * c - auto w = mNumPartitions; - std::vector seeds(w); - CoeffCtx::resize(mNoiseValues, w); - prng.get(seeds.data(), seeds.size()); - for (size_t i = 0; i < w; i++) { - CoeffCtx::fromBlock(mNoiseValues[i], seeds[i]); - } + // and perform a noicy vole to get a = b + mD * c + + Ctx::resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); + for (size_t i = 0; i < mNumPartitions; i++) + Ctx::fromBlock(mBaseC[i], prng.get()); mS.resize(mNumPartitions); - mGen.getPoints(mS, getPprfFormat()); + mGen.getPoints(mS, PprfOutputFormat::Interleaved); - // if (mMalType == SilentSecType::Malicious) - // { - // - // mMalCheckSeed = prng.get(); - // mMalCheckX = ZeroBlock; - // auto yIter = mNoiseValues.begin(); - // - // for (u64 i = 0; i < mNumPartitions; ++i) - // { - // auto s = mS[i]; - // auto xs = mMalCheckSeed.gf128Pow(s + 1); - // mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); - // ++yIter; - // } - // - // auto sIter = mS.begin() + mNumPartitions; - // for (u64 i = 0; i < mGapBaseChoice.size(); ++i) - // { - // if (mGapBaseChoice[i]) - // { - // auto s = *sIter; - // auto xs = mMalCheckSeed.gf128Pow(s + 1); - // mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); - // ++sIter; - // } - // ++yIter; - // } - // - // - // std::vector y(mNoiseValues.begin(), mNoiseValues.end()); - // y.push_back(mMalCheckX); - // return y; - // } + if (mMalType == SilentSecType::Malicious) + { + if constexpr (MaliciousSupported) + { + mMalCheckSeed = prng.get(); - return mNoiseValues; + auto yIter = mBaseC.begin(); + mCtx.zero(mBaseC.end() - 1, mBaseC.end()); + for (u64 i = 0; i < mNumPartitions; ++i) + { + auto s = mS[i]; + auto xs = mMalCheckSeed.gf128Pow(s + 1); + mBaseC[mNumPartitions] = mBaseC[mNumPartitions] ^ xs.gf128Mul(*yIter); + ++yIter; + } + } + else + { + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + } + } + + return mBaseC; } // Set the externally generated base OTs. This choice // bits must be the one return by sampleBaseChoiceBits(...). - void setSilentBaseOts(span recvBaseOts, - span noiseDeltaShare) + void setSilentBaseOts( + span recvBaseOts, + VecF& baseA) { if (isConfigured() == false) throw std::runtime_error("configure(...) must be called first."); @@ -356,15 +362,8 @@ namespace osuCrypto::Subfield mGen.setBase(recvBaseOts); - // if (mMalType == SilentSecType::Malicious) - // { - // mDeltaShare = noiseDeltaShare.back(); - // noiseDeltaShare = noiseDeltaShare.subspan(0, noiseDeltaShare.size() - 1); - // } - - CoeffCtx::resize(mNoiseDeltaShare, noiseDeltaShare.size()); - CoeffCtx::copy(noiseDeltaShare.begin(), noiseDeltaShare.end(), mNoiseDeltaShare.begin()); - + Ctx::resize(mBaseA, baseA.size()); + Ctx::copy(baseA.begin(), baseA.end(), mBaseA.begin()); mState = State::HasBase; } @@ -373,19 +372,19 @@ namespace osuCrypto::Subfield // this function is non-interactive. Otherwise // the silent base OTs will automatically be performed. task<> silentReceive( - span c, - span a, + VecG& c, + VecF& a, PRNG& prng, Socket& chl) { - MC_BEGIN(task<>, this, c, a, &prng, &chl); + MC_BEGIN(task<>, this, &c, &a, &prng, &chl); if (c.size() != a.size()) throw RTE_LOC; MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); - CoeffCtx::copy(mC.begin(), mC.begin() + c.size(), c.begin()); - CoeffCtx::copy(mA.begin(), mA.begin() + a.size(), a.begin()); + Ctx::copy(mC.begin(), mC.begin() + c.size(), c.begin()); + Ctx::copy(mA.begin(), mA.begin() + a.size(), a.begin()); clear(); MC_END(); @@ -410,10 +409,9 @@ namespace osuCrypto::Subfield { // first generate 128 normal base OTs configure(n, SilentBaseType::BaseExtend); - // configure(n, SilentBaseType::Base); } - if (mRequestedNumOTs != n) + if (mRequestSize != n) throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); if (hasSilentBaseOts() == false) @@ -422,20 +420,38 @@ namespace osuCrypto::Subfield } // allocate mA - CoeffCtx::resize(mA, 0); - CoeffCtx::resize(mA, mN2); + Ctx::resize(mA, 0); + Ctx::resize(mA, mNoiseVecSize); setTimePoint("SilentVoleReceiver.alloc"); // allocate the space for mC - CoeffCtx::resize(mC, 0); - CoeffCtx::resize(mC, mN2); - CoeffCtx::zero(mC.begin(), mC.end()); + Ctx::resize(mC, 0); + Ctx::resize(mC, mNoiseVecSize); + Ctx::zero(mC.begin(), mC.end()); setTimePoint("SilentVoleReceiver.alloc.zero"); if (mTimer) mGen.setTimer(*mTimer); - // expand the seeds into mA + + // As part of the setup, we have generated + // + // mBaseA + mBaseB = mBaseC * mDelta + // + // We have mBaseA, mBaseC, + // they have mBaseB, mDelta + // This was done with a small (noisy) vole. + // + // We use the Pprf to expand as + // + // mA' = mB + mS(mBaseB) + // = mB + mS(mBaseC * mDelta - mBaseA) + // = mB + mS(mBaseC * mDelta) - mS(mBaseA) + // + // Therefore if we add mS(mBaseA) to mA' we will get + // + // mA = mB + mS(mBaseC * mDelta) + // MC_AWAIT(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); @@ -445,8 +461,8 @@ namespace osuCrypto::Subfield for (u64 i = 0; i < mNumPartitions; ++i) { auto pnt = mS[i]; - CoeffCtx::copy(mC[pnt], mNoiseValues[i]); - CoeffCtx::minus(mA[pnt], mA[pnt], mNoiseDeltaShare[i]); + Ctx::copy(mC[pnt], mBaseC[i]); + Ctx::plus(mA[pnt], mA[pnt], mBaseA[i]); } if (mDebug) @@ -456,43 +472,68 @@ namespace osuCrypto::Subfield } - // if (mMalType == SilentSecType::Malicious) - // { - // MC_AWAIT(chl.send(std::move(mMalCheckSeed))); - // - // myHash = ferretMalCheck(mDeltaShare, mNoiseValues); - // - // MC_AWAIT(chl.recv(theirHash)); - // - // if (theirHash != myHash) - // throw RTE_LOC; - // } + if (mMalType == SilentSecType::Malicious) + { + MC_AWAIT(chl.send(std::move(mMalCheckSeed))); + + if constexpr (MaliciousSupported) + myHash = ferretMalCheck(); + else + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + + MC_AWAIT(chl.recv(theirHash)); + + if (theirHash != myHash) + throw RTE_LOC; + } switch (mMultType) { case osuCrypto::MultType::ExConv7x24: case osuCrypto::MultType::ExConv21x24: - if (mTimer) { - mExConvEncoder.setTimer(getTimer()); - } + { + u64 expanderWeight, accumulatorWeight; + double _; + ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _); + ExConvCode2 encoder; + encoder.config(mRequestSize, mNoiseVecSize, expanderWeight, accumulatorWeight); - mExConvEncoder.dualEncode2( + if (mTimer) + encoder.setTimer(getTimer()); + + encoder.dualEncode2( mA.begin(), mC.begin() ); - break; + } + case osuCrypto::MultType::QuasiCyclic: + { + if constexpr ( + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + QuasiCyclicCode encoder; + encoder.init2(mRequestSize, mNoiseVecSize); + encoder.dualEncode(mA); + encoder.dualEncode(mC); + } + else + throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); + break; + } default: - throw RTE_LOC; + throw std::runtime_error("Code is not supported. " LOCATION); break; } // resize the buffers down to only contain the real elements. - CoeffCtx::resize(mA, mRequestedNumOTs); - CoeffCtx::resize(mC, mRequestedNumOTs); + Ctx::resize(mA, mRequestSize); + Ctx::resize(mC, mRequestSize); - mNoiseValues = {}; - mNoiseDeltaShare = {}; + mBaseC = {}; + mBaseA = {}; // make the protocol as done and that // mA,mC are ready to be consumed. @@ -507,115 +548,136 @@ namespace osuCrypto::Subfield task<> checkRT(Socket& chl) const { MC_BEGIN(task<>, this, &chl, - B = typename CoeffCtx::Vec{}, - sparseNoiseDelta = typename CoeffCtx::Vec{}, - noiseDeltaShare2 = typename CoeffCtx::Vec{}, - delta = typename CoeffCtx::Vec{}, - tempF = typename CoeffCtx::Vec{}, - tempG = typename CoeffCtx::Vec{}, + B = typename Ctx::Vec{}, + sparseNoiseDelta = typename Ctx::Vec{}, + baseB = typename Ctx::Vec{}, + delta = typename Ctx::Vec{}, + tempF = typename Ctx::Vec{}, + tempG = typename Ctx::Vec{}, buffer = std::vector{} ); // recv delta - buffer.resize(CoeffCtx::byteSize()); + buffer.resize(Ctx::byteSize()); + Ctx::resize(delta, 1); MC_AWAIT(chl.recv(buffer)); - CoeffCtx::deserialize(buffer.begin(), buffer.end(), delta.begin()); + Ctx::deserialize(buffer.begin(), buffer.end(), delta.begin()); // recv B - buffer.resize(CoeffCtx::byteSize() * mA.size()); + buffer.resize(Ctx::byteSize() * mA.size()); + Ctx::resize(B, mA.size()); MC_AWAIT(chl.recv(buffer)); - CoeffCtx::deserialize(buffer.begin(), buffer.end(), B.begin()); + Ctx::deserialize(buffer.begin(), buffer.end(), B.begin()); // recv the noisy values. - buffer.resize(CoeffCtx::byteSize() * mNoiseDeltaShare.size()); + buffer.resize(Ctx::byteSize() * mBaseA.size()); + Ctx::resize(baseB, mBaseA.size()); MC_AWAIT(chl.recvResize(buffer)); - CoeffCtx::deserialize(buffer.begin(), buffer.end(), noiseDeltaShare2.begin()); + Ctx::deserialize(buffer.begin(), buffer.end(), baseB.begin()); - //check that at locations mS[0],...,mS[..] - // that we hold a sharing mA, mB of - // - // delta * mC = delta * (00000 noiseDeltaShare2[0] 0000 .... 0000 noiseDeltaShare2[m] 0000) - // - // where noiseDeltaShare2[i] is at position mS[i] of mC - // - // That is, I hold mA, mC s.t. + // it shoudl hold that + // + // mBaseA = baseB + mBaseC * mDelta // + // and + // // mA = mB + mC * mDelta // - - CoeffCtx::resize(tempF, 2); - CoeffCtx::resize(tempG, 1); - CoeffCtx::zero(tempG.begin(), tempG.end()); - - for (auto i : rng(mNoiseDeltaShare.size())) { - // temp[0] = mNoiseDeltaShare[i] + noiseDeltaShare2[i] - CoeffCtx::plus(tempF[0], mNoiseDeltaShare[i], noiseDeltaShare2[i]); + bool verbose = false; + bool failed = false; + std::vector index(mS.size()); + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(), + [&](std::size_t i, std::size_t j) { return mS[i] < mS[j]; }); + + Ctx::resize(tempF, 2); + Ctx::resize(tempG, 1); + Ctx::zero(tempG.begin(), tempG.end()); + + + // check the correlation that + // + // mBaseA + mBaseB = mBaseC * mDelta + for (auto i : rng(mBaseA.size())) + { + // temp[0] = baseB[i] + mBaseA[i] + Ctx::plus(tempF[0], baseB[i], mBaseA[i]); - // temp[1] = mNoiseValues[i] * delta[0] - CoeffCtx::mul(tempF[1], delta[0], mNoiseValues[i]); + // temp[1] = mBaseC[i] * delta[0] + Ctx::mul(tempF[1], delta[0], mBaseC[i]); - if (!CoeffCtx::eq(tempF[0], tempF[1])) - throw RTE_LOC; - } + if (!Ctx::eq(tempF[0], tempF[1])) + throw RTE_LOC; - { - - for (auto i : rng(mNumPartitions* mSizePer)) - { - auto iter = std::find(mS.begin(), mS.end(), i); - if (iter != mS.end()) + if (i < mNumPartitions) { - auto d = iter - mS.begin(); + //auto idx = index[i]; + auto point = mS[i]; + if (!Ctx::eq(mBaseC[i], mC[point])) + throw RTE_LOC; - if (!CoeffCtx::eq(mC[i], mNoiseValues[d])) + if (i && mS[index[i - 1]] >= mS[index[i]]) throw RTE_LOC; + } + } - // temp[0] = A[i] + B[i] - CoeffCtx::plus(tempF[0], mA[i], B[i]); - // temp[1] = mNoiseValues[d] * delta[0] - CoeffCtx::mul(tempF[1], delta[0], mNoiseValues[d]); + auto iIter = index.begin(); + auto leafIdx = mS[*iIter]; + F act = tempF[0]; + G zero = tempG[0]; + Ctx::zero(tempG.begin(), tempG.end()); + for (u64 j = 0; j < mA.size(); ++j) + { + Ctx::mul(act, delta[0], mC[j]); + Ctx::plus(act, act, B[j]); - if (!CoeffCtx::eq(tempF[0], tempF[1])) - { - std::cout << "bad vole base noisy correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; - std::cout << "i " << i << std::endl; - //std::cout << "mA[i] " << mA[i] << std::endl; - //std::cout << "mB[i] " << B[i] << std::endl; - //std::cout << "mC[i] " << mC[i] << std::endl; - //std::cout << "delta " << delta << std::endl; - //std::cout << "mA[i] + mB[i] " << (mA[i] ^ B[i]) << std::endl; - //std::cout << "mC[i] * delta " << (mC[i].gf128Mul(delta)) << std::endl; + bool active = false; + if (j == leafIdx) + { + active = true; + } + else if (!Ctx::eq(zero, mC[j])) + throw RTE_LOC; - throw RTE_LOC; - } + if (mA[j] != act) + { + failed = true; + if (verbose) + std::cout << Color::Red; } - else + + if (verbose) { - if (!CoeffCtx::eq(mA[i], B[i])) - { - std::cout << "bad vole base non-noisy correlation, mA[i] + mB[i] != 0" << std::endl; - //std::cout << mA[i] << " " << B[i] << std::endl; - throw RTE_LOC; - } + std::cout << j << " act " << Ctx::str(act) + << " a " << Ctx::str(mA[j]) << " b " << Ctx::str(B[j]); - if (!CoeffCtx::eq(mC[i], tempG[0])) + if (active) + std::cout << " < " << Ctx::str(delta[0]); + + std::cout << std::endl << Color::Default; + } + + if (j == leafIdx) + { + ++iIter; + if (iIter != index.end()) { - std::cout << "bad vole base non-noisy correlation, mC[i] != 0" << std::endl; - throw RTE_LOC; + leafIdx = mS[*iIter]; } } } + + if (failed) + throw RTE_LOC; } MC_END(); } - std::array ferretMalCheck( - block deltaShare, - span y) + std::array ferretMalCheck() { block xx = mMalCheckSeed; @@ -634,20 +696,17 @@ namespace osuCrypto::Subfield // xx = mMalCheckSeed^{i+1} xx = xx.gf128Mul(mMalCheckSeed); } + + // = < block mySum = sum0.gf128Reduce(sum1); std::array myHash; RandomOracle ro(32); - ro.Update(mySum ^ deltaShare); + ro.Update(mySum ^ mBaseA.back()); ro.Final(myHash); return myHash; } - PprfOutputFormat getPprfFormat() - { - return PprfOutputFormat::Interleaved; - } - void clear() { mS = {}; diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h index bbe38f97..e3185477 100644 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -30,13 +30,17 @@ namespace osuCrypto template< typename F, typename G = F, - typename CoeffCtx = DefaultCoeffCtx + typename Ctx = DefaultCoeffCtx > class SilentSubfieldVoleSender : public TimerAdapter { public: static constexpr u64 mScaler = 2; + static constexpr bool MaliciousSupported = + std::is_same_v&& + std::is_same_v; + enum class State { Default, @@ -44,37 +48,69 @@ namespace osuCrypto HasBase }; - using VecF = typename CoeffCtx::template Vec; - using VecG = typename CoeffCtx::template Vec; + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; State mState = State::Default; - SilentSubfieldPprfSender mGen; + // the context used to perform F, G operations + Ctx mCtx; + + // the pprf used to generate the noise vector. + SilentSubfieldPprfSender mGen; + + // the number of correlations requested. + u64 mRequestSize = 0; - u64 mRequestedNumOTs = 0; - u64 mN2 = 0; - u64 mN = 0; + // the length of the noisy vector. + u64 mNoiseVecSize = 0; + + // the weight of the nosy vector u64 mNumPartitions = 0; + + // the size of each regular, weight 1, subvector + // of the noisy vector. mNoiseVecSize = mNumPartions * mSizePer u64 mSizePer = 0; - u64 mNumThreads = 1; - SilentBaseType mBaseType; - VecF mNoiseDeltaShares; + // the lpn security parameters + u64 mSecParam = 0; + + // the type of base OT OT that should be performed. + // Base requires more work but less communication. + SilentBaseType mBaseType = SilentBaseType::BaseExtend; + + // the base Vole correlation. To generate the silent vole, + // we must first create a small vole + // mBaseA + mBaseB = mBaseC * mDelta. + // These will be used to initialize the non-zeros of the noisy + // vector. mBaseB is the b in this corrlations. + VecF mBaseB; + + // the full sized noisy vector. This will initalially be + // sparse with the corrlations + // mA = mB + mC * mDelta + // before it is compressed. + VecF mB; + + // determines if the malicious checks are performed. SilentSecType mMalType = SilentSecType::SemiHonest; + // A flag to specify the linear code to use + MultType mMultType = DefaultMultType; + + + block mDeltaShare; + #ifdef ENABLE_SOFTSPOKEN_OT SoftSpokenMalOtSender mOtExtSender; SoftSpokenMalOtReceiver mOtExtRecver; #endif - MultType mMultType = DefaultMultType; - - ExConvCode2 mExConvEncoder; - VecF mB; - u64 baseVoleCount() const { + u64 baseVoleCount() const + { return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); } @@ -82,32 +118,28 @@ namespace osuCrypto // base OTs are set then we do an IKNP extend, // otherwise we perform a base OT protocol to // generate the needed OTs. - task<> genSilentBaseOts(PRNG& prng, Socket& chl, cp::optional delta = {}) + task<> genSilentBaseOts(PRNG& prng, Socket& chl, F delta) { using BaseOT = DefaultBaseOT; - MC_BEGIN(task<>, this, delta, &prng, &chl, msg = AlignedUnVector>(silentBaseOtCount()), baseOt = BaseOT{}, prng2 = std::move(PRNG{}), xx = BitVector{}, chl2 = Socket{}, - nv = NoisySubfieldVoleSender{}, - noiseDeltaShares = std::vector{} + nv = NoisySubfieldVoleSender{}, + b = VecF{} ); setTimePoint("SilentVoleSender.genSilent.begin"); if (isConfigured() == false) throw std::runtime_error("configure must be called first"); - if(!delta) - CoeffCtx::fromBlock(*delta, prng.get()); - - xx = CoeffCtx::binaryDecomposition(*delta); + xx = mCtx.binaryDecomposition(delta); // compute the correlation for the noisy coordinates. - noiseDeltaShares.resize(baseVoleCount()); + b.resize(baseVoleCount()); if (mBaseType == SilentBaseType::BaseExtend) @@ -125,7 +157,7 @@ namespace osuCrypto mOtExtRecver.baseOtCount())); msg.resize(msg.size() - mOtExtRecver.baseOtCount()); - MC_AWAIT(nv.send(*delta, noiseDeltaShares, prng, mOtExtRecver, chl)); + MC_AWAIT(nv.send(delta, b, prng, mOtExtRecver, chl)); } else { @@ -134,7 +166,7 @@ namespace osuCrypto MC_AWAIT( macoro::when_all_ready( - nv.send(*delta, noiseDeltaShares, prng2, mOtExtRecver, chl2), + nv.send(delta, b, prng2, mOtExtRecver, chl2), mOtExtSender.send(msg, prng, chl))); } #else @@ -145,16 +177,16 @@ namespace osuCrypto { chl2 = chl.fork(); prng2.SetSeed(prng.get()); - MC_AWAIT(baseOt.send(msg, prng, chl)); - MC_AWAIT(nv.send(*delta, noiseDeltaShares, prng2, baseOt, chl2)); - // MC_AWAIT( - // macoro::when_all_ready( - // nv.send(*delta, noiseDeltaShares, prng2, baseOt, chl2), - // baseOt.send(msg, prng, chl))); + //MC_AWAIT(baseOt.send(msg, prng, chl)); + //MC_AWAIT(nv.send(delta, b, prng2, baseOt, chl2)); + MC_AWAIT( + macoro::when_all_ready( + nv.send(delta, b, prng2, baseOt, chl2), + baseOt.send(msg, prng, chl))); } - setSilentBaseOts(msg, noiseDeltaShares); + setSilentBaseOts(msg, b); setTimePoint("SilentVoleSender.genSilent.done"); MC_END(); } @@ -164,27 +196,40 @@ namespace osuCrypto // will be needed. These can then be ganerated for // a different OT extension or using a base OT protocol. void configure( - u64 numOTs, + u64 requestSize, SilentBaseType type = SilentBaseType::BaseExtend, - u64 secParam = 128) + u64 secParam = 128, + Ctx ctx = {}) { + mCtx = std::move(ctx); + mSecParam = secParam; + mRequestSize = requestSize; + mState = State::Configured; mBaseType = type; + double minDist = 0; switch (mMultType) { case osuCrypto::MultType::ExConv7x24: case osuCrypto::MultType::ExConv21x24: - - ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); + { + u64 _1, _2; + ExConvConfigure(mScaler, mMultType, _1, _2, minDist); + break; + } + case MultType::QuasiCyclic: + QuasiCyclicConfigure(mScaler, minDist); break; default: throw RTE_LOC; break; } - mGen.configure(mSizePer, mNumPartitions); + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); + mNoiseVecSize = mSizePer * mNumPartitions; - mState = State::Configured; + mGen.configure(mSizePer, mNumPartitions); } // return true if this instance has been configured. @@ -204,17 +249,30 @@ namespace osuCrypto // bits must be the one return by sampleBaseChoiceBits(...). void setSilentBaseOts( span> sendBaseOts, - span noiseDeltaShares) + const VecF& b) { if ((u64)sendBaseOts.size() != silentBaseOtCount()) throw RTE_LOC; - if (noiseDeltaShares.size() != baseVoleCount()) + if (b.size() != baseVoleCount()) throw RTE_LOC; mGen.setBase(sendBaseOts); - mNoiseDeltaShares.resize(noiseDeltaShares.size()); - std::copy(noiseDeltaShares.begin(), noiseDeltaShares.end(), mNoiseDeltaShares.begin()); + + // we store the negative of b. This is because + // we need the correlation + // + // mBaseA + mBaseB = mBaseC * delta + // + // for the pprf to expand correctly but the + // input correlation is a vole: + // + // mBaseA = b + mBaseC * delta + // + mCtx.resize(mBaseB, b.size()); + mCtx.zero(mBaseB.begin(), mBaseB.end()); + for (u64 i = 0; i < mBaseB.size(); ++i) + mCtx.minus(mBaseB[i], mBaseB[i], b[i]); } // The native OT extension interface of silent @@ -224,16 +282,16 @@ namespace osuCrypto // send(...) interface for the normal behavior. task<> silentSend( F delta, - span b, + VecF& b, PRNG& prng, Socket& chl) { - MC_BEGIN(task<>, this, delta, b, &prng, &chl); + MC_BEGIN(task<>, this, delta, &b, &prng, &chl); MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); - CoeffCtx::copy(mB.begin(), mB.begin() + b.size(), b.begin()); - //std::memcpy(b.data(), mB.data(), b.size() * CoeffCtx::bytesF); + mCtx.copy(mB.begin(), mB.begin() + b.size(), b.begin()); + //std::memcpy(b.data(), mB.data(), b.size() * mCtx.bytesF); clear(); setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); @@ -254,7 +312,8 @@ namespace osuCrypto MC_BEGIN(task<>, this, delta, n, &prng, &chl, deltaShare = block{}, X = block{}, - hash = std::array{} + hash = std::array{}, + baseB = VecF{} ); setTimePoint("SilentVoleSender.ot.enter"); @@ -263,10 +322,9 @@ namespace osuCrypto { // first generate 128 normal base OTs configure(n, SilentBaseType::BaseExtend); - // configure(n, SilentBaseType::Base); } - if (mRequestedNumOTs != n) + if (mRequestSize != n) throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); if (mGen.hasBaseOts() == false) @@ -278,25 +336,24 @@ namespace osuCrypto setTimePoint("SilentVoleSender.start"); //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); - //if (mMalType == SilentSecType::Malicious) - //{ - // deltaShare = mNoiseDeltaShares.back(); - // mNoiseDeltaShares.pop_back(); - //} - // allocate B - CoeffCtx::resize(mB, 0); - CoeffCtx::resize(mB, mN2); + mCtx.resize(mB, 0); + mCtx.resize(mB, mNoiseVecSize); if (mTimer) mGen.setTimer(*mTimer); + // extract just the first mNumPartitions value of mBaseB. + // the last is for the malicious check (if present). + mCtx.resize(baseB, mNumPartitions); + mCtx.copy(mBaseB.begin(), mBaseB.begin() + mNumPartitions, baseB.begin()); + // program the output the PPRF to be secret shares of // our secret share of delta * noiseVals. The receiver // can then manually add their shares of this to the // output of the PPRF at the correct locations. - MC_AWAIT(mGen.expand(chl, mNoiseDeltaShares, prng.get(), mB, - PprfOutputFormat::Interleaved, true, mNumThreads)); + MC_AWAIT(mGen.expand(chl, baseB, prng.get(), mB, + PprfOutputFormat::Interleaved, true, 1)); setTimePoint("SilentVoleSender.expand.pprf"); if (mDebug) @@ -305,32 +362,59 @@ namespace osuCrypto setTimePoint("SilentVoleSender.expand.checkRT"); } - //if (mMalType == SilentSecType::Malicious) - //{ - // MC_AWAIT(chl.recv(X)); - // hash = ferretMalCheck(X, deltaShare); - // MC_AWAIT(chl.send(std::move(hash))); - //} + if (mMalType == SilentSecType::Malicious) + { + MC_AWAIT(chl.recv(X)); + + if constexpr (MaliciousSupported) + hash = ferretMalCheck(X); + else + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + + MC_AWAIT(chl.send(std::move(hash))); + } switch (mMultType) { case osuCrypto::MultType::ExConv7x24: case osuCrypto::MultType::ExConv21x24: - if (mTimer) { - mExConvEncoder.setTimer(getTimer()); + { + ExConvCode2 encoder; + u64 expanderWeight, accumulatorWeight; + double _1; + ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _1); + encoder.config(mRequestSize, mNoiseVecSize, expanderWeight, accumulatorWeight); + if (mTimer) + encoder.setTimer(getTimer()); + encoder.dualEncode(mB.begin()); + break; + } + case MultType::QuasiCyclic: + { + if constexpr ( + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + QuasiCyclicCode encoder; + encoder.init2(mRequestSize, mNoiseVecSize); + encoder.dualEncode(mB); } - mExConvEncoder.dualEncode(mB.begin()); + else + throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); + break; + } default: - throw RTE_LOC; + throw std::runtime_error("Code is not supported. " LOCATION); break; } - CoeffCtx::resize(mB, mRequestedNumOTs); + mCtx.resize(mB, mRequestSize); mState = State::Default; - mNoiseDeltaShares.clear(); + mBaseB.clear(); MC_END(); } @@ -342,11 +426,11 @@ namespace osuCrypto MC_BEGIN(task<>, this, &chl, delta); MC_AWAIT(chl.send(delta)); MC_AWAIT(chl.send(mB)); - MC_AWAIT(chl.send(mNoiseDeltaShares)); + MC_AWAIT(chl.send(mBaseB)); MC_END(); } - std::array ferretMalCheck(block X, block deltaShare) + std::array ferretMalCheck(block X) { auto xx = X; @@ -366,7 +450,7 @@ namespace osuCrypto std::array myHash; RandomOracle ro(32); - ro.Update(mySum ^ deltaShare); + ro.Update(mySum ^ mBaseB.back()); ro.Final(myHash); return myHash; diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp index 0d4fd4e6..6d4a5a62 100644 --- a/libOTe_Tests/Pprf_Tests.cpp +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -29,7 +29,7 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) auto numOTs = sender.baseOtCount(); std::vector> sendOTs(numOTs); std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * pntCount, format, prng); + BitVector recvBits = recver.sampleChoiceBits(prng); prng.get(sendOTs.data(), sendOTs.size()); @@ -66,13 +66,14 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) Ctx::Vec rLeafLevel(8ull << depth); u64 leafOffset = 0; - allocateExpandBuffer(depth - 1, program, sBuff, sSums, sLast); + Ctx ctx; + allocateExpandBuffer(depth - 1, program, sBuff, sSums, sLast, ctx); recver.mPoints.resize(roundUpTo(recver.mPntCount, 8)); recver.getPoints(recver.mPoints, PprfOutputFormat::ByLeafIndex); - sender.expandOne(seed, 0, program, sLevels, sLeafLevel, leafOffset, sSums, sLast); - recver.expandOne(0, program, rLevels, rLeafLevel, leafOffset, sSums, sLast); + sender.expandOne(seed, 0, program, sLevels, sLeafLevel, leafOffset, sSums, sLast, ctx); + recver.expandOne(0, program, rLevels, rLeafLevel, leafOffset, sSums, sLast, ctx); bool failed = false; for (u64 i = 0; i < pntCount; ++i) @@ -148,7 +149,7 @@ void Tools_Pprf_expandOne_test(const oc::CLP& cmd) #if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - for (u64 domain : { 8, 128, 4522}) for (bool program : {true, false}) + for (u64 domain : { 4, 128, 4522}) for (bool program : {true, false}) { Tools_Pprf_expandOne_test_impl(domain, program); @@ -193,7 +194,7 @@ void Tools_Pprf_test_impl( auto numOTs = sender.baseOtCount(); std::vector> sendOTs(numOTs); std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); + BitVector recvBits = recver.sampleChoiceBits(prng); prng.get(sendOTs.data(), sendOTs.size()); for (u64 i = 0; i < numOTs; ++i) @@ -399,7 +400,7 @@ void Tools_Pprf_ByLeafIndex_test(const CLP& cmd) auto f = PprfOutputFormat::ByLeafIndex; auto v = cmd.isSet("v"); - for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false */}) { Tools_Pprf_test_impl(d, n, p, f, v); Tools_Pprf_test_impl(d, n, p, f, v); @@ -418,7 +419,7 @@ void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd) auto f = PprfOutputFormat::ByTreeIndex; auto v = cmd.isSet("v"); - for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false*/ }) { Tools_Pprf_test_impl(d, n, p, f, v); Tools_Pprf_test_impl(d, n, p, f, v); @@ -437,7 +438,7 @@ void Tools_Pprf_callback_test(const oc::CLP& cmd) auto f = PprfOutputFormat::Callback; auto v = cmd.isSet("v"); - for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false */}) { Tools_Pprf_test_impl(d, n, p, f, v); Tools_Pprf_test_impl(d, n, p, f, v); diff --git a/libOTe_Tests/Subfield_Test.h b/libOTe_Tests/Subfield_Test.h index 8df6c24f..2aa90901 100644 --- a/libOTe_Tests/Subfield_Test.h +++ b/libOTe_Tests/Subfield_Test.h @@ -1,7 +1,7 @@ #include "cryptoTools/Common/CLP.h" -namespace osuCrypto::Subfield +namespace osuCrypto { diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp index 95377320..3a7ea0c0 100644 --- a/libOTe_Tests/Subfield_Tests.cpp +++ b/libOTe_Tests/Subfield_Tests.cpp @@ -8,73 +8,13 @@ #include "Common.h" -namespace osuCrypto::Subfield +namespace osuCrypto { static_assert(std::is_trivially_copyable_v>); static_assert(std::is_trivially_copyable_v); using tests_libOTe::eval; - void Subfield_Tools_Pprf_test(const oc::CLP& cmd) { -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - //{ - // u64 domain = cmd.getOr("d", 16); - // auto threads = cmd.getOr("t", 1); - // u64 numPoints = cmd.getOr("s", 1) * 8; - - // PRNG prng(ZeroBlock); - - // auto sockets = cp::LocalAsyncSocket::makePair(); - - // auto format = PprfOutputFormat::Interleaved; - // SilentSubfieldPprfSender sender; - // SilentSubfieldPprfReceiver recver; - - // sender.configure(domain, numPoints); - // recver.configure(domain, numPoints); - - // auto numOTs = sender.baseOtCount(); - // std::vector> sendOTs(numOTs); - // std::vector recvOTs(numOTs); - // BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); - // //recvBits.randomize(prng); - - // //recvBits[16] = 1; - // prng.get(sendOTs.data(), sendOTs.size()); - // for (u64 i = 0; i < numOTs; ++i) { - // //recvBits[i] = 0; - // recvOTs[i] = sendOTs[i][recvBits[i]]; - // } - // sender.setBase(sendOTs); - // recver.setBase(recvOTs); - - // //auto cols = (numPoints * domain + 127) / 128; - // Matrix sOut2(numPoints * domain, 1); - // Matrix rOut2(numPoints * domain, 1); - // std::vector points(numPoints); - // recver.getPoints(points, format); - - // std::vector arr(numPoints); - // prng.get(arr.data(), arr.size()); - // auto p0 = sender.expand(sockets[0], arr, prng, sOut2, format, true, threads); - // auto p1 = recver.expand(sockets[1], prng, rOut2, format, true, threads); - - // eval(p0, p1); - // for (u64 i = 0; i < numPoints; i++) { - // u64 point = points[i]; - // auto exp = sOut2(point) + arr[i]; - // if (exp != rOut2(point)) { - // throw RTE_LOC; - // } - // } - //} - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif - } - template void subfield_vole_test(u64 n) { @@ -133,7 +73,7 @@ namespace osuCrypto::Subfield } void Subfield_Silent_Vole_test(const oc::CLP& cmd) { - using namespace oc::Subfield; + using namespace oc; #if defined(ENABLE_SILENTOT) Timer timer; timer.setTimePoint("start"); @@ -141,141 +81,140 @@ namespace osuCrypto::Subfield u64 nt = cmd.getOr("nt", std::thread::hardware_concurrency()); block seed = block(0, cmd.getOr("seed", 0)); - { - PRNG prng(seed); - u64 x = prng.get(); - std::vector c(n), z0(n), z1(n); + //{ + // PRNG prng(seed); + // u64 x = prng.get(); + // std::vector c(n), z0(n), z1(n); - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + // SilentSubfieldVoleReceiver recv; + // SilentSubfieldVoleSender send; - recv.mMultType = MultType::ExConv7x24; - send.mMultType = MultType::ExConv7x24; + // recv.mMultType = MultType::ExConv7x24; + // send.mMultType = MultType::ExConv7x24; - recv.setTimer(timer); - send.setTimer(timer); + // recv.setTimer(timer); + // send.setTimer(timer); - // recv.mDebug = true; - // send.mDebug = true; + // // recv.mDebug = true; + // // send.mDebug = true; - auto chls = cp::LocalAsyncSocket::makePair(); + // auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); + // timer.setTimePoint("net"); - timer.setTimePoint("ot"); - // fakeBase(n, nt, prng, delta, recv, send); + // timer.setTimePoint("ot"); + // // fakeBase(n, nt, prng, delta, recv, send); - auto p0 = send.silentSend(x, span(z0), prng, chls[0]); - auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); + // auto p0 = send.silentSend(x, span(z0), prng, chls[0]); + // auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); - eval(p0, p1); - timer.setTimePoint("send"); - for (u64 i = 0; i < n; ++i) { - u64 left = c[i] * x; - u64 right = z1[i] - z0[i]; - if (left != right) { - std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; - std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; - throw RTE_LOC; - } - } - } + // eval(p0, p1); + // timer.setTimePoint("send"); + // for (u64 i = 0; i < n; ++i) { + // u64 left = c[i] * x; + // u64 right = z1[i] - z0[i]; + // if (left != right) { + // std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; + // std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; + // throw RTE_LOC; + // } + // } + //} - { - PRNG prng(seed); - constexpr size_t N = 10; - using G = u32; - using F = std::array; - using CoeffCtx = CoeffCtxArray; - F x; - CoeffCtx::fromBlock(x, prng.get()); - std::vector c(n); - std::vector a(n), b(n); - - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; - - recv.mMultType = MultType::ExConv7x24; - send.mMultType = MultType::ExConv7x24; - - recv.setTimer(timer); - send.setTimer(timer); - - // recv.mDebug = true; - // send.mDebug = true; - - auto chls = cp::LocalAsyncSocket::makePair(); - - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - // fakeBase(n, nt, prng, delta, recv, send); - - auto p0 = send.silentSend(x, span(b), prng, chls[0]); - auto p1 = recv.silentReceive(span(c), span(a), prng, chls[1]); - - eval(p0, p1); - // std::cout << "transferred " << (chls[0].bytesSent() + chls[0].bytesReceived()) << std::endl; - timer.setTimePoint("verify"); - - timer.setTimePoint("send"); - for (u64 i = 0; i < n; i++) { - for (u64 j = 0; j < N; j++) { - throw RTE_LOC;// fix this - // c = a delta + b - // c - b = a delta - //G left = a[i] * delta[j]; - //G right = c[i][j] - b[i][j]; - //if (left != right) { - // std::cout << "bad " << i << "\n a[i] " << a[i] << " * delta[j] " << delta[j] << " = " << left << std::endl; - // std::cout << "c[i][j] " << c[i][j] << " - b " << b[i][j] << " = " << right << std::endl; - // throw RTE_LOC; - //} - } - } - } + //{ + // PRNG prng(seed); + // constexpr size_t N = 10; + // using G = u32; + // using F = std::array; + // using CoeffCtx = CoeffCtxArray; + // F x; + // CoeffCtx::fromBlock(x, prng.get()); + // std::vector c(n); + // std::vector a(n), b(n); - { - PRNG prng(seed); - block x = prng.get(); - std::vector c(n), z0(n), z1(n); + // SilentSubfieldVoleReceiver recv; + // SilentSubfieldVoleSender send; - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + // recv.mMultType = MultType::ExConv7x24; + // send.mMultType = MultType::ExConv7x24; - recv.mMultType = MultType::ExConv7x24; - send.mMultType = MultType::ExConv7x24; + // recv.setTimer(timer); + // send.setTimer(timer); - recv.setTimer(timer); - send.setTimer(timer); + // // recv.mDebug = true; + // // send.mDebug = true; - // recv.mDebug = true; - // send.mDebug = true; + // auto chls = cp::LocalAsyncSocket::makePair(); - auto chls = cp::LocalAsyncSocket::makePair(); + // timer.setTimePoint("net"); - timer.setTimePoint("net"); + // timer.setTimePoint("ot"); + // // fakeBase(n, nt, prng, delta, recv, send); - timer.setTimePoint("ot"); - // fakeBase(n, nt, prng, delta, recv, send); + // auto p0 = send.silentSend(x, span(b), prng, chls[0]); + // auto p1 = recv.silentReceive(span(c), span(a), prng, chls[1]); - auto p0 = send.silentSend(x, span(z0), prng, chls[0]); - auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); + // eval(p0, p1); + // // std::cout << "transferred " << (chls[0].bytesSent() + chls[0].bytesReceived()) << std::endl; + // timer.setTimePoint("verify"); + + // timer.setTimePoint("send"); + // for (u64 i = 0; i < n; i++) { + // for (u64 j = 0; j < N; j++) { + // throw RTE_LOC;// fix this + // // c = a delta + b + // // c - b = a delta + // //G left = a[i] * delta[j]; + // //G right = c[i][j] - b[i][j]; + // //if (left != right) { + // // std::cout << "bad " << i << "\n a[i] " << a[i] << " * delta[j] " << delta[j] << " = " << left << std::endl; + // // std::cout << "c[i][j] " << c[i][j] << " - b " << b[i][j] << " = " << right << std::endl; + // // throw RTE_LOC; + // //} + // } + // } + //} - eval(p0, p1); - timer.setTimePoint("send"); - for (u64 i = 0; i < n; ++i) { - block left = x.gf128Mul(c[i]); - block right = z1[i] ^ z0[i]; - if (left != right) { - std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; - std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; - throw RTE_LOC; - } - } - } + //{ + // PRNG prng(seed); + // block x = prng.get(); + // std::vector c(n), z0(n), z1(n); + + // SilentSubfieldVoleReceiver recv; + // SilentSubfieldVoleSender send; + + // recv.mMultType = MultType::ExConv7x24; + // send.mMultType = MultType::ExConv7x24; + + // recv.setTimer(timer); + // send.setTimer(timer); + + // // recv.mDebug = true; + // // send.mDebug = true; + + // auto chls = cp::LocalAsyncSocket::makePair(); + + // timer.setTimePoint("net"); - timer.setTimePoint("done"); + // timer.setTimePoint("ot"); + // // fakeBase(n, nt, prng, delta, recv, send); + + // auto p0 = send.silentSend(x, span(z0), prng, chls[0]); + // auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); + + // eval(p0, p1); + // timer.setTimePoint("send"); + // for (u64 i = 0; i < n; ++i) { + // block left = x.gf128Mul(c[i]); + // block right = z1[i] ^ z0[i]; + // if (left != right) { + // std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; + // std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; + // throw RTE_LOC; + // } + // } + //} + //timer.setTimePoint("done"); // std::cout << timer << std::endl; #else throw UnitTestSkipped("not defined." LOCATION); diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 5011d77c..3864bd9b 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -11,8 +11,6 @@ #include "libOTe_Tests/SoftSpoken_Tests.h" #include "libOTe_Tests/bitpolymul_Tests.h" #include "libOTe_Tests/Vole_Tests.h" -#include "libOTe/Tools/LDPC/LdpcDecoder.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" #include "libOTe_Tests/ExConvCode_Tests.h" #include "libOTe_Tests/EACode_Tests.h" #include "libOTe/Tools/LDPC/Mtx.h" @@ -42,18 +40,6 @@ namespace tests_libOTe tc.add("Mtx_add_test ", tests::Mtx_add_test); tc.add("Mtx_mult_test ", tests::Mtx_mult_test); tc.add("Mtx_invert_test ", tests::Mtx_invert_test); -#ifdef ENABLE_LDPC - tc.add("ldpc_Mtx_block_test ", tests::Mtx_block_test); - tc.add("LdpcDecode_pb_test ", tests::LdpcDecode_pb_test); - tc.add("LdpcEncoder_diagonalSolver_test ", tests::LdpcEncoder_diagonalSolver_test); - tc.add("LdpcEncoder_encode_test ", tests::LdpcEncoder_encode_test); - tc.add("LdpcEncoder_encode_g0_test ", tests::LdpcEncoder_encode_g0_test); - tc.add("LdpcS1Encoder_encode_test ", tests::LdpcS1Encoder_encode_test); - tc.add("LdpcS1Encoder_encode_Trans_test ", tests::LdpcS1Encoder_encode_Trans_test); - tc.add("LdpcComposit_RegRepDiagBand_encode_test ", tests::LdpcComposit_RegRepDiagBand_encode_test); - tc.add("LdpcComposit_RegRepDiagBand_Trans_test ", tests::LdpcComposit_RegRepDiagBand_Trans_test); -#endif - tc.add("EACode_encode_basic_test ", EACode_encode_basic_test); tc.add("ExConvCode_encode_basic_test ", ExConvCode_encode_basic_test); @@ -65,10 +51,6 @@ namespace tests_libOTe tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); - tc.add("Subfield_Tools_Pprf_test ", Subfield::Subfield_Tools_Pprf_test); - tc.add("Subfield_Noisy_Vole_test ", Subfield::Subfield_Noisy_Vole_test); - tc.add("Subfield_Silent_Vole_test ", Subfield::Subfield_Silent_Vole_test); - tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); @@ -109,10 +91,10 @@ namespace tests_libOTe tc.add("OtExt_SoftSpokenMalicious21_Test ", OtExt_SoftSpokenMalicious21_Test); tc.add("OtExt_SoftSpokenMalicious21_Split_Test ", OtExt_SoftSpokenMalicious21_Split_Test); tc.add("DotExt_SoftSpokenMaliciousLeaky_Test ", DotExt_SoftSpokenMaliciousLeaky_Test); - + + tc.add("Subfield_Silent_Vole_test ", Subfield_Silent_Vole_test); tc.add("Vole_Noisy_test ", Vole_Noisy_test); tc.add("Vole_Silent_QuasiCyclic_test ", Vole_Silent_QuasiCyclic_test); - tc.add("Vole_Silent_Silver_test ", Vole_Silent_Silver_test); tc.add("Vole_Silent_paramSweep_test ", Vole_Silent_paramSweep_test); tc.add("Vole_Silent_baseOT_test ", Vole_Silent_baseOT_test); tc.add("Vole_Silent_mal_test ", Vole_Silent_mal_test); diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index b287607d..2fa78fa4 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -15,82 +15,79 @@ using namespace oc; #include -using namespace tests_libOTe; - +#include "libOTe/Tools/Subfield/Subfield.h" +#include "libOTe/Vole/Subfield/SilentVoleSender.h" +#include "libOTe/Vole/Subfield/SilentVoleReceiver.h" -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) +using namespace tests_libOTe; -void Vole_Noisy_test(const oc::CLP& cmd) +template +void Vole_Noisy_test_impl(u64 n) { - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 123); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - std::vector y(n), z0(n), z1(n); - prng.get(y); - - NoisyVoleReceiver recv; - NoisyVoleSender send; + PRNG prng(CCBlock); - recv.setTimer(timer); - send.setTimer(timer); + F delta = prng.get(); + std::vector c(n); + std::vector a(n), b(n); + prng.get(c.data(), c.size()); - //IOService ios; - //auto chl1 = Session(ios, "localhost:1212", SessionMode::Server).addChannel(); - //auto chl0 = Session(ios, "localhost:1212", SessionMode::Client).addChannel(); + NoisySubfieldVoleReceiver recv; + NoisySubfieldVoleSender send; auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - BitVector recvChoice((u8*)&x, 128); - std::vector otRecvMsg(128); - std::vector> otSendMsg(128); + BitVector recvChoice = Trait::binaryDecomposition(delta); + std::vector otRecvMsg(recvChoice.size()); + std::vector> otSendMsg(recvChoice.size()); prng.get>(otSendMsg); - for (u64 i = 0; i < 128; ++i) + for (u64 i = 0; i < recvChoice.size(); ++i) otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; - timer.setTimePoint("ot"); - auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); - auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); + // compute a,b such that + // + // a = b + c * delta + // + auto p0 = recv.receive(c, a, prng, otSendMsg, chls[0]); + auto p1 = send.send(delta, b, prng, otRecvMsg, chls[1]); eval(p0, p1); for (u64 i = 0; i < n; ++i) { - if (y[i].gf128Mul(x) != (z0[i] ^ z1[i])) + F prod, sum; + + Trait::mul(prod, delta, c[i]); + Trait::minus(sum, a[i], b[i]); + + if (prod != sum) { throw RTE_LOC; } } - timer.setTimePoint("done"); - - //std::cout << timer << std::endl; - } -#else void Vole_Noisy_test(const oc::CLP& cmd) { - throw UnitTestSkipped( - "ENABLE_SILENT_VOLE not defined. " - ); + for (u64 n : {1, 8, 433}) + { + Vole_Noisy_test_impl(n); + Vole_Noisy_test_impl(n); + Vole_Noisy_test_impl(n); + Vole_Noisy_test_impl, u32, CoeffCtxArray>(n); + } } -#endif -#ifdef ENABLE_SILENT_VOLE +namespace +{ -namespace { + template void fakeBase( u64 n, - u64 threads, PRNG& prng, - block delta, - SilentVoleReceiver& recver, - SilentVoleSender& sender) + F delta, + R& recver, + S& sender, + Ctx ctx) { sender.configure(n, SilentBaseType::Base, 128); recver.configure(n, SilentBaseType::Base, 128); @@ -112,281 +109,116 @@ namespace { for (auto i : rng(msg.size())) msg[i] = msg2[i][choices[i]]; - auto y = recver.sampleBaseVoleVals(prng);; - std::vector c(y.size()), b(y.size()); - prng.get(c.data(), c.size()); - for (auto i : rng(y.size())) + // a = b + c * d + // the sender gets b, d + // the recver gets a, c + auto c = recver.sampleBaseVoleVals(prng); + Ctx::template Vec a(c.size()), b(c.size()); + + prng.get(b.data(), b.size()); + for (auto i : rng(c.size())) { - b[i] = delta.gf128Mul(y[i]) ^ c[i]; + ctx.mul(a[i], delta, c[i]); + ctx.plus(a[i], b[i], a[i]); } sender.setSilentBaseOts(msg2, b); - - // fake base OTs. - recver.setSilentBaseOts(msg, c); + recver.setSilentBaseOts(msg, a); } } -void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) && defined(ENABLE_BITPOLYMUL) - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 102043); - u64 nt = cmd.getOr("nt", std::thread::hardware_concurrency()); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - std::vector c(n), z0(n), z1(n); - - SilentVoleReceiver recv; - SilentVoleSender send; - recv.mMultType = MultType::QuasiCyclic; - send.mMultType = MultType::QuasiCyclic; - - recv.setTimer(timer); - send.setTimer(timer); +template +void Vole_Silent_test_impl(u64 n, MultType type, bool debug, bool doFakeBase, bool mal) +{ + using VecF = Ctx::Vec; + using VecG = Ctx::Vec; + Ctx ctx; - recv.mDebug = true; - send.mDebug = true; + block seed = CCBlock; + PRNG prng(seed); auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; + recv.mMultType = type; + send.mMultType = type; + recv.mDebug = debug; + send.mDebug = debug; + if (mal) + { + recv.mMalType = SilentSecType::Malicious; + send.mMalType = SilentSecType::Malicious; + } - timer.setTimePoint("ot"); - fakeBase(n, nt, prng, x, recv, send); + VecF a(n), b(n); + VecG c(n); + F d = prng.get(); - // c * x = z + m + if(doFakeBase) + fakeBase(n, prng, d, recv, send, ctx); - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); + auto p0 = recv.silentReceive(c, a, prng, chls[0]); + auto p1 = send.silentSend(d, b, prng, chls[1]); eval(p0, p1); - timer.setTimePoint("send"); + for (u64 i = 0; i < n; ++i) { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) + // a = b + c * d + F exp; + ctx.mul(exp, d, c[i]); + ctx.plus(exp, exp, b[i]); + + if (a[i] != exp) { - std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << c[i].gf128Mul(x) << std::endl; - std::cout << " z0[i] " << z0[i] << " ^ z1 " << z1[i] << " = " << (z0[i] ^ z1[i]) << std::endl; throw RTE_LOC; } } - timer.setTimePoint("done"); -#else - throw UnitTestSkipped("ENABLE_BITPOLYMUL not defined." LOCATION); -#endif } + void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { - - Timer timer; - timer.setTimePoint("start"); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - u64 threads = 0; - - - auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - - //recv.mDebug = true; - //send.mDebug = true; - - SilentVoleReceiver recv; - SilentVoleSender send; - // c * x = z + m - - //for (u64 n = 5000; n < 10000; ++n) - for (u64 n : {12,/* 123,465,*/1642,/*4356,34254,*/93425}) + auto debug = cmd.isSet("debug"); + for (u64 n : {128, 45364}) { - std::vector c(n), z0(n), z1(n); - - fakeBase(n, threads, prng, x, recv, send); - - recv.setTimer(timer); - send.setTimer(timer); - - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - timer.setTimePoint("send"); - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + Vole_Silent_test_impl, u32, CoeffCtxArray>(n, DefaultMultType, debug, false, false); } } -void Vole_Silent_Silver_test(const oc::CLP& cmd) +void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { - -#ifdef ENABLE_INSECURE_SILVER - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 102043); - u64 nt = cmd.getOr("nt", std::thread::hardware_concurrency()); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - std::vector c(n), z0(n), z1(n); - - SilentVoleReceiver recv; - SilentVoleSender send; - - recv.mMultType = MultType::slv5; - send.mMultType = MultType::slv5; - - recv.setTimer(timer); - send.setTimer(timer); - - recv.mDebug = false; - send.mDebug = false; - - auto chls = cp::LocalAsyncSocket::makePair(); - - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - fakeBase(n, nt, prng, x, recv, send); - - // c * x = z + m - - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - - eval(p0, p1); - timer.setTimePoint("send"); - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << c[i].gf128Mul(x) << std::endl; - std::cout << " z0[i] " << z0[i] << " ^ z1 " << z1[i] << " = " << (z0[i] ^ z1[i]) << std::endl; - throw RTE_LOC; - } - } - timer.setTimePoint("done"); +#if defined(ENABLE_SILENTOT) && defined(ENABLE_BITPOLYMUL) + auto debug = cmd.isSet("debug"); + for (u64 n : {128, 333}) + Vole_Silent_test_impl(n, MultType::QuasiCyclic, debug, false, false); +#else + throw UnitTestSkipped("ENABLE_BITPOLYMUL not defined." LOCATION); #endif } - void Vole_Silent_baseOT_test(const oc::CLP& cmd) { - - Timer timer; - timer.setTimePoint("start"); - u64 n = 123; - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - - - - auto chls = cp::LocalAsyncSocket::makePair(); - - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - - //recv.mDebug = true; - //send.mDebug = true; - - SilentVoleReceiver recv; - SilentVoleSender send; - // c * x = z + m - - //for (u64 n = 5000; n < 10000; ++n) - { - std::vector c(n), z0(n), z1(n); - - - recv.setTimer(timer); - send.setTimer(timer); - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); - } + auto debug = cmd.isSet("debug"); + u64 n = 128; + Vole_Silent_test_impl(n, DefaultMultType, debug, true, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, true, false); + Vole_Silent_test_impl, u32, CoeffCtxArray>(n, DefaultMultType, debug, true, false); } void Vole_Silent_mal_test(const oc::CLP& cmd) { - - Timer timer; - timer.setTimePoint("start"); - u64 n = 12343; - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - - - - auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - - //recv.mDebug = true; - //send.mDebug = true; - - SilentVoleReceiver recv; - SilentVoleSender send; - - send.mMalType = SilentSecType::Malicious; - recv.mMalType = SilentSecType::Malicious; - // c * x = z + m - - //for (u64 n = 5000; n < 10000; ++n) + auto debug = cmd.isSet("debug"); + for (u64 n : {45364}) { - std::vector c(n), z0(n), z1(n); - - - recv.setTimer(timer); - send.setTimer(timer); - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, true); } } @@ -395,7 +227,6 @@ inline u64 eval( macoro::task<>& t1, macoro::task<>& t0, cp::BufferingSocket& s1, cp::BufferingSocket& s0) { - std::cout << "begin " << std::endl; auto e = macoro::make_eager(macoro::when_all_ready(std::move(t0), std::move(t1))); u64 rounds = 0; @@ -405,8 +236,6 @@ inline u64 eval( { s0.processInbound(*b1); ++rounds; - - std::cout << "round " << rounds << std::endl; } } @@ -429,10 +258,7 @@ inline u64 eval( s0.processInbound(*b1); } - ++rounds; - std::cout << "round " << rounds << std::endl; - ++idx; } @@ -488,7 +314,7 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) auto rounds = eval(p0, p1, chls[1], chls[0]); if (rounds != expRound) - throw std::runtime_error(std::to_string(rounds) + "!="+std::to_string(expRound)+". " +COPROTO_LOCATION); + throw std::runtime_error(std::to_string(rounds) + "!=" + std::to_string(expRound) + ". " + COPROTO_LOCATION); for (u64 i = 0; i < n; ++i) @@ -530,24 +356,20 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) timer.setTimePoint("done"); } } - -#else - -namespace { - void throwDisabled() - { - throw UnitTestSkipped( - "ENABLE_SILENT_VOLE not defined. " - ); - } -} - - -void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { throwDisabled(); } -void Vole_Silent_Silver_test(const oc::CLP& cmd) { throwDisabled(); } -void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { throwDisabled(); } -void Vole_Silent_baseOT_test(const oc::CLP& cmd) { throwDisabled(); } -void Vole_Silent_mal_test(const oc::CLP& cmd) { throwDisabled(); } -void Vole_Silent_Rounds_test(const oc::CLP& cmd) { throwDisabled(); } - -#endif +// +//namespace { +// void throwDisabled() +// { +// throw UnitTestSkipped( +// "ENABLE_SILENT_VOLE not defined. " +// ); +// } +//} +// +// +//void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_Silver_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_baseOT_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_mal_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_Rounds_test(const oc::CLP& cmd) { throwDisabled(); } diff --git a/libOTe_Tests/Vole_Tests.h b/libOTe_Tests/Vole_Tests.h index 5ef787ad..664ace30 100644 --- a/libOTe_Tests/Vole_Tests.h +++ b/libOTe_Tests/Vole_Tests.h @@ -10,9 +10,7 @@ void Vole_Noisy_test(const oc::CLP& cmd); void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd); -void Vole_Silent_Silver_test(const oc::CLP& cmd); void Vole_Silent_paramSweep_test(const oc::CLP& cmd); -void Vole_Noisy_test(const oc::CLP& cmd); void Vole_Silent_baseOT_test(const oc::CLP& cmd); void Vole_Silent_mal_test(const oc::CLP& cmd); void Vole_Silent_Rounds_test(const oc::CLP& cmd); From d4689d5b0607ea72c548e88ae10f9809feb69c44 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sun, 21 Jan 2024 20:00:06 -0800 Subject: [PATCH 12/23] optmize and noise non-zero odd --- CMakePresets.json | 4 +- frontend/benchmark.h | 183 ++++++++++++++++++++++ frontend/main.cpp | 4 + libOTe/Tools/EACode/Util.h | 8 - libOTe/Tools/ExConvCode/Expander2.h | 38 ++--- libOTe/Tools/Subfield/Subfield.h | 39 ++++- libOTe/Vole/Subfield/SilentVoleReceiver.h | 128 +++++++++------ libOTe/Vole/Subfield/SilentVoleSender.h | 8 +- libOTe_Tests/Vole_Tests.cpp | 12 +- 9 files changed, 330 insertions(+), 94 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 3f921bdd..69e16ea3 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -85,8 +85,8 @@ "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", "LIBOTE_STD_VER": "17", - "ENABLE_ALL_OT": false, - "ENABLE_RELIC": false, + "ENABLE_ALL_OT": true, + "ENABLE_RELIC": true, "ENABLE_SODIUM": false, "ENABLE_BOOST": false, "ENABLE_OPENSSL": false, diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 6d04550d..5cb0a761 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -10,7 +10,11 @@ #include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" +#include "libOTe/Vole/Silent/SilentVoleSender.h" +#include "libOTe/Vole/Silent/SilentVoleReceiver.h" +#include "libOTe/Vole/Subfield/SilentVoleSender.h" +#include "libOTe/Vole/Subfield/SilentVoleReceiver.h" namespace osuCrypto { @@ -362,6 +366,185 @@ namespace osuCrypto } + t0.join(); + std::cout << sTimer << std::endl; + std::cout << rTimer << std::endl; + + std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } +#endif + } + + + + inline void VoleBench(const CLP& cmd) + { +#ifdef ENABLE_SILENTOT + + try + { + + SilentVoleSender sender; + SilentVoleReceiver recver; + + u64 trials = cmd.getOr("t", 10); + + u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); + MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); + std::cout << multType << std::endl; + + recver.mMultType = multType; + sender.mMultType = multType; + + PRNG prng0(ZeroBlock), prng1(ZeroBlock); + block delta = prng0.get(); + + auto sock = coproto::LocalAsyncSocket::makePair(); + + Timer sTimer; + Timer rTimer; + sTimer.setTimePoint("start"); + rTimer.setTimePoint("start"); + + auto t0 = std::thread([&] { + for (u64 t = 0; t < trials; ++t) + { + auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); + + char c; + + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("__"); + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("s start"); + coproto::sync_wait(p0); + sTimer.setTimePoint("s done"); + } + }); + + + for (u64 t = 0; t < trials; ++t) + { + auto p1 = recver.silentReceiveInplace(n, prng1, sock[1]); + char c; + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); + + rTimer.setTimePoint("__"); + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); + + rTimer.setTimePoint("r start"); + coproto::sync_wait(p1); + rTimer.setTimePoint("r done"); + + } + + + t0.join(); + std::cout << sTimer << std::endl; + std::cout << rTimer << std::endl; + + std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } +#endif + } + + + + + inline void VoleBench2(const CLP& cmd) + { +#ifdef ENABLE_SILENTOT + + try + { + + SilentSubfieldVoleSender sender; + SilentSubfieldVoleReceiver recver; + + u64 trials = cmd.getOr("t", 10); + + u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); + MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); + std::cout << multType << std::endl; + + recver.mMultType = multType; + sender.mMultType = multType; + + std::vector> baseSend(128); + std::vector baseRecv(128); + BitVector baseChoice(128); + PRNG prng(CCBlock); + baseChoice.randomize(prng); + for (u64 i = 0; i < 128; ++i) + { + baseSend[i] = prng.get(); + baseRecv[i] = baseSend[i][baseChoice[i]]; + } + + sender.mOtExtRecver.setBaseOts(baseSend); + recver.mOtExtRecver.setBaseOts(baseSend); + sender.mOtExtSender.setBaseOts(baseRecv, baseChoice); + recver.mOtExtSender.setBaseOts(baseRecv, baseChoice); + + PRNG prng0(ZeroBlock), prng1(ZeroBlock); + block delta = prng0.get(); + + auto sock = coproto::LocalAsyncSocket::makePair(); + + Timer sTimer; + Timer rTimer; + sTimer.setTimePoint("start"); + rTimer.setTimePoint("start"); + + auto t0 = std::thread([&] { + for (u64 t = 0; t < trials; ++t) + { + auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); + + char c; + + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("__"); + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("s start"); + coproto::sync_wait(p0); + sTimer.setTimePoint("s done"); + } + }); + + + for (u64 t = 0; t < trials; ++t) + { + auto p1 = recver.silentReceiveInplace(n, prng1, sock[1]); + char c; + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); + + rTimer.setTimePoint("__"); + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); + + rTimer.setTimePoint("r start"); + coproto::sync_wait(p1); + rTimer.setTimePoint("r done"); + + } + + t0.join(); std::cout << sTimer << std::endl; std::cout << rTimer << std::endl; diff --git a/frontend/main.cpp b/frontend/main.cpp index 5fe2b603..2ae01973 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -119,6 +119,10 @@ int main(int argc, char** argv) QCCodeBench(cmd); else if (cmd.isSet("silent")) SilentOtBench(cmd); + else if (cmd.isSet("vole")) + VoleBench(cmd); + else if (cmd.isSet("vole2")) + VoleBench2(cmd); else if (cmd.isSet("ea")) EACodeBench(cmd); else diff --git a/libOTe/Tools/EACode/Util.h b/libOTe/Tools/EACode/Util.h index c92d8b4b..8d08fe0b 100644 --- a/libOTe/Tools/EACode/Util.h +++ b/libOTe/Tools/EACode/Util.h @@ -86,14 +86,6 @@ namespace osuCrypto b[6] = AES::roundEnc(b[6], k[6]); b[7] = AES::roundEnc(b[7], k[7]); - b[0] = b[0] ^ k[0]; - b[1] = b[1] ^ k[1]; - b[2] = b[2] ^ k[2]; - b[3] = b[3] ^ k[3]; - b[4] = b[4] ^ k[4]; - b[5] = b[5] ^ k[5]; - b[6] = b[6] ^ k[6]; - b[7] = b[7] ^ k[7]; } auto src = prng.mBuffer.data(); diff --git a/libOTe/Tools/ExConvCode/Expander2.h b/libOTe/Tools/ExConvCode/Expander2.h index fc83a1fe..b2a35ba2 100644 --- a/libOTe/Tools/ExConvCode/Expander2.h +++ b/libOTe/Tools/ExConvCode/Expander2.h @@ -14,19 +14,6 @@ namespace osuCrypto { - template - auto getRestrictPtr(Iter& c) - { - //if constexpr (coproto::has_data_member_func::value) - //{ - // return (decltype(c.data())__restrict) c.data(); - //} - //else - { - return c; - } - } - // The encoder for the expander matrix B. // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly // with fixed row weight mExpanderWeight. @@ -114,17 +101,14 @@ namespace osuCrypto DstIter&& output, CoeffCtx ctx) const { - //using P = std::pair; - //expandMany( - // std::tuple

{ P{input, output}}, - // ctx - //); - (void)*(input + (mCodeSize - 1)); (void)*(output + (mMessageSize - 1)); detail::ExpanderModd prng(mSeed, mCodeSize); + auto rInput = ctx.restrictPtr(input); + auto rOutput = ctx.restrictPtr(output); + auto main = mMessageSize / 8 * 8; u64 i = 0; @@ -147,14 +131,14 @@ namespace osuCrypto rr[6] = prng.get(); rr[7] = prng.get(); - ctx.plus(*(output + 0), *(output + 0), *(input + rr[0])); - ctx.plus(*(output + 1), *(output + 1), *(input + rr[1])); - ctx.plus(*(output + 2), *(output + 2), *(input + rr[2])); - ctx.plus(*(output + 3), *(output + 3), *(input + rr[3])); - ctx.plus(*(output + 4), *(output + 4), *(input + rr[4])); - ctx.plus(*(output + 5), *(output + 5), *(input + rr[5])); - ctx.plus(*(output + 6), *(output + 6), *(input + rr[6])); - ctx.plus(*(output + 7), *(output + 7), *(input + rr[7])); + ctx.plus(*(rOutput + 0), *(rOutput + 0), *(rInput + rr[0])); + ctx.plus(*(rOutput + 1), *(rOutput + 1), *(rInput + rr[1])); + ctx.plus(*(rOutput + 2), *(rOutput + 2), *(rInput + rr[2])); + ctx.plus(*(rOutput + 3), *(rOutput + 3), *(rInput + rr[3])); + ctx.plus(*(rOutput + 4), *(rOutput + 4), *(rInput + rr[4])); + ctx.plus(*(rOutput + 5), *(rOutput + 5), *(rInput + rr[5])); + ctx.plus(*(rOutput + 6), *(rOutput + 6), *(rInput + rr[6])); + ctx.plus(*(rOutput + 7), *(rOutput + 7), *(rInput + rr[7])); } } diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/Subfield/Subfield.h index c5218446..77dde299 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/Subfield/Subfield.h @@ -30,6 +30,16 @@ namespace osuCrypto { return lhs == rhs; } + // is F a field? + template + static OC_FORCEINLINE bool isField() { + return false; // default. + } + + + + + // the bit size require to prepresent F // the protocol will perform binary decomposition @@ -205,7 +215,11 @@ namespace osuCrypto { deserialize(begin, end, dstBegin); } - + template + static F* __restrict restrictPtr(Iter iter) + { + return &*iter; + } // fill the range [begin,..., end) with zeros. @@ -229,7 +243,21 @@ namespace osuCrypto { template static void one(Iter begin, Iter end) { - std::fill(begin, end, 1); + using F = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + + + if (begin != end) + { + auto n = std::distance(begin, end); + assert(n > 0); + memset(&*begin, 0, n * sizeof(F)); + while (begin != end) + { + auto& v = *begin++; + *(u8*)&v = 1; + } + } } @@ -261,6 +289,13 @@ namespace osuCrypto { static OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { ret = lhs ^ rhs; } + + // is F a field? + template + static OC_FORCEINLINE bool isField() { + return true; // default. + } + }; // block does not use operator* diff --git a/libOTe/Vole/Subfield/SilentVoleReceiver.h b/libOTe/Vole/Subfield/SilentVoleReceiver.h index a9fbc31e..7f3552b2 100644 --- a/libOTe/Vole/Subfield/SilentVoleReceiver.h +++ b/libOTe/Vole/Subfield/SilentVoleReceiver.h @@ -184,7 +184,7 @@ namespace osuCrypto // other party will program the PPRF to output their share of delta * noiseVals. // noiseVals = sampleBaseVoleVals(prng); - Ctx::resize(baseAs, noiseVals.size()); + mCtx.resize(baseAs, noiseVals.size()); if (mTimer) nv.setTimer(*mTimer); @@ -231,8 +231,12 @@ namespace osuCrypto { chl2 = chl.fork(); prng2.SetSeed(prng.get()); - MC_AWAIT(baseOt.receive(choice, msg, prng, chl)); - MC_AWAIT(nv.receive(noiseVals, baseAs, prng2, baseOt, chl2)); + + MC_AWAIT( + macoro::when_all_ready( + baseOt.receive(choice, msg, prng, chl), + nv.receive(noiseVals, baseAs, prng2, baseOt, chl2)) + ); } setSilentBaseOts(msg, baseAs); @@ -277,6 +281,8 @@ namespace osuCrypto mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); mNoiseVecSize = mSizePer * mNumPartitions; + //std::cout << "n " << mRequestSize << " -> " << mNoiseVecSize << " = " << mSizePer << " * " << mNumPartitions << std::endl; + mGen.configure(mSizePer, mNumPartitions); } @@ -316,9 +322,29 @@ namespace osuCrypto // sample the values of the noisy coordinate of c // and perform a noicy vole to get a = b + mD * c - Ctx::resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); + + VecG zero, one; + mCtx.resize(zero, 1); + mCtx.zero(zero.begin(), zero.end()); + mCtx.one(one.begin(), one.end()); + mCtx.resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); for (size_t i = 0; i < mNumPartitions; i++) - Ctx::fromBlock(mBaseC[i], prng.get()); + { + mCtx.fromBlock(mBaseC[i], prng.get()); + + // must not be zero. + while(mCtx.eq(zero[0], mBaseC[i])) + mCtx.fromBlock(mBaseC[i], prng.get()); + + // if we are not a field, then the noise should be odd. + if (mCtx.isField() == false) + { + auto odd = mCtx.binaryDecomposition(mBaseC[i])[0]; + if (odd) + mCtx.plus(mBaseC[i], mBaseC[i], one[0]); + } + } + mS.resize(mNumPartitions); mGen.getPoints(mS, PprfOutputFormat::Interleaved); @@ -362,8 +388,8 @@ namespace osuCrypto mGen.setBase(recvBaseOts); - Ctx::resize(mBaseA, baseA.size()); - Ctx::copy(baseA.begin(), baseA.end(), mBaseA.begin()); + mCtx.resize(mBaseA, baseA.size()); + mCtx.copy(baseA.begin(), baseA.end(), mBaseA.begin()); mState = State::HasBase; } @@ -383,8 +409,8 @@ namespace osuCrypto MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); - Ctx::copy(mC.begin(), mC.begin() + c.size(), c.begin()); - Ctx::copy(mA.begin(), mA.begin() + a.size(), a.begin()); + mCtx.copy(mC.begin(), mC.begin() + c.size(), c.begin()); + mCtx.copy(mA.begin(), mA.begin() + a.size(), a.begin()); clear(); MC_END(); @@ -420,15 +446,15 @@ namespace osuCrypto } // allocate mA - Ctx::resize(mA, 0); - Ctx::resize(mA, mNoiseVecSize); + mCtx.resize(mA, 0); + mCtx.resize(mA, mNoiseVecSize); setTimePoint("SilentVoleReceiver.alloc"); // allocate the space for mC - Ctx::resize(mC, 0); - Ctx::resize(mC, mNoiseVecSize); - Ctx::zero(mC.begin(), mC.end()); + mCtx.resize(mC, 0); + mCtx.resize(mC, mNoiseVecSize); + mCtx.zero(mC.begin(), mC.end()); setTimePoint("SilentVoleReceiver.alloc.zero"); if (mTimer) @@ -461,8 +487,8 @@ namespace osuCrypto for (u64 i = 0; i < mNumPartitions; ++i) { auto pnt = mS[i]; - Ctx::copy(mC[pnt], mBaseC[i]); - Ctx::plus(mA[pnt], mA[pnt], mBaseA[i]); + mCtx.copy(mC[pnt], mBaseC[i]); + mCtx.plus(mA[pnt], mA[pnt], mBaseA[i]); } if (mDebug) @@ -496,7 +522,9 @@ namespace osuCrypto double _; ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _); ExConvCode2 encoder; - encoder.config(mRequestSize, mNoiseVecSize, expanderWeight, accumulatorWeight); + if (mScaler * mRequestSize > mNoiseVecSize) + throw RTE_LOC; + encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); if (mTimer) encoder.setTimer(getTimer()); @@ -509,6 +537,7 @@ namespace osuCrypto } case osuCrypto::MultType::QuasiCyclic: { +#ifdef ENABLE_BITPOLYMUL if constexpr ( std::is_same_v && std::is_same_v && @@ -521,6 +550,9 @@ namespace osuCrypto } else throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); +#else + throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); +#endif break; } default: @@ -529,8 +561,8 @@ namespace osuCrypto } // resize the buffers down to only contain the real elements. - Ctx::resize(mA, mRequestSize); - Ctx::resize(mC, mRequestSize); + mCtx.resize(mA, mRequestSize); + mCtx.resize(mC, mRequestSize); mBaseC = {}; mBaseA = {}; @@ -548,32 +580,32 @@ namespace osuCrypto task<> checkRT(Socket& chl) const { MC_BEGIN(task<>, this, &chl, - B = typename Ctx::Vec{}, - sparseNoiseDelta = typename Ctx::Vec{}, - baseB = typename Ctx::Vec{}, - delta = typename Ctx::Vec{}, - tempF = typename Ctx::Vec{}, - tempG = typename Ctx::Vec{}, + B = VecF{}, + sparseNoiseDelta = VecF{}, + baseB = VecF{}, + delta = VecF{}, + tempF = VecF{}, + tempG = VecG{}, buffer = std::vector{} ); // recv delta - buffer.resize(Ctx::byteSize()); - Ctx::resize(delta, 1); + buffer.resize(mCtx.byteSize()); + mCtx.resize(delta, 1); MC_AWAIT(chl.recv(buffer)); - Ctx::deserialize(buffer.begin(), buffer.end(), delta.begin()); + mCtx.deserialize(buffer.begin(), buffer.end(), delta.begin()); // recv B - buffer.resize(Ctx::byteSize() * mA.size()); - Ctx::resize(B, mA.size()); + buffer.resize(mCtx.byteSize() * mA.size()); + mCtx.resize(B, mA.size()); MC_AWAIT(chl.recv(buffer)); - Ctx::deserialize(buffer.begin(), buffer.end(), B.begin()); + mCtx.deserialize(buffer.begin(), buffer.end(), B.begin()); // recv the noisy values. - buffer.resize(Ctx::byteSize() * mBaseA.size()); - Ctx::resize(baseB, mBaseA.size()); + buffer.resize(mCtx.byteSize() * mBaseA.size()); + mCtx.resize(baseB, mBaseA.size()); MC_AWAIT(chl.recvResize(buffer)); - Ctx::deserialize(buffer.begin(), buffer.end(), baseB.begin()); + mCtx.deserialize(buffer.begin(), buffer.end(), baseB.begin()); // it shoudl hold that // @@ -591,9 +623,9 @@ namespace osuCrypto std::sort(index.begin(), index.end(), [&](std::size_t i, std::size_t j) { return mS[i] < mS[j]; }); - Ctx::resize(tempF, 2); - Ctx::resize(tempG, 1); - Ctx::zero(tempG.begin(), tempG.end()); + mCtx.resize(tempF, 2); + mCtx.resize(tempG, 1); + mCtx.zero(tempG.begin(), tempG.end()); // check the correlation that @@ -602,19 +634,19 @@ namespace osuCrypto for (auto i : rng(mBaseA.size())) { // temp[0] = baseB[i] + mBaseA[i] - Ctx::plus(tempF[0], baseB[i], mBaseA[i]); + mCtx.plus(tempF[0], baseB[i], mBaseA[i]); // temp[1] = mBaseC[i] * delta[0] - Ctx::mul(tempF[1], delta[0], mBaseC[i]); + mCtx.mul(tempF[1], delta[0], mBaseC[i]); - if (!Ctx::eq(tempF[0], tempF[1])) + if (!mCtx.eq(tempF[0], tempF[1])) throw RTE_LOC; if (i < mNumPartitions) { //auto idx = index[i]; auto point = mS[i]; - if (!Ctx::eq(mBaseC[i], mC[point])) + if (!mCtx.eq(mBaseC[i], mC[point])) throw RTE_LOC; if (i && mS[index[i - 1]] >= mS[index[i]]) @@ -627,19 +659,19 @@ namespace osuCrypto auto leafIdx = mS[*iIter]; F act = tempF[0]; G zero = tempG[0]; - Ctx::zero(tempG.begin(), tempG.end()); + mCtx.zero(tempG.begin(), tempG.end()); for (u64 j = 0; j < mA.size(); ++j) { - Ctx::mul(act, delta[0], mC[j]); - Ctx::plus(act, act, B[j]); + mCtx.mul(act, delta[0], mC[j]); + mCtx.plus(act, act, B[j]); bool active = false; if (j == leafIdx) { active = true; } - else if (!Ctx::eq(zero, mC[j])) + else if (!mCtx.eq(zero, mC[j])) throw RTE_LOC; if (mA[j] != act) @@ -651,11 +683,11 @@ namespace osuCrypto if (verbose) { - std::cout << j << " act " << Ctx::str(act) - << " a " << Ctx::str(mA[j]) << " b " << Ctx::str(B[j]); + std::cout << j << " act " << mCtx.str(act) + << " a " << mCtx.str(mA[j]) << " b " << mCtx.str(B[j]); if (active) - std::cout << " < " << Ctx::str(delta[0]); + std::cout << " < " << mCtx.str(delta[0]); std::cout << std::endl << Color::Default; } diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h index e3185477..dd2496b0 100644 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ b/libOTe/Vole/Subfield/SilentVoleSender.h @@ -383,7 +383,9 @@ namespace osuCrypto u64 expanderWeight, accumulatorWeight; double _1; ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _1); - encoder.config(mRequestSize, mNoiseVecSize, expanderWeight, accumulatorWeight); + if (mScaler * mRequestSize > mNoiseVecSize) + throw RTE_LOC; + encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); if (mTimer) encoder.setTimer(getTimer()); encoder.dualEncode(mB.begin()); @@ -391,6 +393,7 @@ namespace osuCrypto } case MultType::QuasiCyclic: { +#ifdef ENABLE_BITPOLYMUL if constexpr ( std::is_same_v && std::is_same_v && @@ -402,6 +405,9 @@ namespace osuCrypto } else throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); +#else + throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); +#endif break; } diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index 2fa78fa4..155b2f9b 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -77,7 +77,7 @@ void Vole_Noisy_test(const oc::CLP& cmd) } } -namespace +namespace { template @@ -156,7 +156,7 @@ void Vole_Silent_test_impl(u64 n, MultType type, bool debug, bool doFakeBase, bo VecG c(n); F d = prng.get(); - if(doFakeBase) + if (doFakeBase) fakeBase(n, prng, d, recv, send, ctx); auto p0 = recv.silentReceive(c, a, prng, chls[0]); @@ -274,7 +274,7 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) Timer timer; timer.setTimePoint("start"); - u64 n = 12343; + u64 n = 1233; block seed = block(0, cmd.getOr("seed", 0)); PRNG prng(seed); @@ -283,8 +283,8 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) cp::BufferingSocket chls[2]; - SilentVoleReceiver recv; - SilentVoleSender send; + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; send.mMalType = SilentSecType::SemiHonest; recv.mMalType = SilentSecType::SemiHonest; @@ -302,7 +302,7 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) send.setTimer(timer); if (jj) { - std::vector c(n), z0(n), z1(n); + AlignedUnVector c(n), z0(n), z1(n); auto p0 = recv.silentReceive(c, z0, prng, chls[0]); auto p1 = send.silentSend(x, z1, prng, chls[1]); #if (defined ENABLE_MRR_TWIST && defined ENABLE_SSE) || \ From 6f775615442f280d93165c13bcbfb58ab8be4457 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sun, 21 Jan 2024 23:39:34 -0800 Subject: [PATCH 13/23] removed old impl --- frontend/benchmark.h | 89 +- frontend/main.cpp | 2 - libOTe/CMakeLists.txt | 2 +- .../Tools/{Subfield/Subfield.h => CoeffCtx.h} | 74 +- libOTe/Tools/EACode/EACode.cpp | 142 --- libOTe/Tools/EACode/EACode.h | 157 ++- libOTe/Tools/EACode/EACodeInstantiations.cpp | 15 - libOTe/Tools/EACode/Expander.cpp | 433 -------- libOTe/Tools/EACode/Expander.h | 98 -- .../Tools/EACode/ExpanderInstantiations.cpp | 26 - libOTe/Tools/ExConvCode/ExConvCode.cpp | 610 ++--------- libOTe/Tools/ExConvCode/ExConvCode.h | 538 ++++++++-- libOTe/Tools/ExConvCode/ExConvCode2.cpp | 117 -- libOTe/Tools/ExConvCode/ExConvCode2.h | 626 ----------- libOTe/Tools/ExConvCode/ExConvCode2Impl.h | 3 - .../ExConvCode/ExConvCodeInstantiations.cpp | 17 - .../ExConvCode/{Expander2.h => Expander.h} | 15 +- libOTe/Tools/Pprf/PprfUtil.h | 266 +++++ .../SubfieldPprf.cpp => Pprf/RegularPprf.cpp} | 0 .../SubfieldPprf.h => Pprf/RegularPprf.h} | 678 +++++------- libOTe/Tools/SilentPprf.cpp | 995 ------------------ libOTe/Tools/SilentPprf.h | 318 ------ libOTe/TwoChooseOne/ConfigureCode.cpp | 2 +- .../Silent/SilentOtExtReceiver.cpp | 127 +-- .../TwoChooseOne/Silent/SilentOtExtReceiver.h | 11 +- .../TwoChooseOne/Silent/SilentOtExtSender.cpp | 120 +-- .../TwoChooseOne/Silent/SilentOtExtSender.h | 9 +- libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h | 23 + libOTe/Vole/Noisy/NoisyVoleReceiver.cpp | 88 -- libOTe/Vole/Noisy/NoisyVoleReceiver.h | 151 ++- libOTe/Vole/Noisy/NoisyVoleSender.cpp | 82 -- libOTe/Vole/Noisy/NoisyVoleSender.h | 138 ++- libOTe/Vole/Silent/SilentVoleReceiver.cpp | 821 --------------- libOTe/Vole/Silent/SilentVoleReceiver.h | 695 ++++++++++-- libOTe/Vole/Silent/SilentVoleSender.cpp | 435 -------- libOTe/Vole/Silent/SilentVoleSender.h | 439 ++++++-- libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp | 41 +- libOTe/Vole/SoftSpokenOT/SmallFieldVole.h | 18 +- libOTe/Vole/Subfield/NoisyVoleReceiver.h | 148 --- libOTe/Vole/Subfield/NoisyVoleSender.h | 141 --- libOTe/Vole/Subfield/SilentVoleReceiver.h | 751 ------------- libOTe/Vole/Subfield/SilentVoleSender.h | 475 --------- libOTe_Tests/EACode_Tests.cpp | 92 +- libOTe_Tests/ExConvCode_Tests.cpp | 34 +- libOTe_Tests/Pprf_Tests.cpp | 65 +- libOTe_Tests/SilentOT_Tests.cpp | 1 - libOTe_Tests/Subfield_Test.h | 12 - libOTe_Tests/Subfield_Tests.cpp | 224 ---- libOTe_Tests/UnitTests.cpp | 3 +- libOTe_Tests/Vole_Tests.cpp | 21 +- 50 files changed, 2713 insertions(+), 7675 deletions(-) rename libOTe/Tools/{Subfield/Subfield.h => CoeffCtx.h} (81%) delete mode 100644 libOTe/Tools/EACode/EACode.cpp delete mode 100644 libOTe/Tools/EACode/EACodeInstantiations.cpp delete mode 100644 libOTe/Tools/EACode/Expander.cpp delete mode 100644 libOTe/Tools/EACode/Expander.h delete mode 100644 libOTe/Tools/EACode/ExpanderInstantiations.cpp delete mode 100644 libOTe/Tools/ExConvCode/ExConvCode2.cpp delete mode 100644 libOTe/Tools/ExConvCode/ExConvCode2.h delete mode 100644 libOTe/Tools/ExConvCode/ExConvCode2Impl.h delete mode 100644 libOTe/Tools/ExConvCode/ExConvCodeInstantiations.cpp rename libOTe/Tools/ExConvCode/{Expander2.h => Expander.h} (97%) create mode 100644 libOTe/Tools/Pprf/PprfUtil.h rename libOTe/Tools/{Subfield/SubfieldPprf.cpp => Pprf/RegularPprf.cpp} (100%) rename libOTe/Tools/{Subfield/SubfieldPprf.h => Pprf/RegularPprf.h} (60%) delete mode 100644 libOTe/Tools/SilentPprf.cpp delete mode 100644 libOTe/Tools/SilentPprf.h create mode 100644 libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h delete mode 100644 libOTe/Vole/Noisy/NoisyVoleReceiver.cpp delete mode 100644 libOTe/Vole/Noisy/NoisyVoleSender.cpp delete mode 100644 libOTe/Vole/Silent/SilentVoleReceiver.cpp delete mode 100644 libOTe/Vole/Silent/SilentVoleSender.cpp delete mode 100644 libOTe/Vole/Subfield/NoisyVoleReceiver.h delete mode 100644 libOTe/Vole/Subfield/NoisyVoleSender.h delete mode 100644 libOTe/Vole/Subfield/SilentVoleReceiver.h delete mode 100644 libOTe/Vole/Subfield/SilentVoleSender.h delete mode 100644 libOTe_Tests/Subfield_Test.h delete mode 100644 libOTe_Tests/Subfield_Tests.cpp diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 5cb0a761..bd50c7b2 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -13,8 +13,6 @@ #include "libOTe/Vole/Silent/SilentVoleSender.h" #include "libOTe/Vole/Silent/SilentVoleReceiver.h" -#include "libOTe/Vole/Subfield/SilentVoleSender.h" -#include "libOTe/Vole/Subfield/SilentVoleReceiver.h" namespace osuCrypto { @@ -101,7 +99,7 @@ namespace osuCrypto timer.setTimePoint("_____________________"); for (u64 i = 0; i < trials; ++i) { - code.dualEncode(x, y); + code.dualEncode(x, y, {}); timer.setTimePoint("encode"); } @@ -156,10 +154,7 @@ namespace osuCrypto timer.setTimePoint("_____________________"); for (u64 i = 0; i < trials; ++i) { - if (sys) - code.dualEncode(x); - else - code.dualEncode(x, y); + code.dualEncode(x.begin(), {}); timer.setTimePoint("encode"); } @@ -381,86 +376,6 @@ namespace osuCrypto - inline void VoleBench(const CLP& cmd) - { -#ifdef ENABLE_SILENTOT - - try - { - - SilentVoleSender sender; - SilentVoleReceiver recver; - - u64 trials = cmd.getOr("t", 10); - - u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); - MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); - std::cout << multType << std::endl; - - recver.mMultType = multType; - sender.mMultType = multType; - - PRNG prng0(ZeroBlock), prng1(ZeroBlock); - block delta = prng0.get(); - - auto sock = coproto::LocalAsyncSocket::makePair(); - - Timer sTimer; - Timer rTimer; - sTimer.setTimePoint("start"); - rTimer.setTimePoint("start"); - - auto t0 = std::thread([&] { - for (u64 t = 0; t < trials; ++t) - { - auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); - - char c; - - coproto::sync_wait(sock[0].send(std::move(c))); - coproto::sync_wait(sock[0].recv(c)); - sTimer.setTimePoint("__"); - coproto::sync_wait(sock[0].send(std::move(c))); - coproto::sync_wait(sock[0].recv(c)); - sTimer.setTimePoint("s start"); - coproto::sync_wait(p0); - sTimer.setTimePoint("s done"); - } - }); - - - for (u64 t = 0; t < trials; ++t) - { - auto p1 = recver.silentReceiveInplace(n, prng1, sock[1]); - char c; - coproto::sync_wait(sock[1].send(std::move(c))); - coproto::sync_wait(sock[1].recv(c)); - - rTimer.setTimePoint("__"); - coproto::sync_wait(sock[1].send(std::move(c))); - coproto::sync_wait(sock[1].recv(c)); - - rTimer.setTimePoint("r start"); - coproto::sync_wait(p1); - rTimer.setTimePoint("r done"); - - } - - - t0.join(); - std::cout << sTimer << std::endl; - std::cout << rTimer << std::endl; - - std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - } -#endif - } - - inline void VoleBench2(const CLP& cmd) diff --git a/frontend/main.cpp b/frontend/main.cpp index 2ae01973..b3c64c07 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -119,8 +119,6 @@ int main(int argc, char** argv) QCCodeBench(cmd); else if (cmd.isSet("silent")) SilentOtBench(cmd); - else if (cmd.isSet("vole")) - VoleBench(cmd); else if (cmd.isSet("vole2")) VoleBench2(cmd); else if (cmd.isSet("ea")) diff --git a/libOTe/CMakeLists.txt b/libOTe/CMakeLists.txt index 0e690f9f..6e1b8c27 100644 --- a/libOTe/CMakeLists.txt +++ b/libOTe/CMakeLists.txt @@ -4,7 +4,7 @@ file(GLOB_RECURSE SRCS *.cpp *.c) set(SRCS "${SRCS}") -add_library(libOTe STATIC ${SRCS} "Tools/EACode/EACodeInstantiations.cpp") +add_library(libOTe STATIC ${SRCS} ) # make projects that include libOTe use this as an include folder diff --git a/libOTe/Tools/Subfield/Subfield.h b/libOTe/Tools/CoeffCtx.h similarity index 81% rename from libOTe/Tools/Subfield/Subfield.h rename to libOTe/Tools/CoeffCtx.h index 77dde299..5de952ef 100644 --- a/libOTe/Tools/Subfield/Subfield.h +++ b/libOTe/Tools/CoeffCtx.h @@ -12,27 +12,27 @@ namespace osuCrypto { struct CoeffCtxInteger { template - static OC_FORCEINLINE void plus(R&& ret, F1&& lhs, F2&& rhs) { + OC_FORCEINLINE void plus(R&& ret, F1&& lhs, F2&& rhs) { ret = lhs + rhs; } template - static OC_FORCEINLINE void minus(R&& ret, F1&& lhs, F2&& rhs) { + OC_FORCEINLINE void minus(R&& ret, F1&& lhs, F2&& rhs) { ret = lhs - rhs; } template - static OC_FORCEINLINE void mul(R&& ret, F1&& lhs, F2&& rhs) { + OC_FORCEINLINE void mul(R&& ret, F1&& lhs, F2&& rhs) { ret = lhs * rhs; } template - static OC_FORCEINLINE bool eq(F&& lhs, F&& rhs) { + OC_FORCEINLINE bool eq(F&& lhs, F&& rhs) { return lhs == rhs; } // is F a field? template - static OC_FORCEINLINE bool isField() { + OC_FORCEINLINE bool isField() { return false; // default. } @@ -45,7 +45,7 @@ namespace osuCrypto { // the protocol will perform binary decomposition // of F using this many bits template - static u64 bitSize() + u64 bitSize() { return sizeof(F) * 8; } @@ -56,15 +56,15 @@ namespace osuCrypto { // x = sum_{i = 0,...,n} 2^i * binaryDecomposition(x)[i] // template - static OC_FORCEINLINE BitVector binaryDecomposition(F& x) { - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + OC_FORCEINLINE BitVector binaryDecomposition(F& x) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); return { (u8*)&x, sizeof(F) * 8 }; } // sample an F using the randomness b. template - static OC_FORCEINLINE void fromBlock(F& ret, const block& b) { - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + OC_FORCEINLINE void fromBlock(F& ret, const block& b) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); if constexpr (sizeof(F) <= sizeof(block)) { @@ -84,8 +84,8 @@ namespace osuCrypto { // return the F element with value 2^power template - static OC_FORCEINLINE void powerOfTwo(F& ret, u64 power) { - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + OC_FORCEINLINE void powerOfTwo(F& ret, u64 power) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); memset(&ret, 0, sizeof(F)); *BitIterator((u8*)&ret, power) = 1; } @@ -102,22 +102,22 @@ namespace osuCrypto { using Vec = AlignedUnVector; // resize Vec - template - static void resize(Vec& f, u64 size) + template + void resize(VecF& f, u64 size) { f.resize(size); } // the size of F when serialized. template - static u64 byteSize() + u64 byteSize() { return sizeof(F); } // copy a single F element. template - static OC_FORCEINLINE void copy(F& dst, const F& src) + OC_FORCEINLINE void copy(F& dst, const F& src) { dst = src; } @@ -125,7 +125,7 @@ namespace osuCrypto { // copy [begin,...,end) into [dstBegin, ...) // the iterators will point to the same types, i.e. F template - static OC_FORCEINLINE void copy( + OC_FORCEINLINE void copy( SrcIter begin, SrcIter end, DstIter dstBegin) @@ -142,7 +142,7 @@ namespace osuCrypto { // begin will be a byte pointer/iterator. // dstBegin will be an F pointer/iterator template - static void deserialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) + void deserialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) { // as written this function is a bit more general than strictly neccessary // due to serialize(...) redirecting here. @@ -153,7 +153,7 @@ namespace osuCrypto { #if __cplusplus >= 202002L //std::contiguous_iterator<> - // static_assert contigous iter in cpp20 + // _assert contigous iter in cpp20 #endif @@ -208,7 +208,7 @@ namespace osuCrypto { // begin will be an F pointer/iterator // dstBegin will be a byte pointer/iterator. template - static void serialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) + void serialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) { // for primitive types serialization and deserializaion // are the same, a memcpy. @@ -216,7 +216,7 @@ namespace osuCrypto { } template - static F* __restrict restrictPtr(Iter iter) + F* __restrict restrictPtr(Iter iter) { return &*iter; } @@ -225,7 +225,7 @@ namespace osuCrypto { // fill the range [begin,..., end) with zeros. // begin will be an F pointer/iterator. template - static void zero(Iter begin, Iter end) + void zero(Iter begin, Iter end) { using F = std::remove_reference_t; static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); @@ -241,7 +241,7 @@ namespace osuCrypto { // fill the range [begin,..., end) with ones. // begin will be an F pointer/iterator. template - static void one(Iter begin, Iter end) + void one(Iter begin, Iter end) { using F = std::remove_reference_t; static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); @@ -263,7 +263,7 @@ namespace osuCrypto { // convert F into a string template - static std::string str(F&& f) + std::string str(F&& f) { std::stringstream ss; if constexpr (std::is_same_v, u8>) @@ -282,17 +282,17 @@ namespace osuCrypto { { template - static OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { ret = lhs ^ rhs; } template - static OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { + OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { ret = lhs ^ rhs; } // is F a field? template - static OC_FORCEINLINE bool isField() { + OC_FORCEINLINE bool isField() { return true; // default. } @@ -301,7 +301,7 @@ namespace osuCrypto { // block does not use operator* struct CoeffCtxGFBlock : CoeffCtxGF { - static OC_FORCEINLINE void mul(block& ret, const block& lhs, const block& rhs) { + OC_FORCEINLINE void mul(block& ret, const block& lhs, const block& rhs) { ret = lhs.gf128Mul(rhs); } }; @@ -312,35 +312,35 @@ namespace osuCrypto { { using F = std::array; - static OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { ret[i] = lhs[i] + rhs[i]; } } - static OC_FORCEINLINE void plus(G& ret, const G& lhs, const G& rhs) { + OC_FORCEINLINE void plus(G& ret, const G& lhs, const G& rhs) { ret = lhs + rhs; } - static OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) + OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { ret[i] = lhs[i] - rhs[i]; } } - static OC_FORCEINLINE void minus(G& ret, const G& lhs, const G& rhs) { + OC_FORCEINLINE void minus(G& ret, const G& lhs, const G& rhs) { ret = lhs - rhs; } - static OC_FORCEINLINE void mul(F& ret, const F& lhs, const G& rhs) + OC_FORCEINLINE void mul(F& ret, const F& lhs, const G& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { ret[i] = lhs[i] * rhs; } } - static OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) + OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { if (lhs[i] != rhs[i]) @@ -349,13 +349,13 @@ namespace osuCrypto { return true; } - static OC_FORCEINLINE bool eq(const G& lhs, const G& rhs) + OC_FORCEINLINE bool eq(const G& lhs, const G& rhs) { return lhs == rhs; } // convert F into a string - static std::string str(const F& f) + std::string str(const F& f) { auto delim = "{ "; std::stringstream ss; @@ -375,7 +375,7 @@ namespace osuCrypto { } // convert G into a string - static std::string str(const G& g) + std::string str(const G& g) { std::stringstream ss; if constexpr (std::is_same_v, u8>) diff --git a/libOTe/Tools/EACode/EACode.cpp b/libOTe/Tools/EACode/EACode.cpp deleted file mode 100644 index 3ae76cd9..00000000 --- a/libOTe/Tools/EACode/EACode.cpp +++ /dev/null @@ -1,142 +0,0 @@ - -#include "EACode.h" - -namespace osuCrypto -{ - - // Compute w = G * e. - template - void EACode::dualEncode(span e, span w) - { - if (mCodeSize == 0) - throw RTE_LOC; - if (e.size() != mCodeSize) - throw RTE_LOC; - if (w.size() != mMessageSize) - throw RTE_LOC; - - - setTimePoint("EACode.encode.begin"); - - accumulate(e); - - setTimePoint("EACode.encode.accumulate"); - - expand(e, w); - setTimePoint("EACode.encode.expand"); - } - - - - - template - void EACode::dualEncode2( - span e0, - span w0, - span e1, - span w1 - ) - { - if (mCodeSize == 0) - throw RTE_LOC; - if (e0.size() != mCodeSize) - throw RTE_LOC; - if (e1.size() != mCodeSize) - throw RTE_LOC; - if (w0.size() != mMessageSize) - throw RTE_LOC; - if (w1.size() != mMessageSize) - throw RTE_LOC; - - setTimePoint("EACode.encode.begin"); - - accumulate(e0, e1); - //accumulate(e1); - - setTimePoint("EACode.encode.accumulate"); - - expand(e0, w0); - expand(e1, w1); - setTimePoint("EACode.encode.expand"); - } - - - - template - void EACode::accumulate(span x) - { - if (x.size() != mCodeSize) - throw RTE_LOC; - auto main = (u64)std::max(0, mCodeSize - 1); - T* __restrict xx = x.data(); - - for (u64 i = 0; i < main; ++i) - { - auto xj = xx[i + 1] ^ xx[i]; - xx[i + 1] = xj; - } - } - - template - void EACode::accumulate(span x, span x2) - { - if (x.size() != mCodeSize) - throw RTE_LOC; - if (x2.size() != mCodeSize) - throw RTE_LOC; - - - auto main = (u64)std::max(0, mCodeSize - 1); - T* __restrict xx1 = x.data(); - T2* __restrict xx2 = x2.data(); - - for (u64 i = 0; i < main; ++i) - { - auto x1j = xx1[i + 1] ^ xx1[i]; - auto x2j = xx2[i + 1] ^ xx2[i]; - xx1[i + 1] = x1j; - xx2[i + 1] = x2j; - } - } - -#ifndef EACODE_INSTANTIONATIONS - - SparseMtx EACode::getAPar() const - { - PointList AP(mCodeSize, mCodeSize);; - for (u64 i = 0; i < mCodeSize; ++i) - { - AP.push_back(i, i); - if (i + 1 < mCodeSize) - AP.push_back(i + 1, i); - } - return AP; - } - - SparseMtx EACode::getA() const - { - auto APar = getAPar(); - - auto A = DenseMtx::Identity(mCodeSize); - - for (u64 i = 0; i < mCodeSize; ++i) - { - for (auto y : APar.col(i)) - { - //std::cout << y << " "; - if (y != i) - { - auto ay = A.row(y); - auto ai = A.row(i); - ay ^= ai; - - } - } - - //std::cout << "\n" << A << std::endl; - } - - return A.sparse(); - } -#endif -} \ No newline at end of file diff --git a/libOTe/Tools/EACode/EACode.h b/libOTe/Tools/EACode/EACode.h index bbe76ac4..77022721 100644 --- a/libOTe/Tools/EACode/EACode.h +++ b/libOTe/Tools/EACode/EACode.h @@ -8,16 +8,11 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" -#include "Expander.h" +#include "libOTe/Tools/ExConvCode/Expander.h" #include "Util.h" - namespace osuCrypto { -#if __cplusplus >= 201703L -#define EA_CONSTEXPR constexpr -#else -#define EA_CONSTEXPR -#endif + // The encoder for the generator matrix G = B * A. // B is the expander while A is the accumulator. // @@ -30,27 +25,25 @@ namespace osuCrypto { public: using ExpanderCode::config; - using ExpanderCode::getB; - // Compute w = G * e. - template - void dualEncode(span e, span w); + template + void dualEncode(span e, span w, Ctx ctx); - template + template void dualEncode2( span e0, span w0, span e1, - span w1 + span w1, Ctx ctx ); - template - void accumulate(span x); - template - void accumulate(span x, span x2); + template + void accumulate(span x, Ctx ctx); + template + void accumulate(span x, span x2, Ctx ctx); // Get the parity check version of the accumulator SparseMtx getAPar() const; @@ -58,5 +51,133 @@ namespace osuCrypto SparseMtx getA() const; }; -#undef EA_CONSTEXPR + // Compute w = G * e. + template + void EACode::dualEncode(span e, span w, Ctx ctx) + { + if (mCodeSize == 0) + throw RTE_LOC; + if (e.size() != mCodeSize) + throw RTE_LOC; + if (w.size() != mMessageSize) + throw RTE_LOC; + + + setTimePoint("EACode.encode.begin"); + + accumulate(e, ctx); + + setTimePoint("EACode.encode.accumulate"); + + expand(e.begin(), w.begin(), ctx); + setTimePoint("EACode.encode.expand"); + } + + + + + template + void EACode::dualEncode2( + span e0, + span w0, + span e1, + span w1, Ctx ctx + ) + { + if (mCodeSize == 0) + throw RTE_LOC; + if (e0.size() != mCodeSize) + throw RTE_LOC; + if (e1.size() != mCodeSize) + throw RTE_LOC; + if (w0.size() != mMessageSize) + throw RTE_LOC; + if (w1.size() != mMessageSize) + throw RTE_LOC; + + setTimePoint("EACode.encode.begin"); + + accumulate(e0, e1, ctx); + //accumulate(e1); + + setTimePoint("EACode.encode.accumulate"); + + expand(e0.begin(), w0.begin(), ctx); + expand(e1.begin(), w1.begin(), ctx); + setTimePoint("EACode.encode.expand"); + } + + + + template + void EACode::accumulate(span x, Ctx ctx) + { + if (x.size() != mCodeSize) + throw RTE_LOC; + auto main = (u64)std::max(0, mCodeSize - 1); + T* __restrict xx = x.data(); + + for (u64 i = 0; i < main; ++i) + { + ctx.plus(xx[i + 1], xx[i + 1], xx[i]); + } + } + + template + void EACode::accumulate(span x, span x2, Ctx ctx) + { + if (x.size() != mCodeSize) + throw RTE_LOC; + if (x2.size() != mCodeSize) + throw RTE_LOC; + + auto main = (u64)std::max(0, mCodeSize - 1); + T* __restrict xx1 = x.data(); + T2* __restrict xx2 = x2.data(); + + for (u64 i = 0; i < main; ++i) + { + ctx.plus(xx1[i + 1], xx1[i + 1], xx1[i]); + ctx.plus(xx2[i + 1], xx2[i + 1], xx2[i]); + } + } + + + inline SparseMtx EACode::getAPar() const + { + PointList AP(mCodeSize, mCodeSize);; + for (u64 i = 0; i < mCodeSize; ++i) + { + AP.push_back(i, i); + if (i + 1 < mCodeSize) + AP.push_back(i + 1, i); + } + return AP; + } + + inline SparseMtx EACode::getA() const + { + auto APar = getAPar(); + + auto A = DenseMtx::Identity(mCodeSize); + + for (u64 i = 0; i < mCodeSize; ++i) + { + for (auto y : APar.col(i)) + { + //std::cout << y << " "; + if (y != i) + { + auto ay = A.row(y); + auto ai = A.row(i); + ay ^= ai; + + } + } + + //std::cout << "\n" << A << std::endl; + } + + return A.sparse(); + } } diff --git a/libOTe/Tools/EACode/EACodeInstantiations.cpp b/libOTe/Tools/EACode/EACodeInstantiations.cpp deleted file mode 100644 index 502398c3..00000000 --- a/libOTe/Tools/EACode/EACodeInstantiations.cpp +++ /dev/null @@ -1,15 +0,0 @@ - -#define EACODE_INSTANTIONATIONS -#include "EACode.cpp" - -namespace osuCrypto -{ - template void EACode::dualEncode(span e, span w); - template void EACode::dualEncode(span e, span w); - template void EACode::accumulate(span e); - template void EACode::accumulate(span e); - - template void EACode::dualEncode2(span e, span w, span e2, span w2); - template void EACode::dualEncode2(span e, span w, span e2, span w2); - template void EACode::accumulate(spane0, span e); -} \ No newline at end of file diff --git a/libOTe/Tools/EACode/Expander.cpp b/libOTe/Tools/EACode/Expander.cpp deleted file mode 100644 index 98bbb99e..00000000 --- a/libOTe/Tools/EACode/Expander.cpp +++ /dev/null @@ -1,433 +0,0 @@ - -#include "Expander.h" -#include "cryptoTools/Common/Range.h" - -namespace osuCrypto -{ - template - typename std::enable_if::type - ExpanderCode::expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const - { - auto r = prng.get(); - return ee[r]; - } - - template - typename std::enable_if::type - ExpanderCode::expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng) const - { - auto r = prng.get(); - - if (Add) - { - *y1 = *y1 ^ ee1[r]; - *y2 = *y2 ^ ee2[r]; - } - else - { - - *y1 = ee1[r]; - *y2 = ee2[r]; - } - } - - - template - OC_FORCEINLINE typename std::enable_if<(count > 1), T>::type - ExpanderCode::expandOne(const T* __restrict ee, detail::ExpanderModd& prng)const - { - if constexpr (count >= 8) - { - u64 rr[8]; - T w[8]; - rr[0] = prng.get(); - rr[1] = prng.get(); - rr[2] = prng.get(); - rr[3] = prng.get(); - rr[4] = prng.get(); - rr[5] = prng.get(); - rr[6] = prng.get(); - rr[7] = prng.get(); - - w[0] = ee[rr[0]]; - w[1] = ee[rr[1]]; - w[2] = ee[rr[2]]; - w[3] = ee[rr[3]]; - w[4] = ee[rr[4]]; - w[5] = ee[rr[5]]; - w[6] = ee[rr[6]]; - w[7] = ee[rr[7]]; - - auto ww = - w[0] ^ - w[1] ^ - w[2] ^ - w[3] ^ - w[4] ^ - w[5] ^ - w[6] ^ - w[7]; - - if constexpr (count > 8) - ww = ww ^ expandOne(ee, prng); - return ww; - } - else - { - - auto r = prng.get(); - auto ww = expandOne(ee, prng); - return ww ^ ee[r]; - } - } - - - template - OC_FORCEINLINE typename std::enable_if<(count > 1)>::type - ExpanderCode::expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng)const - { - if constexpr (count >= 8) - { - u64 rr[8]; - T w1[8]; - T2 w2[8]; - rr[0] = prng.get(); - rr[1] = prng.get(); - rr[2] = prng.get(); - rr[3] = prng.get(); - rr[4] = prng.get(); - rr[5] = prng.get(); - rr[6] = prng.get(); - rr[7] = prng.get(); - - w1[0] = ee1[rr[0]]; - w1[1] = ee1[rr[1]]; - w1[2] = ee1[rr[2]]; - w1[3] = ee1[rr[3]]; - w1[4] = ee1[rr[4]]; - w1[5] = ee1[rr[5]]; - w1[6] = ee1[rr[6]]; - w1[7] = ee1[rr[7]]; - - w2[0] = ee2[rr[0]]; - w2[1] = ee2[rr[1]]; - w2[2] = ee2[rr[2]]; - w2[3] = ee2[rr[3]]; - w2[4] = ee2[rr[4]]; - w2[5] = ee2[rr[5]]; - w2[6] = ee2[rr[6]]; - w2[7] = ee2[rr[7]]; - - auto ww1 = - w1[0] ^ - w1[1] ^ - w1[2] ^ - w1[3] ^ - w1[4] ^ - w1[5] ^ - w1[6] ^ - w1[7]; - auto ww2 = - w2[0] ^ - w2[1] ^ - w2[2] ^ - w2[3] ^ - w2[4] ^ - w2[5] ^ - w2[6] ^ - w2[7]; - - if constexpr (count > 8) - { - T yy1; - T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); - ww1 = ww1 ^ yy1; - ww2 = ww2 ^ yy2; - } - - if constexpr (Add) - { - *y1 = *y1 ^ ww1; - *y2 = *y2 ^ ww2; - } - else - { - *y1 = ww1; - *y2 = ww2; - } - - } - else - { - - auto r = prng.get(); - if constexpr (Add) - { - auto w1 = ee1[r]; - auto w2 = ee2[r]; - expandOne(ee1, ee2, y1, y2, prng); - *y1 = *y1 ^ w1; - *y2 = *y2 ^ w2; - - } - else - { - - T yy1; - T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); - *y1 = ee1[r] ^ yy1; - *y2 = ee2[r] ^ yy2; - } - } - } - - - - template - void ExpanderCode::expand( - span e, - span w) const - { - assert(w.size() == mMessageSize); - assert(e.size() == mCodeSize); - detail::ExpanderModd prng(mSeed, mCodeSize); - - const T* __restrict ee = e.data(); - T* __restrict ww = w.data(); - - auto main = mMessageSize / 8 * 8; - u64 i = 0; - - for (; i < main; i += 8) - { -#define CASE(I) \ - case I:\ - if constexpr(Add)\ - {\ - ww[i + 0] = ww[i + 0] ^ expandOne(ee, prng);\ - ww[i + 1] = ww[i + 1] ^ expandOne(ee, prng);\ - ww[i + 2] = ww[i + 2] ^ expandOne(ee, prng);\ - ww[i + 3] = ww[i + 3] ^ expandOne(ee, prng);\ - ww[i + 4] = ww[i + 4] ^ expandOne(ee, prng);\ - ww[i + 5] = ww[i + 5] ^ expandOne(ee, prng);\ - ww[i + 6] = ww[i + 6] ^ expandOne(ee, prng);\ - ww[i + 7] = ww[i + 7] ^ expandOne(ee, prng);\ - }\ - else\ - {\ - ww[i + 0] = expandOne(ee, prng);\ - ww[i + 1] = expandOne(ee, prng);\ - ww[i + 2] = expandOne(ee, prng);\ - ww[i + 3] = expandOne(ee, prng);\ - ww[i + 4] = expandOne(ee, prng);\ - ww[i + 5] = expandOne(ee, prng);\ - ww[i + 6] = expandOne(ee, prng);\ - ww[i + 7] = expandOne(ee, prng);\ - }\ - break - - switch (mExpanderWeight) - { - CASE(5); - CASE(7); - CASE(9); - CASE(11); - CASE(21); - CASE(40); - default: - for (u64 jj = 0; jj < 8; ++jj) - { - auto r = prng.get(); - auto wv = ee[r]; - - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv = wv ^ ee[r]; - } - if constexpr (Add) - ww[i + jj] = ww[i + jj] ^ wv; - else - ww[i + jj] = wv; - - } - } -#undef CASE - } - - for (; i < mMessageSize; ++i) - { - auto wv = ee[prng.get()]; - for (auto j = 1ull; j < mExpanderWeight; ++j) - wv = wv ^ ee[prng.get()]; - - if constexpr (Add) - ww[i] = ww[i] ^ wv; - else - ww[i] = wv; - } - } - - - - template - void ExpanderCode::expand( - span e1, - span e2, - span w1, - span w2 - ) const - { - assert(w1.size() == mMessageSize); - assert(w2.size() == mMessageSize); - assert(e1.size() == mCodeSize); - assert(e2.size() == mCodeSize); - detail::ExpanderModd prng(mSeed, mCodeSize); - - const T* __restrict ee1 = e1.data(); - const T2* __restrict ee2 = e2.data(); - T* __restrict ww1 = w1.data(); - T2* __restrict ww2 = w2.data(); - - auto main = mMessageSize / 8 * 8; - u64 i = 0; - - for (; i < main; i += 8) - { -#define CASE(I) \ - case I:\ - expandOne(ee1, ee2, &ww1[i + 0], &ww2[i + 0], prng);\ - expandOne(ee1, ee2, &ww1[i + 1], &ww2[i + 1], prng);\ - expandOne(ee1, ee2, &ww1[i + 2], &ww2[i + 2], prng);\ - expandOne(ee1, ee2, &ww1[i + 3], &ww2[i + 3], prng);\ - expandOne(ee1, ee2, &ww1[i + 4], &ww2[i + 4], prng);\ - expandOne(ee1, ee2, &ww1[i + 5], &ww2[i + 5], prng);\ - expandOne(ee1, ee2, &ww1[i + 6], &ww2[i + 6], prng);\ - expandOne(ee1, ee2, &ww1[i + 7], &ww2[i + 7], prng);\ - break - - switch (mExpanderWeight) - { - CASE(5); - CASE(7); - CASE(9); - CASE(11); - CASE(21); - CASE(40); - default: - for (u64 jj = 0; jj < 8; ++jj) - { - auto r = prng.get(); - auto wv1 = ee1[r]; - auto wv2 = ee2[r]; - - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv1 = wv1 ^ ee1[r]; - wv2 = wv2 ^ ee2[r]; - } - if constexpr (Add) - { - ww1[i + jj] = ww1[i + jj] ^ wv1; - ww2[i + jj] = ww2[i + jj] ^ wv2; - } - else - { - - ww1[i + jj] = wv1; - ww2[i + jj] = wv2; - } - } - } -#undef CASE - } - - for (; i < mMessageSize; ++i) - { - auto r = prng.get(); - auto wv1 = ee1[r]; - auto wv2 = ee2[r]; - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv1 = wv1 ^ ee1[r]; - wv2 = wv2 ^ ee2[r]; - - } - if constexpr (Add) - { - ww1[i] = ww1[i] ^ wv1; - ww2[i] = ww2[i] ^ wv2; - } - else - { - ww1[i] = wv1; - ww2[i] = wv2; - } - } - } - -#ifndef EXPANDER_INSTANTATIONS - SparseMtx ExpanderCode::getB() const - { - //PRNG prng(mSeed); - detail::ExpanderModd prng(mSeed, mCodeSize); - PointList points(mMessageSize, mCodeSize); - - std::vector row(mExpanderWeight); - - { - - for (auto i : rng(mMessageSize)) - { - row[0] = prng.get(); - //points.push_back(i, row[0]); - for (auto j : rng(1, mExpanderWeight)) - { - //do { - row[j] = prng.get(); - //} while - auto iter = std::find(row.data(), row.data() + j, row[j]); - if (iter != row.data() + j) - { - row[j] = ~0ull; - *iter = ~0ull; - } - //throw RTE_LOC; - - } - for (auto j : rng(mExpanderWeight)) - { - - if (row[j] != ~0ull) - { - //std::cout << row[j] << " "; - points.push_back(i, row[j]); - } - else - { - //std::cout << "* "; - } - } - //std::cout << std::endl; - } - } - - return points; - } -#endif - -} diff --git a/libOTe/Tools/EACode/Expander.h b/libOTe/Tools/EACode/Expander.h deleted file mode 100644 index 5cd0e737..00000000 --- a/libOTe/Tools/EACode/Expander.h +++ /dev/null @@ -1,98 +0,0 @@ -// © 2023 Peter Rindal. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -#pragma once - -#include "cryptoTools/Common/Defines.h" -#include "libOTe/Tools/LDPC/Mtx.h" -#include "Util.h" - -namespace osuCrypto -{ - - // The encoder for the expander matrix B. - // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly - // with fixed row weight mExpanderWeight. - class ExpanderCode - { - public: - - void config( - u64 messageSize, - u64 codeSize = 0 /* default is 5* messageSize */, - u64 expanderWeight = 21, - block seed = block(33333, 33333)) - { - mMessageSize = messageSize; - mCodeSize = codeSize; - mExpanderWeight = expanderWeight; - mSeed = seed; - - } - - // the seed that generates the code. - block mSeed = block(0, 0); - - // The message size of the code. K. - u64 mMessageSize = 0; - - // The codeword size of the code. n. - u64 mCodeSize = 0; - - // The row weight of the B matrix. - u64 mExpanderWeight = 0; - - u64 parityRows() const { return mCodeSize - mMessageSize; } - u64 parityCols() const { return mCodeSize; } - - u64 generatorRows() const { return mMessageSize; } - u64 generatorCols() const { return mCodeSize; } - - - - template - typename std::enable_if<(count > 1), T>::type - expandOne(const T* __restrict ee, detail::ExpanderModd& prng)const; - - template - typename std::enable_if<(count > 1)>::type - expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng)const; - - template - typename std::enable_if::type - expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const; - - template - typename std::enable_if::type - expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng) const; - - template - void expand( - span e, - span w) const; - - template - void expand( - span e1, - span e2, - span w1, - span w2 - ) const; - - SparseMtx getB() const; - - }; -} diff --git a/libOTe/Tools/EACode/ExpanderInstantiations.cpp b/libOTe/Tools/EACode/ExpanderInstantiations.cpp deleted file mode 100644 index 130e095d..00000000 --- a/libOTe/Tools/EACode/ExpanderInstantiations.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#define EXPANDER_INSTANTATIONS -#include "Expander.cpp" - -namespace osuCrypto -{ - - template void ExpanderCode::expand(span e, span w) const; - template void ExpanderCode::expand(span e, span w) const; - template void ExpanderCode::expand(span e, span w) const; - template void ExpanderCode::expand(span e, span w) const; - - template void ExpanderCode::expand( - span e, span e2, - span w, span w2) const; - template void ExpanderCode::expand( - span e, span e2, - span w, span w2) const; - - - template void ExpanderCode::expand( - span e, span e2, - span w, span w2) const; - //template void ExpanderCode::expand( - // span e, span e2, - // span w, span w2) const; -} \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode.cpp b/libOTe/Tools/ExConvCode/ExConvCode.cpp index c28cefed..7f843ecb 100644 --- a/libOTe/Tools/ExConvCode/ExConvCode.cpp +++ b/libOTe/Tools/ExConvCode/ExConvCode.cpp @@ -1,557 +1,85 @@ #include "ExConvCode.h" - namespace osuCrypto { -#ifdef ENABLE_SSE - - using My__m128 = __m128; - -#else - using My__m128 = block; - - inline My__m128 _mm_load_ps(float* b) { return *(block*)b; } - - // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_ps&ig_expand=557 - inline My__m128 _mm_blendv_ps(My__m128 a, My__m128 b, My__m128 mask) - { - My__m128 dst; - for (u64 j = 0; j < 4; ++j) - { - if (mask.get(j) < 0) - dst.set(j, b.get(j)); - else - dst.set(j, a.get(j)); - } - return dst; - } - - - inline My__m128 _mm_setzero_ps() { return ZeroBlock; } -#endif - - // Compute e = G * e. - template - void ExConvCode::dualEncode(span e) - { - if (e.size() != mCodeSize) - throw RTE_LOC; - - if (mSystematic) - { - auto d = e.subspan(mMessageSize); - setTimePoint("ExConv.encode.begin"); - accumulate(d); - setTimePoint("ExConv.encode.accumulate"); - mExpander.expand(d, e.subspan(0, mMessageSize)); - setTimePoint("ExConv.encode.expand"); - } - else - { - oc::AlignedUnVector w(mMessageSize); - dualEncode(e, w); - memcpy(e.data(), w.data(), w.size() * sizeof(T)); - setTimePoint("ExConv.encode.memcpy"); - - } - } - - - // Compute e = G * e. - template - void ExConvCode::dualEncode2(span e0, span e1) - { - if (e0.size() != mCodeSize) - throw RTE_LOC; - if (e1.size() != mCodeSize) - throw RTE_LOC; - - if (mSystematic) - { - auto d0 = e0.subspan(mMessageSize); - auto d1 = e1.subspan(mMessageSize); - setTimePoint("ExConv.encode.begin"); - accumulate(d0, d1); - setTimePoint("ExConv.encode.accumulate"); - mExpander.expand( - d0, d1, - e0.subspan(0, mMessageSize), - e1.subspan(0, mMessageSize)); - setTimePoint("ExConv.encode.expand"); - } - else - { - //oc::AlignedUnVector w0(mMessageSize); - //dualEncode(e, w); - //memcpy(e.data(), w.data(), w.size() * sizeof(T)); - //setTimePoint("ExConv.encode.memcpy"); - - // not impl. - throw RTE_LOC; - - } - } - - // Compute w = G * e. - template - void ExConvCode::dualEncode(span e, span w) - { - if (e.size() != mCodeSize) - throw RTE_LOC; - - if (w.size() != mMessageSize) - throw RTE_LOC; - - if (mSystematic) - { - dualEncode(e); - memcpy(w.data(), e.data(), w.size() * sizeof(T)); - setTimePoint("ExConv.encode.memcpy"); - } - else - { - - setTimePoint("ExConv.encode.begin"); - - accumulate(e); - - setTimePoint("ExConv.encode.accumulate"); - - mExpander.expand(e, w); - setTimePoint("ExConv.encode.expand"); - } - } - - inline void refill(PRNG& prng) - { - assert(prng.mBuffer.size() == 256); - //block b[8]; - for (u64 i = 0; i < 256; i += 8) - { - //auto idx = mPrng.mBuffer[i].get(); - block* __restrict b = prng.mBuffer.data() + i; - block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); - //for (u64 j = 0; j < 8; ++j) - //{ - // b = b ^ mPrng.mBuffer.data()[idx[j]]; - //} - b[0] = AES::roundEnc(b[0], k[0]); - b[1] = AES::roundEnc(b[1], k[1]); - b[2] = AES::roundEnc(b[2], k[2]); - b[3] = AES::roundEnc(b[3], k[3]); - b[4] = AES::roundEnc(b[4], k[4]); - b[5] = AES::roundEnc(b[5], k[5]); - b[6] = AES::roundEnc(b[6], k[6]); - b[7] = AES::roundEnc(b[7], k[7]); - - b[0] = b[0] ^ k[0]; - b[1] = b[1] ^ k[1]; - b[2] = b[2] ^ k[2]; - b[3] = b[3] ^ k[3]; - b[4] = b[4] ^ k[4]; - b[5] = b[5] ^ k[5]; - b[6] = b[6] ^ k[6]; - b[7] = b[7] ^ k[7]; - } - } - -#ifndef EXCONVCODE_INSTANTIATIONS - - void ExConvCode::accOne( - PointList& pl, - u64 i, - u8* __restrict& ptr, - PRNG& prng, - block& rnd, - u64& q, - u64 qe, - u64 size) const - { - u64 j = i + 1; - pl.push_back(i, i); - - if (q + mAccumulatorSize > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; - } - - - for (u64 k = 0; k < mAccumulatorSize; k += 8, q += 8, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - rnd = block::allSame(*ptr); - ++ptr; - - //std::cout << "r " << rnd << std::endl; - auto b0 = rnd; - auto b1 = rnd.slli_epi32<1>(); - auto b2 = rnd.slli_epi32<2>(); - auto b3 = rnd.slli_epi32<3>(); - auto b4 = rnd.slli_epi32<4>(); - auto b5 = rnd.slli_epi32<5>(); - auto b6 = rnd.slli_epi32<6>(); - auto b7 = rnd.slli_epi32<7>(); - //rnd = rnd.mm_slli_epi32<8>(); - - if (j + 0 < size && b0.get(0) < 0) pl.push_back(j + 0, i); - if (j + 1 < size && b1.get(0) < 0) pl.push_back(j + 1, i); - if (j + 2 < size && b2.get(0) < 0) pl.push_back(j + 2, i); - if (j + 3 < size && b3.get(0) < 0) pl.push_back(j + 3, i); - if (j + 4 < size && b4.get(0) < 0) pl.push_back(j + 4, i); - if (j + 5 < size && b5.get(0) < 0) pl.push_back(j + 5, i); - if (j + 6 < size && b6.get(0) < 0) pl.push_back(j + 6, i); - if (j + 7 < size && b7.get(0) < 0) pl.push_back(j + 7, i); - } - - - //if (mWrapping) - { - if (j < size) - pl.push_back(j, i); - ++j; - } - - } -#endif - - - template - OC_FORCEINLINE void accOneHelper( - T* __restrict xx, - My__m128 xii, - u64 j, u64 i, u64 size, - block* b - ) - { - My__m128 Zero = _mm_setzero_ps(); - - if constexpr (std::is_same::value) - { - My__m128 bb[8]; - bb[0] = _mm_load_ps((float*)&b[0]); - bb[1] = _mm_load_ps((float*)&b[1]); - bb[2] = _mm_load_ps((float*)&b[2]); - bb[3] = _mm_load_ps((float*)&b[3]); - bb[4] = _mm_load_ps((float*)&b[4]); - bb[5] = _mm_load_ps((float*)&b[5]); - bb[6] = _mm_load_ps((float*)&b[6]); - bb[7] = _mm_load_ps((float*)&b[7]); - - - bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); - bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); - bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); - bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); - bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); - bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); - bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); - bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); - block tt[8]; - memcpy(tt, bb, 8 * 16); - if (!rangeCheck || j + 0 < size) xx[j + 0] = xx[j + 0] ^ tt[0]; - if (!rangeCheck || j + 1 < size) xx[j + 1] = xx[j + 1] ^ tt[1]; - if (!rangeCheck || j + 2 < size) xx[j + 2] = xx[j + 2] ^ tt[2]; - if (!rangeCheck || j + 3 < size) xx[j + 3] = xx[j + 3] ^ tt[3]; - if (!rangeCheck || j + 4 < size) xx[j + 4] = xx[j + 4] ^ tt[4]; - if (!rangeCheck || j + 5 < size) xx[j + 5] = xx[j + 5] ^ tt[5]; - if (!rangeCheck || j + 6 < size) xx[j + 6] = xx[j + 6] ^ tt[6]; - if (!rangeCheck || j + 7 < size) xx[j + 7] = xx[j + 7] ^ tt[7]; - } - else - { - auto bb0 = xx[i] * (b[0].get(0) < 0); - auto bb1 = xx[i] * (b[1].get(0) < 0); - auto bb2 = xx[i] * (b[2].get(0) < 0); - auto bb3 = xx[i] * (b[3].get(0) < 0); - auto bb4 = xx[i] * (b[4].get(0) < 0); - auto bb5 = xx[i] * (b[5].get(0) < 0); - auto bb6 = xx[i] * (b[6].get(0) < 0); - auto bb7 = xx[i] * (b[7].get(0) < 0); + // configure the code. The default parameters are choses to balance security and performance. + // For additional parameter choices see the paper. - if (!rangeCheck || j + 0 < size) xx[j + 0] = xx[j + 0] ^ bb0; - if (!rangeCheck || j + 1 < size) xx[j + 1] = xx[j + 1] ^ bb1; - if (!rangeCheck || j + 2 < size) xx[j + 2] = xx[j + 2] ^ bb2; - if (!rangeCheck || j + 3 < size) xx[j + 3] = xx[j + 3] ^ bb3; - if (!rangeCheck || j + 4 < size) xx[j + 4] = xx[j + 4] ^ bb4; - if (!rangeCheck || j + 5 < size) xx[j + 5] = xx[j + 5] ^ bb5; - if (!rangeCheck || j + 6 < size) xx[j + 6] = xx[j + 6] ^ bb6; - if (!rangeCheck || j + 7 < size) xx[j + 7] = xx[j + 7] ^ bb7; - } - } + //// get the expander matrix + //SparseMtx ExConvCode::getB() const + //{ + // throw RTE_LOC; + // //if (mSystematic) + // //{ + // // PointList R(mMessageSize, mCodeSize); + // // auto B = mExpander.getB().points(); - template - OC_FORCEINLINE void ExConvCode::accOne( - T* __restrict xx, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size) - { - u64 j = i + 1; - if (width) - { - My__m128 xii; - if constexpr (std::is_same::value) - xii = _mm_load_ps((float*)(xx + i)); - else - xii = _mm_setzero_ps(); + // // for (auto p : B) + // // { + // // R.push_back(p.mRow, mMessageSize + p.mCol); + // // } + // // for (u64 i = 0; i < mMessageSize; ++i) + // // R.push_back(i, i); - if (q + width > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; + // // return R; + // //} + // //else + // //{ + // // return mExpander.getB(); + // //} + //} - } - q += width; - - for (u64 k = 0; k < width; ++k, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - block rnd = block::allSame(*(u8*)ptr++); - - - block b[8]; - b[0] = rnd; - b[1] = rnd.slli_epi32<1>(); - b[2] = rnd.slli_epi32<2>(); - b[3] = rnd.slli_epi32<3>(); - b[4] = rnd.slli_epi32<4>(); - b[5] = rnd.slli_epi32<5>(); - b[6] = rnd.slli_epi32<6>(); - b[7] = rnd.slli_epi32<7>(); - - accOneHelper(xx, xii, j, i, size, b); - } - } - - if (!rangeCheck || j < size) - { - auto xj = xx[j] ^ xx[i]; - xx[j] = xj; - } - } - - template - OC_FORCEINLINE void ExConvCode::accOne( - T0* __restrict xx0, - T1* __restrict xx1, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size) - { - u64 j = i + 1; - if (width) - { - My__m128 xii0, xii1; - if constexpr (std::is_same::value) - xii0 = _mm_load_ps((float*)(xx0 + i)); - else - xii0 = _mm_setzero_ps(); - if constexpr (std::is_same::value) - xii1 = _mm_load_ps((float*)(xx1 + i)); - else - xii1 = _mm_setzero_ps(); - - if (q + width > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; - - } - q += width; - - for (u64 k = 0; k < width; ++k, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - block rnd = block::allSame(*(u8*)ptr++); - - block b[8]; - b[0] = rnd; - b[1] = rnd.slli_epi32<1>(); - b[2] = rnd.slli_epi32<2>(); - b[3] = rnd.slli_epi32<3>(); - b[4] = rnd.slli_epi32<4>(); - b[5] = rnd.slli_epi32<5>(); - b[6] = rnd.slli_epi32<6>(); - b[7] = rnd.slli_epi32<7>(); - - accOneHelper(xx0, xii0, j, i, size, b); - accOneHelper(xx1, xii1, j, i, size, b); - } - } - - if (!rangeCheck || j < size) - { - auto xj0 = xx0[j] ^ xx0[i]; - auto xj1 = xx1[j] ^ xx1[i]; - xx0[j] = xj0; - xx1[j] = xj1; - } - } - - - - template - void ExConvCode::accumulate(span x) - { - PRNG prng(mSeed ^ OneBlock); - - u64 i = 0; - auto size = x.size(); - auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); - u8* ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128 / 8; - u64 q = 0; - T* __restrict xx = x.data(); - - { - -#define CASE(I) case I:\ - for (; i < main; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ - for (; i < size; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ - break - - switch (mAccumulatorSize / 8) - { - CASE(0); - CASE(1); - CASE(2); - CASE(3); - CASE(4); - default: - throw RTE_LOC; - break; - } -#undef CASE - } - } - - - template - void ExConvCode::accumulate(span x0, span x1) - { - PRNG prng(mSeed ^ OneBlock); - - u64 i = 0; - auto size = x0.size(); - auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); - u8* ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128 / 8; - u64 q = 0; - T0* __restrict xx0 = x0.data(); - T1* __restrict xx1 = x1.data(); - - { - -#define CASE(I) case I:\ - for (; i < main; ++i)\ - accOne(xx0,xx1, i, ptr, prng, q, qe, size);\ - for (; i < size; ++i)\ - accOne(xx0, xx1, i, ptr, prng, q, qe, size);\ - break - - switch (mAccumulatorSize / 8) - { - CASE(0); - CASE(1); - CASE(2); - CASE(3); - CASE(4); - default: - throw RTE_LOC; - break; - } -#undef CASE - } - } - - -#ifndef EXCONVCODE_INSTANTIATIONS - - SparseMtx ExConvCode::getB() const - { - if (mSystematic) - { - PointList R(mMessageSize, mCodeSize); - auto B = mExpander.getB().points(); - - for (auto p : B) - { - R.push_back(p.mRow, mMessageSize + p.mCol); - } - for (u64 i = 0; i < mMessageSize; ++i) - R.push_back(i, i); - - return R; - } - else - { - return mExpander.getB(); - } - - } // Get the parity check version of the accumulator - SparseMtx ExConvCode::getAPar() const - { - PRNG prng(mSeed ^ OneBlock); - - auto n = mCodeSize - mSystematic * mMessageSize; - - PointList AP(n, n);; - DenseMtx A = DenseMtx::Identity(n); - - block rnd; - u8* __restrict ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128; - u64 q = 0; - - for (u64 i = 0; i < n; ++i) - { - accOne(AP, i, ptr, prng, rnd, q, qe, n); - } - return AP; - } - - SparseMtx ExConvCode::getA() const - { - auto APar = getAPar(); - - auto A = DenseMtx::Identity(mCodeSize); - - u64 offset = mSystematic ? mMessageSize : 0ull; - - for (u64 i = 0; i < APar.rows(); ++i) - { - for (auto y : APar.col(i)) - { - //std::cout << y << " "; - if (y != i) - { - auto ay = A.row(y + offset); - auto ai = A.row(i + offset); - ay ^= ai; - } - } + //SparseMtx ExConvCode::getAPar() const + //{ + // throw RTE_LOC; + // //PRNG prng(mSeed ^ OneBlock); + + // //auto n = mCodeSize - mSystematic * mMessageSize; + + // //PointList AP(n, n);; + // //DenseMtx A = DenseMtx::Identity(n); + + // //block rnd; + // //u8* __restrict ptr = (u8*)prng.mBuffer.data(); + // //auto qe = prng.mBuffer.size() * 128; + // //u64 q = 0; + + // //for (u64 i = 0; i < n; ++i) + // //{ + // // accOne(AP, i, ptr, prng, rnd, q, qe, n); + // //} + // //return AP; + //} + + //// get the accumulator matrix + //SparseMtx ExConvCode::getA() const + //{ + // auto APar = getAPar(); + + // auto A = DenseMtx::Identity(mCodeSize); + + // u64 offset = mSystematic ? mMessageSize : 0ull; + + // for (u64 i = 0; i < APar.rows(); ++i) + // { + // for (auto y : APar.col(i)) + // { + // if (y != i) + // { + // auto ay = A.row(y + offset); + // auto ai = A.row(i + offset); + // ay ^= ai; + // } + // } + // } + + // return A.sparse(); + //} - //std::cout << "\n" << A << std::endl; - } - return A.sparse(); - } -#endif } \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode.h b/libOTe/Tools/ExConvCode/ExConvCode.h index 13a25cad..277702e3 100644 --- a/libOTe/Tools/ExConvCode/ExConvCode.h +++ b/libOTe/Tools/ExConvCode/ExConvCode.h @@ -1,4 +1,4 @@ -// © 2023 Visa. +// © 2023 Visa. // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. @@ -8,12 +8,39 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" -#include "libOTe/Tools/EACode/Expander.h" +#include "libOTe/Tools/ExConvCode/Expander.h" #include "libOTe/Tools/EACode/Util.h" namespace osuCrypto { + template + struct has_operator_star : std::false_type + {}; + + template + struct has_operator_star < T, std::void_t< + // must have a operator*() member fn + decltype(std::declval().operator*()) + >> + : std::true_type{}; + + + template + struct is_iterator : std::false_type + {}; + + template + struct is_iterator < T, std::void_t< + // must have a operator*() member fn + // or be a pointer + std::enable_if_t< + has_operator_star::value || + std::is_pointer_v> + > + >> + : std::true_type{}; + // The encoder for the generator matrix G = B * A. dualEncode(...) is the main function // config(...) should be called first. // @@ -39,25 +66,11 @@ namespace osuCrypto // For additional parameter choices see the paper. void config( u64 messageSize, - u64 codeSize = 0 /*2 * messageSize is default */, + u64 codeSize, u64 expanderWeight = 7, - u64 accumulatorSize = 16, + u64 accumulatorWeight = 16, bool systematic = true, - block seed = block(99999, 88888)) - { - if (codeSize == 0) - codeSize = 2 * messageSize; - - if (accumulatorSize % 8) - throw std::runtime_error("ExConvCode accumulator size must be a multiple of 8." LOCATION); - - mSeed = seed; - mMessageSize = messageSize; - mCodeSize = codeSize; - mAccumulatorSize = accumulatorSize; - mSystematic = systematic; - mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); - } + block seed = block(9996754675674599, 56756745976768754)); // the seed that generates the code. block mSeed = ZeroBlock; @@ -86,73 +99,448 @@ namespace osuCrypto // return code size n. u64 generatorCols() const { return mCodeSize; } - // Compute w = G * e. e will be modified in the computation. - template - void dualEncode(span e, span w); - // Compute e[0,...,k-1] = G * e. - template - void dualEncode(span e); - + // the computation will be done over F using ctx.plus + template< + typename F, + typename CoeffCtx, + typename Iter + > + void dualEncode(Iter&& e, CoeffCtx ctx); // Compute e[0,...,k-1] = G * e. - template - void dualEncode2(span e0, span e1); + template< + typename F, + typename G, + typename CoeffCtx, + typename IterF, + typename IterG + > + void dualEncode2(IterF&& e0, IterG&& e1, CoeffCtx ctx) + { + dualEncode(e0, ctx); + dualEncode(e1, ctx); + } + + // Private functions ------------------------------------ - // get the expander matrix - SparseMtx getB() const; + static void refill(PRNG& prng) + { + assert(prng.mBuffer.size() == 256); + //block b[8]; + for (u64 i = 0; i < 256; i += 8) + { + block* __restrict b = prng.mBuffer.data() + i; + block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); - // Get the parity check version of the accumulator - SparseMtx getAPar() const; + b[0] = AES::roundEnc(b[0], k[0]); + b[1] = AES::roundEnc(b[1], k[1]); + b[2] = AES::roundEnc(b[2], k[2]); + b[3] = AES::roundEnc(b[3], k[3]); + b[4] = AES::roundEnc(b[4], k[4]); + b[5] = AES::roundEnc(b[5], k[5]); + b[6] = AES::roundEnc(b[6], k[6]); + b[7] = AES::roundEnc(b[7], k[7]); + } + } - // get the accumulator matrix - SparseMtx getA() const; + // take x[i] and add it to the next 8 positions if the flag b is 1. + // + // xx[j] += b[j] * x[i] + // + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void accOne8( + Iter&& xi, + Iter&& xj, + Iter&& end, + u8 b, + CoeffCtx& ctx); - // Private functions ------------------------------------ + // accumulating row i. + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + int AccumulatorSize, + typename Iter + > + OC_FORCEINLINE void accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx, + std::integral_constant); - // generate the point list for accumulating row i. - void accOne( - PointList& pl, - u64 i, - u8* __restrict& ptr, - PRNG& prng, - block& rnd, - u64& q, - u64 qe, - u64 size) const; - - // accumulating row i. - template - void accOne( - T* __restrict xx, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size); - - - // accumulating row i. - template - void accOne( - T0* __restrict xx0, - T1* __restrict xx1, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size); + // accumulating row i. generic version + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx, + std::integral_constant); // accumulate x onto itself. - template - void accumulate(span x); - + template< + typename F, + typename CoeffCtx, + typename Iter + > + void accumulate( + Iter x, + CoeffCtx& ctx) + { + switch (mAccumulatorSize) + { + case 16: + accumulateFixed(std::forward(x), ctx); + break; + case 24: + accumulateFixed(std::forward(x), ctx); + break; + default: + // generic case + accumulateFixed(std::forward(x), ctx); + } + } // accumulate x onto itself. - template - void accumulate(span x0, span x1); + template< + typename F, + typename CoeffCtx, + u64 AccumulatorSize, + typename Iter + > + void accumulateFixed(Iter x, + CoeffCtx& ctx); + }; -} + + + inline void ExConvCode::config( + u64 messageSize, + u64 codeSize, + u64 expanderWeight, + u64 accumulatorSize, + bool systematic, + block seed) + { + if (codeSize == 0) + codeSize = 2 * messageSize; + + mSeed = seed; + mMessageSize = messageSize; + mCodeSize = codeSize; + mAccumulatorSize = accumulatorSize; + mSystematic = systematic; + mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); + } + + // Compute e[0,...,k-1] = G * e. + template + void ExConvCode::dualEncode( + Iter&& e, + CoeffCtx ctx) + { + static_assert(is_iterator::value, "must pass in an iterator to the data"); + + (void)*(e + mCodeSize - 1); + + if (mSystematic) + { + auto d = e + mMessageSize; + setTimePoint("ExConv.encode.begin"); + accumulate(d, ctx); + setTimePoint("ExConv.encode.accumulate"); + mExpander.expand(d, e, ctx); + setTimePoint("ExConv.encode.expand"); + } + else + { + + setTimePoint("ExConv.encode.begin"); + accumulate(e, ctx); + setTimePoint("ExConv.encode.accumulate"); + + CoeffCtx::template Vec w; + ctx.resize(w, mMessageSize); + mExpander.expand(e, w.begin(), ctx); + setTimePoint("ExConv.encode.expand"); + + ctx.copy(w.begin(), w.end(), e); + setTimePoint("ExConv.encode.memcpy"); + + } + } + + // take x[i] and add it to the next 8 positions if the flag b is 1. + // + // xx[j] += b[j] * x[i] + // + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void ExConvCode::accOne8( + Iter&& xi, + Iter&& xj, + Iter&& end, + u8 b, + CoeffCtx& ctx) + { + +#ifdef ENABLE_SSE + if constexpr (std::is_same::value) + { + block rnd = block::allSame(b); + + block bshift[8]; + bshift[0] = rnd.slli_epi32<7>(); + bshift[1] = rnd.slli_epi32<6>(); + bshift[2] = rnd.slli_epi32<5>(); + bshift[3] = rnd.slli_epi32<4>(); + bshift[4] = rnd.slli_epi32<3>(); + bshift[5] = rnd.slli_epi32<2>(); + bshift[6] = rnd.slli_epi32<1>(); + bshift[7] = rnd; + + __m128 bb[8]; + auto xii = _mm_load_ps((float*)(&*xi)); + __m128 Zero = _mm_setzero_ps(); + + // bbj = bj + bb[0] = _mm_load_ps((float*)&bshift[0]); + bb[1] = _mm_load_ps((float*)&bshift[1]); + bb[2] = _mm_load_ps((float*)&bshift[2]); + bb[3] = _mm_load_ps((float*)&bshift[3]); + bb[4] = _mm_load_ps((float*)&bshift[4]); + bb[5] = _mm_load_ps((float*)&bshift[5]); + bb[6] = _mm_load_ps((float*)&bshift[6]); + bb[7] = _mm_load_ps((float*)&bshift[7]); + + // bbj = bj * xi + bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); + bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); + bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); + bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); + bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); + bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); + bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); + bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); + + block tt[8]; + memcpy(tt, bb, 8 * 16); + + assert((((b >> 0) & 1) ? *xi : ZeroBlock) == tt[0]); + assert((((b >> 1) & 1) ? *xi : ZeroBlock) == tt[1]); + assert((((b >> 2) & 1) ? *xi : ZeroBlock) == tt[2]); + assert((((b >> 3) & 1) ? *xi : ZeroBlock) == tt[3]); + assert((((b >> 4) & 1) ? *xi : ZeroBlock) == tt[4]); + assert((((b >> 5) & 1) ? *xi : ZeroBlock) == tt[5]); + assert((((b >> 6) & 1) ? *xi : ZeroBlock) == tt[6]); + assert((((b >> 7) & 1) ? *xi : ZeroBlock) == tt[7]); + + // xj += bj * xi + if (rangeCheck && xj + 0 == end) return; ctx.plus(*(xj + 0), *(xj + 0), tt[0]); + if (rangeCheck && xj + 1 == end) return; ctx.plus(*(xj + 1), *(xj + 1), tt[1]); + if (rangeCheck && xj + 2 == end) return; ctx.plus(*(xj + 2), *(xj + 2), tt[2]); + if (rangeCheck && xj + 3 == end) return; ctx.plus(*(xj + 3), *(xj + 3), tt[3]); + if (rangeCheck && xj + 4 == end) return; ctx.plus(*(xj + 4), *(xj + 4), tt[4]); + if (rangeCheck && xj + 5 == end) return; ctx.plus(*(xj + 5), *(xj + 5), tt[5]); + if (rangeCheck && xj + 6 == end) return; ctx.plus(*(xj + 6), *(xj + 6), tt[6]); + if (rangeCheck && xj + 7 == end) return; ctx.plus(*(xj + 7), *(xj + 7), tt[7]); + } + else +#endif + { + auto b0 = b & 1; + auto b1 = b & 2; + auto b2 = b & 4; + auto b3 = b & 8; + auto b4 = b & 16; + auto b5 = b & 32; + auto b6 = b & 64; + auto b7 = b & 128; + + if (rangeCheck && xj + 0 == end) return; if (b0) ctx.plus(*(xj + 0), *(xj + 0), *xi); + if (rangeCheck && xj + 1 == end) return; if (b1) ctx.plus(*(xj + 1), *(xj + 1), *xi); + if (rangeCheck && xj + 2 == end) return; if (b2) ctx.plus(*(xj + 2), *(xj + 2), *xi); + if (rangeCheck && xj + 3 == end) return; if (b3) ctx.plus(*(xj + 3), *(xj + 3), *xi); + if (rangeCheck && xj + 4 == end) return; if (b4) ctx.plus(*(xj + 4), *(xj + 4), *xi); + if (rangeCheck && xj + 5 == end) return; if (b5) ctx.plus(*(xj + 5), *(xj + 5), *xi); + if (rangeCheck && xj + 6 == end) return; if (b6) ctx.plus(*(xj + 6), *(xj + 6), *xi); + if (rangeCheck && xj + 7 == end) return; if (b7) ctx.plus(*(xj + 7), *(xj + 7), *xi); + } + } + + + + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void ExConvCode::accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx, + std::integral_constant _) + { + + // xj += xi + std::remove_reference_t xj = xi + 1; + if (!rangeCheck || xj < end) + { + ctx.plus(*xj, *xj, *xi); + ++xj; + } + + // xj += bj * xi + u64 k = 0; + for (; k < mAccumulatorSize - 7; k += 8) + { + accOne8(xi, xj, end, *matrixCoeff++, ctx); + + if constexpr (rangeCheck) + { + auto r = end - xj; + xj += std::min(r, 8); + } + else + { + xj += 8; + } + } + for (; k < mAccumulatorSize; ) + { + auto b = *matrixCoeff++; + + for (u64 j = 0; j < 8 && k < mAccumulatorSize; ++j, ++k) + { + + if (rangeCheck == false || (xj != end)) + { + if (b & 1) + ctx.plus(*xj, *xj, *xi); + + ++xj; + b >>= 1; + } + + } + } + } + + + // add xi to all of the future locations + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + int AccumulatorSize, + typename Iter + > + OC_FORCEINLINE void ExConvCode::accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx, + std::integral_constant) + { + static_assert(AccumulatorSize, "should have called the other overload"); + static_assert(AccumulatorSize % 8 == 0, "must be a multiple of 8"); + + // xj += xi + std::remove_reference_t xj = xi + 1; + if (!rangeCheck || xj < end) + { + ctx.plus(*xj, *xj, *xi); + ++xj; + } + + // xj += bj * xi + for (u64 k = 0; k < AccumulatorSize; k += 8) + { + accOne8(xi, xj, end, *matrixCoeff++, ctx); + + if constexpr (rangeCheck) + { + auto r = end - xj; + xj += std::min(r, 8); + } + else + { + xj += 8; + } + } + } + + // accumulate x onto itself. + template< + typename F, + typename CoeffCtx, + u64 AccumulatorSize, + typename Iter + > + void ExConvCode::accumulateFixed( + Iter xi, + CoeffCtx& ctx) + { + auto end = xi + (mCodeSize - mSystematic * mMessageSize); + auto main = end - 1 - mAccumulatorSize; + + PRNG prng(mSeed ^ OneBlock); + u8* mtxCoeffIter = (u8*)prng.mBuffer.data(); + auto mtxCoeffEnd = mtxCoeffIter + prng.mBuffer.size() * sizeof(block) - divCeil(mAccumulatorSize, 8); + + // AccumulatorSize == 0 is the generic case, otherwise + // AccumulatorSize should be equal to mAccumulatorSize. + static_assert(AccumulatorSize % 8 == 0); + if (AccumulatorSize && mAccumulatorSize != AccumulatorSize) + throw RTE_LOC; + + while (xi < main) + { + if (mtxCoeffIter > mtxCoeffEnd) + { + // generate more mtx coefficients + refill(prng); + mtxCoeffIter = (u8*)prng.mBuffer.data(); + } + + // add xi to the next positions + accOne(xi, end, mtxCoeffIter++, ctx, std::integral_constant{}); + ++xi; + } + + while (xi < end) + { + if (mtxCoeffIter > mtxCoeffEnd) + { + // generate more mtx coefficients + refill(prng); + mtxCoeffIter = (u8*)prng.mBuffer.data(); + } + + // add xi to the next positions + accOne(xi, end, mtxCoeffIter++, ctx, std::integral_constant{}); + ++xi; + } + } + +} \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode2.cpp b/libOTe/Tools/ExConvCode/ExConvCode2.cpp deleted file mode 100644 index aa4521c8..00000000 --- a/libOTe/Tools/ExConvCode/ExConvCode2.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include "ExConvCode2.h" -//#include "ExConvCode2Impl.h" -#include "libOTe/Tools/Subfield/Subfield.h" - -namespace osuCrypto -{ - - - //template void ExConvCode2::dualEncode(span e); - //template void ExConvCode2::dualEncode(span e, span w); - // - //template void ExConvCode2::dualEncode(span e); - //template void ExConvCode2::dualEncode(span e, span w); - - //template void ExConvCode2::dualEncode2(span, span e); - //template void ExConvCode2::accumulate(span, span e); - - //template void ExConvCode2::dualEncode2(span, span e); - //template void ExConvCode2::accumulate(span, span e); - - - // configure the code. The default parameters are choses to balance security and performance. - // For additional parameter choices see the paper. - void ExConvCode2::config( - u64 messageSize, - u64 codeSize, - u64 expanderWeight, - u64 accumulatorSize, - bool systematic, - block seed) - { - if (codeSize == 0) - codeSize = 2 * messageSize; - - mSeed = seed; - mMessageSize = messageSize; - mCodeSize = codeSize; - mAccumulatorSize = accumulatorSize; - mSystematic = systematic; - mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); - } - - //// get the expander matrix - //SparseMtx ExConvCode2::getB() const - //{ - // throw RTE_LOC; - // //if (mSystematic) - // //{ - // // PointList R(mMessageSize, mCodeSize); - // // auto B = mExpander.getB().points(); - - // // for (auto p : B) - // // { - // // R.push_back(p.mRow, mMessageSize + p.mCol); - // // } - // // for (u64 i = 0; i < mMessageSize; ++i) - // // R.push_back(i, i); - - // // return R; - // //} - // //else - // //{ - // // return mExpander.getB(); - // //} - //} - - - // Get the parity check version of the accumulator - //SparseMtx ExConvCode2::getAPar() const - //{ - // throw RTE_LOC; - // //PRNG prng(mSeed ^ OneBlock); - - // //auto n = mCodeSize - mSystematic * mMessageSize; - - // //PointList AP(n, n);; - // //DenseMtx A = DenseMtx::Identity(n); - - // //block rnd; - // //u8* __restrict ptr = (u8*)prng.mBuffer.data(); - // //auto qe = prng.mBuffer.size() * 128; - // //u64 q = 0; - - // //for (u64 i = 0; i < n; ++i) - // //{ - // // accOne(AP, i, ptr, prng, rnd, q, qe, n); - // //} - // //return AP; - //} - - //// get the accumulator matrix - //SparseMtx ExConvCode2::getA() const - //{ - // auto APar = getAPar(); - - // auto A = DenseMtx::Identity(mCodeSize); - - // u64 offset = mSystematic ? mMessageSize : 0ull; - - // for (u64 i = 0; i < APar.rows(); ++i) - // { - // for (auto y : APar.col(i)) - // { - // if (y != i) - // { - // auto ay = A.row(y + offset); - // auto ai = A.row(i + offset); - // ay ^= ai; - // } - // } - // } - - // return A.sparse(); - //} - - -} \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode2.h b/libOTe/Tools/ExConvCode/ExConvCode2.h deleted file mode 100644 index fb7e1bea..00000000 --- a/libOTe/Tools/ExConvCode/ExConvCode2.h +++ /dev/null @@ -1,626 +0,0 @@ -// © 2023 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -#pragma once - -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/Timer.h" -#include "libOTe/Tools/ExConvCode/Expander2.h" -#include "libOTe/Tools/EACode/Util.h" -#include "libOTe/Tools/Subfield/Subfield.h" - -namespace osuCrypto -{ - - template - struct has_operator_star : std::false_type - {}; - - template - struct has_operator_star < T, std::void_t< - // must have a operator*() member fn - decltype(std::declval().operator*()) - >> - : std::true_type{}; - - - template - struct is_iterator : std::false_type - {}; - - template - struct is_iterator < T, std::void_t< - // must have a operator*() member fn - // or be a pointer - std::enable_if_t< - has_operator_star::value || - std::is_pointer_v> - > - >> - : std::true_type{}; - - // The encoder for the generator matrix G = B * A. dualEncode(...) is the main function - // config(...) should be called first. - // - // B is the expander while A is the convolution. - // - // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly - // with fixed row weight mExpanderWeight. - // - // A is a lower triangular n by n matrix with ones on the diagonal. The - // mAccumulatorSize diagonals left of the main diagonal are uniformly random. - // If mStickyAccumulator, then the first diagonal left of the main is always ones. - // - // See ExConvCode2Instantiations.cpp for how to instantiate new types that - // dualEncode can be called on. - // - // https://eprint.iacr.org/2023/882 - class ExConvCode2 : public TimerAdapter - { - public: - ExpanderCode2 mExpander; - - // configure the code. The default parameters are choses to balance security and performance. - // For additional parameter choices see the paper. - void config( - u64 messageSize, - u64 codeSize, - u64 expanderWeight = 7, - u64 accumulatorWeight = 16, - bool systematic = true, - block seed = block(9996754675674599, 56756745976768754)); - - // the seed that generates the code. - block mSeed = ZeroBlock; - - // The message size of the code. K. - u64 mMessageSize = 0; - - // The codeword size of the code. n. - u64 mCodeSize = 0; - - // The size of the accumulator. - u64 mAccumulatorSize = 0; - - // is the code systematic (true=faster) - bool mSystematic = true; - - // return n-k. code size n, message size k. - u64 parityRows() const { return mCodeSize - mMessageSize; } - - // return code size n. - u64 parityCols() const { return mCodeSize; } - - // return message size k. - u64 generatorRows() const { return mMessageSize; } - - // return code size n. - u64 generatorCols() const { return mCodeSize; } - - //// Compute w = G * e. e will be modified in the computation. - //// the computation will be done over F using CoeffCtx::plus - //template< - // typename F, - // typename CoeffCtx, - // typename SrcIter, - // typename DstIter - //> - //void dualEncode(SrcIter&& e, DstIter&& w); - - // Compute e[0,...,k-1] = G * e. - // the computation will be done over F using CoeffCtx::plus - template< - typename F, - typename CoeffCtx, - typename Iter - > - void dualEncode(Iter&& e); - - // Compute e[0,...,k-1] = G * e. - template< - typename F, - typename G, - typename CoeffCtx, - typename IterF, - typename IterG - > - void dualEncode2(IterF&& e0, IterG&& e1) - { - dualEncode(e0); - dualEncode(e1); - } - - // Private functions ------------------------------------ - - static void refill(PRNG& prng) - { - assert(prng.mBuffer.size() == 256); - //block b[8]; - for (u64 i = 0; i < 256; i += 8) - { - block* __restrict b = prng.mBuffer.data() + i; - block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); - - b[0] = AES::roundEnc(b[0], k[0]); - b[1] = AES::roundEnc(b[1], k[1]); - b[2] = AES::roundEnc(b[2], k[2]); - b[3] = AES::roundEnc(b[3], k[3]); - b[4] = AES::roundEnc(b[4], k[4]); - b[5] = AES::roundEnc(b[5], k[5]); - b[6] = AES::roundEnc(b[6], k[6]); - b[7] = AES::roundEnc(b[7], k[7]); - } - } - - // take x[i] and add it to the next 8 positions if the flag b is 1. - // - // xx[j] += b[j] * x[i] - // - template< - typename F, - typename CoeffCtx, - bool rangeCheck, - typename Iter - > - OC_FORCEINLINE void accOne8( - Iter&& xi, - Iter&& xj, - Iter&& end, - u8 b); - - // accumulating row i. - template< - typename F, - typename CoeffCtx, - bool rangeCheck, - int AccumulatorSize, - typename Iter - > - OC_FORCEINLINE void accOne( - Iter&& xi, - Iter&& end, - u8* matrixCoeff, - std::integral_constant); - - // accumulating row i. generic version - template< - typename F, - typename CoeffCtx, - bool rangeCheck, - typename Iter - > - OC_FORCEINLINE void accOne( - Iter&& xi, - Iter&& end, - u8* matrixCoeff, - std::integral_constant); - - - // accumulate x onto itself. - template< - typename F, - typename CoeffCtx, - typename Iter - > - void accumulate(Iter x) - { - switch (mAccumulatorSize) - { - case 16: - accumulateFixed(std::forward(x)); - break; - case 24: - accumulateFixed(std::forward(x)); - break; - default: - // generic case - accumulateFixed(std::forward(x)); - } - } - - // accumulate x onto itself. - template< - typename F, - typename CoeffCtx, - u64 AccumulatorSize, - typename Iter - > - void accumulateFixed(Iter x); - - }; -} - - -//#include "ExConvCode2Impl.h" - -namespace osuCrypto -{ - - // Compute w = G * e. e will be modified in the computation. - //template< - // typename F, - // typename CoeffCtx, - // typename SrcIter, - // typename DstIter - //> - //void ExConvCode2::dualEncode(SrcIter&& e, DstIter&& w) - //{ - - // static_assert(is_iterator::value, "must pass in an iterator to the data, " __FUNCTION__); - // static_assert(is_iterator::value, "must pass in an iterator to the data"); - - // // try to deref the back. might bounds check. - // (void)*(e + mCodeSize - 1); - // (void)*(w + mMessageSize - 1); - - // if (mSystematic) - // { - // dualEncode(e); - // CoeffCtx::copy(w, w + mMessageSize, e); - // setTimePoint("ExConv.encode.memcpy"); - // } - // else - // { - - // setTimePoint("ExConv.encode.begin"); - - // accumulate(e); - - // setTimePoint("ExConv.encode.accumulate"); - - // mExpander.expand(e, w); - // setTimePoint("ExConv.encode.expand"); - // } - //} - - // Compute e[0,...,k-1] = G * e. - template - void ExConvCode2::dualEncode(Iter&& e) - { - static_assert(is_iterator::value, "must pass in an iterator to the data"); - - (void)*(e + mCodeSize - 1); - - if (mSystematic) - { - auto d = e + mMessageSize; - setTimePoint("ExConv.encode.begin"); - accumulate(d); - setTimePoint("ExConv.encode.accumulate"); - mExpander.expand(d, e); - setTimePoint("ExConv.encode.expand"); - } - else - { - - setTimePoint("ExConv.encode.begin"); - accumulate(e); - setTimePoint("ExConv.encode.accumulate"); - - CoeffCtx::template Vec w; - CoeffCtx::resize(w, mMessageSize); - mExpander.expand(e, w.begin()); - setTimePoint("ExConv.encode.expand"); - - CoeffCtx::copy(w.begin(), w.end(), e); - setTimePoint("ExConv.encode.memcpy"); - - } - } - - - //// Compute e[0,...,k-1] = G * e. - //template - //void ExConvCode2::dualEncode2(span e0, span e1) - //{ - // if (e0.size() != mCodeSize) - // throw RTE_LOC; - // if (e1.size() != mCodeSize) - // throw RTE_LOC; - - // if (mSystematic) - // { - // auto d0 = e0.subspan(mMessageSize); - // auto d1 = e1.subspan(mMessageSize); - // setTimePoint("ExConv.encode.begin"); - // accumulate(d0, d1); - // setTimePoint("ExConv.encode.accumulate"); - // mExpander.expand( - // d0, d1, - // e0.subspan(0, mMessageSize), - // e1.subspan(0, mMessageSize)); - // setTimePoint("ExConv.encode.expand"); - // } - // else - // { - // //oc::AlignedUnVector w0(mMessageSize); - // //dualEncode(e, w); - // //memcpy(e.data(), w.data(), w.size() * sizeof(T)); - // //setTimePoint("ExConv.encode.memcpy"); - - // // not impl. - // throw RTE_LOC; - // } - //} - - -#ifdef ENABLE_SSE - using My__m128 = __m128; -#else - using My__m128 = block; - - inline My__m128 _mm_load_ps(float* b) { return *(block*)b; } - - // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_ps&ig_expand=557 - inline My__m128 _mm_blendv_ps(My__m128 a, My__m128 b, My__m128 mask) - { - My__m128 dst; - for (u64 j = 0; j < 4; ++j) - { - if (mask.get(j) < 0) - dst.set(j, b.get(j)); - else - dst.set(j, a.get(j)); - } - return dst; - } - - - inline My__m128 _mm_setzero_ps() { return ZeroBlock; } -#endif - - // take x[i] and add it to the next 8 positions if the flag b is 1. - // - // xx[j] += b[j] * x[i] - // - template< - typename F, - typename CoeffCtx, - bool rangeCheck, - typename Iter - > - OC_FORCEINLINE void ExConvCode2::accOne8( - Iter&& xi, - Iter&& xj, - Iter&& end, - u8 b) - { - if constexpr (std::is_same::value) - { - block rnd = block::allSame(b); - - block bshift[8]; - bshift[0] = rnd.slli_epi32<7>(); - bshift[1] = rnd.slli_epi32<6>(); - bshift[2] = rnd.slli_epi32<5>(); - bshift[3] = rnd.slli_epi32<4>(); - bshift[4] = rnd.slli_epi32<3>(); - bshift[5] = rnd.slli_epi32<2>(); - bshift[6] = rnd.slli_epi32<1>(); - bshift[7] = rnd; - - My__m128 bb[8]; - auto xii = _mm_load_ps((float*)(&*xi)); - My__m128 Zero = _mm_setzero_ps(); - - // bbj = bj - bb[0] = _mm_load_ps((float*)&bshift[0]); - bb[1] = _mm_load_ps((float*)&bshift[1]); - bb[2] = _mm_load_ps((float*)&bshift[2]); - bb[3] = _mm_load_ps((float*)&bshift[3]); - bb[4] = _mm_load_ps((float*)&bshift[4]); - bb[5] = _mm_load_ps((float*)&bshift[5]); - bb[6] = _mm_load_ps((float*)&bshift[6]); - bb[7] = _mm_load_ps((float*)&bshift[7]); - - // bbj = bj * xi - bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); - bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); - bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); - bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); - bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); - bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); - bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); - bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); - - block tt[8]; - memcpy(tt, bb, 8 * 16); - - assert((((b >> 0) & 1) ? *xi : ZeroBlock) == tt[0]); - assert((((b >> 1) & 1) ? *xi : ZeroBlock) == tt[1]); - assert((((b >> 2) & 1) ? *xi : ZeroBlock) == tt[2]); - assert((((b >> 3) & 1) ? *xi : ZeroBlock) == tt[3]); - assert((((b >> 4) & 1) ? *xi : ZeroBlock) == tt[4]); - assert((((b >> 5) & 1) ? *xi : ZeroBlock) == tt[5]); - assert((((b >> 6) & 1) ? *xi : ZeroBlock) == tt[6]); - assert((((b >> 7) & 1) ? *xi : ZeroBlock) == tt[7]); - - // xj += bj * xi - if (rangeCheck && xj + 0 == end) return; CoeffCtx::plus(*(xj + 0), *(xj + 0), tt[0]); - if (rangeCheck && xj + 1 == end) return; CoeffCtx::plus(*(xj + 1), *(xj + 1), tt[1]); - if (rangeCheck && xj + 2 == end) return; CoeffCtx::plus(*(xj + 2), *(xj + 2), tt[2]); - if (rangeCheck && xj + 3 == end) return; CoeffCtx::plus(*(xj + 3), *(xj + 3), tt[3]); - if (rangeCheck && xj + 4 == end) return; CoeffCtx::plus(*(xj + 4), *(xj + 4), tt[4]); - if (rangeCheck && xj + 5 == end) return; CoeffCtx::plus(*(xj + 5), *(xj + 5), tt[5]); - if (rangeCheck && xj + 6 == end) return; CoeffCtx::plus(*(xj + 6), *(xj + 6), tt[6]); - if (rangeCheck && xj + 7 == end) return; CoeffCtx::plus(*(xj + 7), *(xj + 7), tt[7]); - } - else - { - auto b0 = b & 1; - auto b1 = b & 2; - auto b2 = b & 4; - auto b3 = b & 8; - auto b4 = b & 16; - auto b5 = b & 32; - auto b6 = b & 64; - auto b7 = b & 128; - - if (rangeCheck && xj + 0 == end) return; if (b0) CoeffCtx::plus(*(xj + 0), *(xj + 0), *xi); - if (rangeCheck && xj + 1 == end) return; if (b1) CoeffCtx::plus(*(xj + 1), *(xj + 1), *xi); - if (rangeCheck && xj + 2 == end) return; if (b2) CoeffCtx::plus(*(xj + 2), *(xj + 2), *xi); - if (rangeCheck && xj + 3 == end) return; if (b3) CoeffCtx::plus(*(xj + 3), *(xj + 3), *xi); - if (rangeCheck && xj + 4 == end) return; if (b4) CoeffCtx::plus(*(xj + 4), *(xj + 4), *xi); - if (rangeCheck && xj + 5 == end) return; if (b5) CoeffCtx::plus(*(xj + 5), *(xj + 5), *xi); - if (rangeCheck && xj + 6 == end) return; if (b6) CoeffCtx::plus(*(xj + 6), *(xj + 6), *xi); - if (rangeCheck && xj + 7 == end) return; if (b7) CoeffCtx::plus(*(xj + 7), *(xj + 7), *xi); - } - } - - - - template< - typename F, - typename CoeffCtx, - bool rangeCheck, - typename Iter - > - OC_FORCEINLINE void ExConvCode2::accOne( - Iter&& xi, - Iter&& end, - u8* matrixCoeff, - std::integral_constant _) - { - - // xj += xi - std::remove_reference_t xj = xi + 1; - if (!rangeCheck || xj < end) - { - CoeffCtx::plus(*xj, *xj, *xi); - ++xj; - } - - // xj += bj * xi - u64 k = 0; - for (; k < mAccumulatorSize - 7; k += 8) - { - accOne8(xi, xj, end, *matrixCoeff++); - - if constexpr (rangeCheck) - { - auto r = end - xj; - xj += std::min(r, 8); - } - else - { - xj += 8; - } - } - for (; k < mAccumulatorSize; ) - { - auto b = *matrixCoeff++; - - for (u64 j = 0; j < 8 && k < mAccumulatorSize; ++j, ++k) - { - - if (rangeCheck == false || (xj != end)) - { - if (b & 1) - CoeffCtx::plus(*xj, *xj, *xi); - - ++xj; - b >>= 1; - } - - } - } - } - - - // add xi to all of the future locations - template< - typename F, - typename CoeffCtx, - bool rangeCheck, - int AccumulatorSize, - typename Iter - > - OC_FORCEINLINE void ExConvCode2::accOne( - Iter&& xi, - Iter&& end, - u8* matrixCoeff, - std::integral_constant) - { - static_assert(AccumulatorSize, "should have called the other overload"); - static_assert(AccumulatorSize % 8 == 0, "must be a multiple of 8"); - - // xj += xi - std::remove_reference_t xj = xi + 1; - if (!rangeCheck || xj < end) - { - CoeffCtx::plus(*xj, *xj, *xi); - ++xj; - } - - // xj += bj * xi - for (u64 k = 0; k < AccumulatorSize; k += 8) - { - accOne8(xi, xj, end, *matrixCoeff++); - - if constexpr (rangeCheck) - { - auto r = end - xj; - xj += std::min(r, 8); - } - else - { - xj += 8; - } - } - } - - // accumulate x onto itself. - template< - typename F, - typename CoeffCtx, - u64 AccumulatorSize, - typename Iter - > - void ExConvCode2::accumulateFixed(Iter xi) - { - auto end = xi + (mCodeSize - mSystematic * mMessageSize); - auto main = end - 1 - mAccumulatorSize; - - PRNG prng(mSeed ^ OneBlock); - u8* mtxCoeffIter = (u8*)prng.mBuffer.data(); - auto mtxCoeffEnd = mtxCoeffIter + prng.mBuffer.size() * sizeof(block) - divCeil(mAccumulatorSize, 8); - - // AccumulatorSize == 0 is the generic case, otherwise - // AccumulatorSize should be equal to mAccumulatorSize. - static_assert(AccumulatorSize % 8 == 0); - if (AccumulatorSize && mAccumulatorSize != AccumulatorSize) - throw RTE_LOC; - - while (xi < main) - { - if (mtxCoeffIter > mtxCoeffEnd) - { - // generate more mtx coefficients - refill(prng); - mtxCoeffIter = (u8*)prng.mBuffer.data(); - } - - // add xi to the next positions - accOne(xi, end, mtxCoeffIter++, std::integral_constant{}); - ++xi; - } - - while (xi < end) - { - if (mtxCoeffIter > mtxCoeffEnd) - { - // generate more mtx coefficients - refill(prng); - mtxCoeffIter = (u8*)prng.mBuffer.data(); - } - - // add xi to the next positions - accOne(xi, end, mtxCoeffIter++, std::integral_constant{}); - ++xi; - } - } - -} \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode2Impl.h b/libOTe/Tools/ExConvCode/ExConvCode2Impl.h deleted file mode 100644 index da722947..00000000 --- a/libOTe/Tools/ExConvCode/ExConvCode2Impl.h +++ /dev/null @@ -1,3 +0,0 @@ -#pragma once -#include "ExConvCode2.h" - diff --git a/libOTe/Tools/ExConvCode/ExConvCodeInstantiations.cpp b/libOTe/Tools/ExConvCode/ExConvCodeInstantiations.cpp deleted file mode 100644 index 1b91615e..00000000 --- a/libOTe/Tools/ExConvCode/ExConvCodeInstantiations.cpp +++ /dev/null @@ -1,17 +0,0 @@ - -#define EXCONVCODE_INSTANTIATIONS -#include "ExConvCode.cpp" - -namespace osuCrypto -{ - - template void ExConvCode::dualEncode(span e); - template void ExConvCode::dualEncode(span e); - template void ExConvCode::dualEncode(span e, span w); - template void ExConvCode::dualEncode(span e, span w); - template void ExConvCode::dualEncode2(span, span e); - template void ExConvCode::dualEncode2(span, span e); - - template void ExConvCode::accumulate(span, span e); - template void ExConvCode::accumulate(span, span e); -} diff --git a/libOTe/Tools/ExConvCode/Expander2.h b/libOTe/Tools/ExConvCode/Expander.h similarity index 97% rename from libOTe/Tools/ExConvCode/Expander2.h rename to libOTe/Tools/ExConvCode/Expander.h index b2a35ba2..c952b044 100644 --- a/libOTe/Tools/ExConvCode/Expander2.h +++ b/libOTe/Tools/ExConvCode/Expander.h @@ -17,7 +17,7 @@ namespace osuCrypto // The encoder for the expander matrix B. // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly // with fixed row weight mExpanderWeight. - class ExpanderCode2 + class ExpanderCode { public: @@ -66,15 +66,6 @@ namespace osuCrypto ) const; - template< - typename F, - typename CoeffCtx - > - typename CoeffCtx::template Vec getB(CoeffCtx ctx = {}) const; - - - - //template< // bool Add, // typename CoeffCtx, @@ -96,7 +87,7 @@ namespace osuCrypto typename SrcIter, typename DstIter > - void ExpanderCode2::expand( + void ExpanderCode::expand( SrcIter&& input, DstIter&& output, CoeffCtx ctx) const @@ -162,7 +153,7 @@ namespace osuCrypto // typename... F, // typename... SrcDstIterPair //> - //void ExpanderCode2::expandMany( + //void ExpanderCode::expandMany( // std::tuple inOuts, // CoeffCtx ctx)const //{ diff --git a/libOTe/Tools/Pprf/PprfUtil.h b/libOTe/Tools/Pprf/PprfUtil.h new file mode 100644 index 00000000..03b70f21 --- /dev/null +++ b/libOTe/Tools/Pprf/PprfUtil.h @@ -0,0 +1,266 @@ +#pragma once +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Aligned.h" +#include +#include + +namespace osuCrypto +{ + + + // the various formats that the output of the + // Pprf can be generated. + enum class PprfOutputFormat + { + // The i'th row holds the i'th leaf for all trees. + // The j'th tree is in the j'th column. + ByLeafIndex, + + // The i'th row holds the i'th tree. + // The j'th leaf is in the j'th column. + ByTreeIndex, + + // The native output mode. The output will be + // a single row with all leaf values. + // Every 8 trees are mixed together where the + // i'th leaf for each of the 8 tree will be next + // to each other. For example, let tij be the j'th + // leaf of the i'th tree. If we have m leaves, then + // + // t00 t10 ... t70 t01 t11 ... t71 ... t0m t1m ... t7m + // t80 t90 ... t_{15,0} t81 t91 ... t_{15,1} ... t8m t9m ... t_{15,m} + // ... + // + // These are all flattened into a single row. + Interleaved, + + // call the user's callback. The leaves will be in + // Interleaved format. + Callback + }; + + + + namespace pprf + { + + template + void copyOut( + VecF& leaf, + VecF& output, + u64 totalTrees, + u64 treeIndex, + PprfOutputFormat oFormat, + std::function& callback) + { + auto curSize = std::min(totalTrees - treeIndex, 8); + auto domain = leaf.size() / 8; + if (oFormat == PprfOutputFormat::ByLeafIndex) + { + if (curSize == 8) + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + auto oIdx = totalTrees * leafIndex + treeIndex; + auto iIdx = leafIndex * 8; + output[oIdx + 0] = leaf[iIdx + 0]; + output[oIdx + 1] = leaf[iIdx + 1]; + output[oIdx + 2] = leaf[iIdx + 2]; + output[oIdx + 3] = leaf[iIdx + 3]; + output[oIdx + 4] = leaf[iIdx + 4]; + output[oIdx + 5] = leaf[iIdx + 5]; + output[oIdx + 6] = leaf[iIdx + 6]; + output[oIdx + 7] = leaf[iIdx + 7]; + } + } + else + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + //auto oi = output[leafIndex].subspan(treeIndex, curSize); + //auto& ii = leaf[leafIndex]; + auto oIdx = totalTrees * leafIndex + treeIndex; + auto iIdx = leafIndex * 8; + for (u64 j = 0; j < curSize; ++j) + output[oIdx + j] = leaf[iIdx + j]; + } + } + + } + else if (oFormat == PprfOutputFormat::ByTreeIndex) + { + + if (curSize == 8) + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + auto iIdx = leafIndex * 8; + + output[(treeIndex + 0) * domain + leafIndex] = leaf[iIdx + 0]; + output[(treeIndex + 1) * domain + leafIndex] = leaf[iIdx + 1]; + output[(treeIndex + 2) * domain + leafIndex] = leaf[iIdx + 2]; + output[(treeIndex + 3) * domain + leafIndex] = leaf[iIdx + 3]; + output[(treeIndex + 4) * domain + leafIndex] = leaf[iIdx + 4]; + output[(treeIndex + 5) * domain + leafIndex] = leaf[iIdx + 5]; + output[(treeIndex + 6) * domain + leafIndex] = leaf[iIdx + 6]; + output[(treeIndex + 7) * domain + leafIndex] = leaf[iIdx + 7]; + } + } + else + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + auto iIdx = leafIndex * 8; + for (u64 j = 0; j < curSize; ++j) + output[(treeIndex + j) * domain + leafIndex] = leaf[iIdx + j]; + } + } + + } + else if (oFormat == PprfOutputFormat::Callback) + callback(treeIndex, leaf); + else + throw RTE_LOC; + } + + template + void allocateExpandBuffer( + u64 depth, + u64 numTrees, + bool programPuncturedPoint, + std::vector& buff, + span>& sums, + span& leaf, + CoeffCtx& ctx) + { + + u64 elementSize = ctx.byteSize(); + + // num of bytes they will take up. + u64 numBytes = + depth * numTrees * sizeof(std::array) + // each internal level of the tree has two sums + elementSize * numTrees * 2 + // we must program numTrees inactive F leaves + elementSize * numTrees * 2 * programPuncturedPoint; // if we are programing the active lead, then we have numTrees more. + + // allocate the buffer and partition them. + buff.resize(numBytes); + sums = span>((std::array*)buff.data(), depth * numTrees); + leaf = span((u8*)(sums.data() + sums.size()), + elementSize * numTrees * 2 + + elementSize * numTrees * 2 * programPuncturedPoint + ); + + void* sEnd = sums.data() + sums.size(); + void* lEnd = leaf.data() + leaf.size(); + void* end = buff.data() + buff.size(); + if (sEnd > end || lEnd != end) + throw RTE_LOC; + } + + template + void validateExpandFormat( + PprfOutputFormat oFormat, + VecF& output, + u64 domain, + u64 pntCount) + { + if (oFormat == PprfOutputFormat::Interleaved && pntCount % 8) + throw std::runtime_error("For Interleaved output format, pointCount must be a multiple of 8 (general case not impl). " LOCATION); + + + switch (oFormat) + { + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + case osuCrypto::PprfOutputFormat::Interleaved: + if (output.size() != domain * pntCount) + throw RTE_LOC; + break; + case osuCrypto::PprfOutputFormat::Callback: + if (output.size()) + throw RTE_LOC; + break; + default: + throw RTE_LOC; + break; + } + + } + + + struct TreeAllocator + { + TreeAllocator() = default; + TreeAllocator(const TreeAllocator&) = delete; + TreeAllocator(TreeAllocator&&) = default; + + using ValueType = AlignedArray; + std::list> mTrees; + std::vector> mFreeTrees; + //std::mutex mMutex; + u64 mTreeSize = 0, mNumTrees = 0; + + void reserve(u64 num, u64 size) + { + //std::lock_guard lock(mMutex); + mTreeSize = size; + mNumTrees += num; + mTrees.clear(); + mFreeTrees.clear(); + mTrees.emplace_back(num * size); + auto iter = mTrees.back().data(); + for (u64 i = 0; i < num; ++i) + { + mFreeTrees.push_back(span(iter, size)); + assert((u64)mFreeTrees.back().data() % 32 == 0); + iter += size; + } + } + + span get() + { + //std::lock_guard lock(mMutex); + if (mFreeTrees.size() == 0) + { + assert(mTreeSize); + mTrees.emplace_back(mTreeSize); + mFreeTrees.push_back(span(mTrees.back().data(), mTreeSize)); + assert((u64)mFreeTrees.back().data() % 32 == 0); + ++mNumTrees; + } + + auto ret = mFreeTrees.back(); + mFreeTrees.pop_back(); + return ret; + } + + void clear() + { + mTrees = {}; + mFreeTrees = {}; + mTreeSize = 0; + mNumTrees = 0; + } + }; + + + inline void allocateExpandTree( + TreeAllocator& alloc, + std::vector>>& levels) + { + span> tree = alloc.get(); + assert((u64)tree.data() % 32 == 0); + levels[0] = tree.subspan(0, 1); + auto rem = tree.subspan(2); + for (auto i = 1ull; i < levels.size(); ++i) + { + levels[i] = rem.subspan(0, levels[i - 1].size() * 2); + assert((u64)levels[i].data() % 32 == 0); + rem = rem.subspan(levels[i].size()); + } + } + + + } + +} \ No newline at end of file diff --git a/libOTe/Tools/Subfield/SubfieldPprf.cpp b/libOTe/Tools/Pprf/RegularPprf.cpp similarity index 100% rename from libOTe/Tools/Subfield/SubfieldPprf.cpp rename to libOTe/Tools/Pprf/RegularPprf.cpp diff --git a/libOTe/Tools/Subfield/SubfieldPprf.h b/libOTe/Tools/Pprf/RegularPprf.h similarity index 60% rename from libOTe/Tools/Subfield/SubfieldPprf.h rename to libOTe/Tools/Pprf/RegularPprf.h index 91a2b0ae..fc37f009 100644 --- a/libOTe/Tools/Subfield/SubfieldPprf.h +++ b/libOTe/Tools/Pprf/RegularPprf.h @@ -7,182 +7,25 @@ #include "cryptoTools/Common/Range.h" #include "cryptoTools/Crypto/PRNG.h" #include "libOTe/Tools/Coproto.h" -#include "libOTe/Tools/SilentPprf.h" -#include "SubfieldPprf.h" #include -#include "libOTe/Tools/Subfield/Subfield.h" +#include "libOTe/Tools/CoeffCtx.h" +#include "PprfUtil.h" namespace osuCrypto { extern const std::array gGgmAes; - inline void allocateExpandTree( - TreeAllocator& alloc, - span>& tree, - std::vector>>& levels) - { - tree = alloc.get(); - assert((u64)tree.data() % 32 == 0); - levels[0] = tree.subspan(0, 1); - auto rem = tree.subspan(2); - for (auto i : rng(1ull, levels.size())) - { - levels[i] = rem.subspan(0, levels[i - 1].size() * 2); - assert((u64)levels[i].data() % 32 == 0); - rem = rem.subspan(levels[i].size()); - } - } - - template - void copyOut( - VecF& leaf, - VecF& output, - u64 totalTrees, - u64 treeIndex, - PprfOutputFormat oFormat, - std::function& callback) - { - auto curSize = std::min(totalTrees - treeIndex, 8); - auto domain = leaf.size() / 8; - if (oFormat == PprfOutputFormat::ByLeafIndex) - { - if (curSize == 8) - { - for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) - { - auto oIdx = totalTrees * leafIndex + treeIndex; - auto iIdx = leafIndex * 8; - output[oIdx + 0] = leaf[iIdx + 0]; - output[oIdx + 1] = leaf[iIdx + 1]; - output[oIdx + 2] = leaf[iIdx + 2]; - output[oIdx + 3] = leaf[iIdx + 3]; - output[oIdx + 4] = leaf[iIdx + 4]; - output[oIdx + 5] = leaf[iIdx + 5]; - output[oIdx + 6] = leaf[iIdx + 6]; - output[oIdx + 7] = leaf[iIdx + 7]; - } - } - else - { - for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) - { - //auto oi = output[leafIndex].subspan(treeIndex, curSize); - //auto& ii = leaf[leafIndex]; - auto oIdx = totalTrees * leafIndex + treeIndex; - auto iIdx = leafIndex * 8; - for (u64 j = 0; j < curSize; ++j) - output[oIdx + j] = leaf[iIdx + j]; - } - } - - } - else if (oFormat == PprfOutputFormat::ByTreeIndex) - { - - if (curSize == 8) - { - for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) - { - auto iIdx = leafIndex * 8; - - output[(treeIndex + 0) * domain + leafIndex] = leaf[iIdx + 0]; - output[(treeIndex + 1) * domain + leafIndex] = leaf[iIdx + 1]; - output[(treeIndex + 2) * domain + leafIndex] = leaf[iIdx + 2]; - output[(treeIndex + 3) * domain + leafIndex] = leaf[iIdx + 3]; - output[(treeIndex + 4) * domain + leafIndex] = leaf[iIdx + 4]; - output[(treeIndex + 5) * domain + leafIndex] = leaf[iIdx + 5]; - output[(treeIndex + 6) * domain + leafIndex] = leaf[iIdx + 6]; - output[(treeIndex + 7) * domain + leafIndex] = leaf[iIdx + 7]; - } - } - else - { - for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) - { - auto iIdx = leafIndex * 8; - for (u64 j = 0; j < curSize; ++j) - output[(treeIndex + j) * domain + leafIndex] = leaf[iIdx + j]; - } - } - - } - else if (oFormat == PprfOutputFormat::Callback) - callback(treeIndex, leaf); - else - throw RTE_LOC; - } - - template - void allocateExpandBuffer( - u64 depth, - bool programPuncturedPoint, - std::vector& buff, - span, 2>>& sums, - span& leaf, - CoeffCtx& ctx) - { - - u64 elementSize = ctx.byteSize(); - - using SumType = std::array, 2>; - // num of bytes they will take up. - u64 numBytes = - depth * sizeof(SumType) + // each internal level of the tree has a sum - elementSize * 8 * 2 + // we must program 8 inactive F leaves - elementSize * 8 * 2 * programPuncturedPoint; // if we are programing the active lead, then we have 8 more. - - // allocate the buffer and partition them. - buff.resize(numBytes); - sums = span((SumType*)buff.data(), depth); - leaf = span((u8*)(sums.data() + sums.size()), - elementSize * 8 * 2 + - elementSize * 8 * 2 * programPuncturedPoint - ); - - void* sEnd = sums.data() + sums.size(); - void* lEnd = leaf.data() + leaf.size(); - void* end = buff.data() + buff.size(); - if (sEnd > end || lEnd != end) - throw RTE_LOC; - } - - template - void validateExpandFormat( - PprfOutputFormat oFormat, - VecF& output, - u64 domain, - u64 pntCount) - { - switch (oFormat) - { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - case osuCrypto::PprfOutputFormat::Interleaved: - if (output.size() != domain * pntCount) - throw RTE_LOC; - break; - case osuCrypto::PprfOutputFormat::Callback: - if (output.size()) - throw RTE_LOC; - break; - default: - throw RTE_LOC; - break; - } - - } template< typename F, typename G = F, typename CoeffCtx = DefaultCoeffCtx > - class SilentSubfieldPprfSender : public TimerAdapter { + class RegularPprfSender : public TimerAdapter { public: u64 mDomain = 0, mDepth = 0, mPntCount = 0; std::vector mValue; - TreeAllocator mTreeAlloc; Matrix> mBaseOTs; using VecF = typename CoeffCtx::template Vec; @@ -191,13 +34,13 @@ namespace osuCrypto std::function mOutputFn; - SilentSubfieldPprfSender() = default; + RegularPprfSender() = default; - SilentSubfieldPprfSender(const SilentSubfieldPprfSender&) = delete; + RegularPprfSender(const RegularPprfSender&) = delete; - SilentSubfieldPprfSender(SilentSubfieldPprfSender&&) = delete; + RegularPprfSender(RegularPprfSender&&) = delete; - SilentSubfieldPprfSender(u64 domainSize, u64 pointCount) { + RegularPprfSender(u64 domainSize, u64 pointCount) { configure(domainSize, pointCount); } @@ -205,15 +48,12 @@ namespace osuCrypto { if (domainSize & 1) throw std::runtime_error("Pprf domain must be even. " LOCATION); - if (domainSize < 4) - throw std::runtime_error("Pprf domain must must be at least 4. " LOCATION); - if (mPntCount % 8) - throw std::runtime_error("pointCount must be a multiple of 8 (general case not impl). " LOCATION); + if (domainSize < 2) + throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); mDomain = domainSize; mDepth = log2ceil(mDomain); mPntCount = pointCount; - //mPntCount8 = roundUpTo(pointCount, 8); mBaseOTs.resize(0, 0); } @@ -239,11 +79,6 @@ namespace osuCrypto mBaseOTs(i) = baseMessages[i]; } - //task<> expand(Socket& chls, span value, block seed, span output, PprfOutputFormat oFormat, - // bool programPuncturedPoint, u64 numThreads) { - // MatrixView o(output.data(), output.size(), 1); - // return expand(chls, value, seed, o, oFormat, programPuncturedPoint, numThreads); - //} task<> expand( Socket& chl, const VecF& value, @@ -259,7 +94,7 @@ namespace osuCrypto setTimePoint("SilentMultiPprfSender.start"); - validateExpandFormat(oFormat, output, mDomain, mPntCount); + pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); MC_BEGIN(task<>, this, numThreads, oFormat, &output, seed, &chl, programPuncturedPoint, ctx, treeIndex = u64{}, @@ -269,15 +104,16 @@ namespace osuCrypto leafLevelPtr = (VecF*)nullptr, leafLevel = VecF{}, buff = std::vector{}, - encSums = span, 2>>{}, - leafMsgs = span{} + encSums = span>{}, + leafMsgs = span{}, + mTreeAlloc = pprf::TreeAllocator{} ); mTreeAlloc.reserve(numThreads, (1ull << mDepth) + 2); setTimePoint("SilentMultiPprfSender.reserve"); levels.resize(mDepth); - allocateExpandTree(mTreeAlloc, tree, levels); + pprf::allocateExpandTree(mTreeAlloc, levels); for (treeIndex = 0; treeIndex < mPntCount; treeIndex += 8) { @@ -298,7 +134,8 @@ namespace osuCrypto } // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs, ctx); + pprf::allocateExpandBuffer( + mDepth - 1, std::min(8, mPntCount - treeIndex), programPuncturedPoint, buff, encSums, leafMsgs, ctx); // exapnd the tree expandOne(seed, treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs, ctx); @@ -308,12 +145,12 @@ namespace osuCrypto // if we aren't interleaved, we need to copy the // leaf layer to the output. if (oFormat != PprfOutputFormat::Interleaved) - copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); + pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); } mBaseOTs = {}; - mTreeAlloc.del(tree); + //mTreeAlloc.del(tree); mTreeAlloc.clear(); setTimePoint("SilentMultiPprfSender.de-alloc"); @@ -350,10 +187,11 @@ namespace osuCrypto span>> levels, VecF& leafLevel, const u64 leafOffset, - span, 2>> encSums, + span> encSums, span leafMsgs, CoeffCtx ctx) { + auto remTrees = std::min(8, mPntCount - treeIdx); // the first level should be size 1, the root of the tree. // we will populate it with random seeds using aesSeed in counter mode @@ -361,7 +199,8 @@ namespace osuCrypto assert(levels[0].size() == 1); mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); - assert(encSums.size() == mDepth - 1); + assert(encSums.size() == (mDepth - 1) * remTrees); + auto encSumIter = encSums.begin(); // space for our sums of each level. Should always be less then // 24 levels... If not increase the limit or make it a vector. @@ -450,13 +289,14 @@ namespace osuCrypto } // encrypt the sums and write them to the output. - for (u64 j = 0; j < 8; ++j) + for (u64 j = 0; j < remTrees; ++j) { - encSums[d][0][j] = sums[0][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][1]; - encSums[d][1][j] = sums[1][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][0]; + (*encSumIter)[0] = sums[0][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][1]; + (*encSumIter)[1] = sums[1][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][0]; + ++encSumIter; } } - + assert(encSumIter == encSums.end()); auto d = mDepth - 1; @@ -532,7 +372,7 @@ namespace osuCrypto ctx.resize(leafOts, 2); PRNG otMasker; - for (u64 j = 0; j < 8; ++j) + for (u64 j = 0; j < remTrees; ++j) { // we will construct two OT strings. Let // s0, s1 be the left and right child sums. @@ -575,7 +415,7 @@ namespace osuCrypto ctx.resize(leafOts, 1); PRNG otMasker; - for (u64 j = 0; j < 8; ++j) + for (u64 j = 0; j < remTrees; ++j) { for (u64 k = 0; k < 2; ++k) { @@ -606,35 +446,31 @@ namespace osuCrypto typename G = F, typename CoeffCtx = DefaultCoeffCtx > - class SilentSubfieldPprfReceiver : public TimerAdapter + class RegularPprfReceiver : public TimerAdapter { public: u64 mDomain = 0, mDepth = 0, mPntCount = 0; using VecF = typename CoeffCtx::template Vec; using VecG = typename CoeffCtx::template Vec; - std::vector mPoints; + //std::vector mPoints; Matrix mBaseOTs; Matrix mBaseChoices; - TreeAllocator mTreeAlloc; - std::function mOutputFn; - SilentSubfieldPprfReceiver() = default; - SilentSubfieldPprfReceiver(const SilentSubfieldPprfReceiver&) = delete; - SilentSubfieldPprfReceiver(SilentSubfieldPprfReceiver&&) = delete; + RegularPprfReceiver() = default; + RegularPprfReceiver(const RegularPprfReceiver&) = delete; + RegularPprfReceiver(RegularPprfReceiver&&) = delete; void configure(u64 domainSize, u64 pointCount) { if (domainSize & 1) throw std::runtime_error("Pprf domain must be even. " LOCATION); - if (domainSize < 4) - throw std::runtime_error("Pprf domain must must be at least 4. " LOCATION); - if (mPntCount % 8) - throw std::runtime_error("pointCount must be a multiple of 8 (general case not impl). " LOCATION); + if (domainSize < 2) + throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); mDomain = domainSize; mDepth = log2ceil(mDomain); @@ -753,7 +589,7 @@ namespace osuCrypto { auto subTree = j % 8; auto batch = j / 8; - points[j] = (batch * mDomain + points[j]) * 8 + subTree; + points[j] = (batch * mDomain + points[j]) * 8 + subTree; } //interleavedPoints(points, mDomain, format); @@ -776,29 +612,30 @@ namespace osuCrypto u64 numThreads, CoeffCtx ctx = {}) { - validateExpandFormat(oFormat, output, mDomain, mPntCount); + pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); MC_BEGIN(task<>, this, oFormat, &output, &chl, programPuncturedPoint, ctx, treeIndex = u64{}, - tree = span>{}, levels = std::vector>>{}, leafIndex = u64{}, leafLevelPtr = (VecF*)nullptr, leafLevel = VecF{}, buff = std::vector{}, - encSums = span, 2>>{}, - leafMsgs = span{} + encSums = span>{}, + leafMsgs = span{}, + mTreeAlloc = pprf::TreeAllocator{}, + points = std::vector{} ); setTimePoint("SilentMultiPprfReceiver.start"); - mPoints.resize(roundUpTo(mPntCount, 8)); - getPoints(mPoints, PprfOutputFormat::ByLeafIndex); + points.resize(mPntCount); + getPoints(points, PprfOutputFormat::ByLeafIndex); mTreeAlloc.reserve(1, (1ull << mDepth) + 2); setTimePoint("SilentMultiPprfSender.reserve"); levels.resize(mDepth); - allocateExpandTree(mDepth, mTreeAlloc, tree, levels); + pprf::allocateExpandTree(mTreeAlloc, levels); for (treeIndex = 0; treeIndex < mPntCount; treeIndex += 8) { @@ -819,23 +656,24 @@ namespace osuCrypto } // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth - 1, programPuncturedPoint, buff, encSums, leafMsgs, ctx); + pprf::allocateExpandBuffer(mDepth - 1, std::min(8, mPntCount - treeIndex), + programPuncturedPoint, buff, encSums, leafMsgs, ctx); MC_AWAIT(chl.recv(buff)); // exapnd the tree - expandOne(treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs, ctx); + expandOne(treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs, points, ctx); // if we aren't interleaved, we need to copy the // leaf layer to the output. if (oFormat != PprfOutputFormat::Interleaved) - copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); + pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); } setTimePoint("SilentMultiPprfReceiver.join"); mBaseOTs = {}; - mTreeAlloc.del(tree); + //mTreeAlloc.del(tree); mTreeAlloc.clear(); setTimePoint("SilentMultiPprfReceiver.de-alloc"); @@ -852,6 +690,14 @@ namespace osuCrypto mPntCount = 0; } + void expandOneInternal( + u64 treeIdx, + span>> levels, + span, 2>> theirSums, + CoeffCtx& ctx) + { + } + //treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs void expandOne( u64 treeIdx, @@ -859,231 +705,241 @@ namespace osuCrypto span>> levels, VecF& leafLevel, const u64 outputOffset, - span, 2>> theirSums, + span> theirSums, span leafMsg, - CoeffCtx ctx) + span points, + CoeffCtx& ctx) { - // We will process 8 trees at a time. + auto remTrees = std::min(8, mPntCount - treeIdx); + assert(theirSums.size() == remTrees * (mDepth - 1)); - // special case for the first level. - auto l1 = levels[1]; - for (u64 i = 0; i < 8; ++i) + // We change the hash function for the leaf so lets update + // inactiveChildValues to use the new hash and subtract + // these from the leafSums + std::array, 2> leafSums; + if (mDepth > 1) { - // For the non-active path, set the child of the root node - // as the OT message XOR'ed with the correction sum. - - int active = mBaseChoices[i + treeIdx].back(); - l1[active ^ 1][i] = mBaseOTs[i + treeIdx].back() ^ theirSums[0][active ^ 1][i]; - l1[active][i] = ZeroBlock; - //if (!i) - // std::cout << " unmask " - // << mBaseOTs[i + treeIdx].back() << " ^ " - // << theirSums[0][active ^ 1][i] << " = " - // << l1[active ^ 1][i] << std::endl; - - } - - // space for our sums of each level. - std::array, 2> mySums; + auto theirSumsIter = theirSums.begin(); - // this will be the value of both children of active an parent - // before the active child is updated. We will need to subtract - // this value as the main loop does not distinguish active parents. - std::array inactiveChildValues; - inactiveChildValues[0] = AES::roundEnc(mAesFixedKey.ecbEncBlock(ZeroBlock), ZeroBlock); - inactiveChildValues[1] = mAesFixedKey.ecbEncBlock(ZeroBlock); - - // For all other levels, expand the GGM tree and add in - // the correction along the active path. - for (u64 d = 1; d < mDepth - 1; ++d) - { - // initialized the sums with inactiveChildValue so that - // it will cancel when we expand the actual inactive child. - std::fill(mySums[0].begin(), mySums[0].end(), inactiveChildValues[0]); - std::fill(mySums[1].begin(), mySums[1].end(), inactiveChildValues[1]); + // special case for the first level. + auto l1 = levels[1]; + for (u64 i = 0; i < remTrees; ++i) + { + // For the non-active path, set the child of the root node + // as the OT message XOR'ed with the correction sum. + + int active = mBaseChoices[i + treeIdx].back(); + l1[active ^ 1][i] = mBaseOTs[i + treeIdx].back() ^ (*theirSumsIter)[active ^ 1]; + l1[active][i] = ZeroBlock; + ++theirSumsIter; + //if (!i) + // std::cout << " unmask " + // << mBaseOTs[i + treeIdx].back() << " ^ " + // << theirSums[0][active ^ 1][i] << " = " + // << l1[active ^ 1][i] << std::endl; - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); + } - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = levels[d]; + // space for our sums of each level. + std::array, 2> mySums; - // The next level that we want to construct. - auto level1 = levels[d + 1]; + // this will be the value of both children of active an parent + // before the active child is updated. We will need to subtract + // this value as the main loop does not distinguish active parents. + std::array inactiveChildValues; + inactiveChildValues[0] = AES::roundEnc(mAesFixedKey.ecbEncBlock(ZeroBlock), ZeroBlock); + inactiveChildValues[1] = mAesFixedKey.ecbEncBlock(ZeroBlock); - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) + // For all other levels, expand the GGM tree and add in + // the correction along the active path. + for (u64 d = 1; d < mDepth - 1; ++d) { - // The value of the parent. - auto parent = level0[parentIdx]; + // initialized the sums with inactiveChildValue so that + // it will cancel when we expand the actual inactive child. + std::fill(mySums[0].begin(), mySums[0].end(), inactiveChildValues[0]); + std::fill(mySums[1].begin(), mySums[1].end(), inactiveChildValues[1]); - auto& child0 = level1.data()[childIdx]; - auto& child1 = level1.data()[childIdx + 1]; - mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); - // inspired by the Expand Accumualte idea to - // use - // - // child0 = AES(parent) ^ parent - // child1 = AES(parent) + parent - // - // but instead we are a bit more conservative and - // compute - // - // child0 = AES:Round(AES(parent), parent) - // = AES:Round(AES(parent), 0) ^ parent - // child1 = AES(parent) + parent - // - // That is, we applies an additional AES round function - // to the first child before XORing it with parent. - child0[0] = AES::roundEnc(child1[0], parent[0]); - child0[1] = AES::roundEnc(child1[1], parent[1]); - child0[2] = AES::roundEnc(child1[2], parent[2]); - child0[3] = AES::roundEnc(child1[3], parent[3]); - child0[4] = AES::roundEnc(child1[4], parent[4]); - child0[5] = AES::roundEnc(child1[5], parent[5]); - child0[6] = AES::roundEnc(child1[6], parent[6]); - child0[7] = AES::roundEnc(child1[7], parent[7]); + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent but this will cancel - // with inactiveChildValue thats already there. - mySums[0][0] = mySums[0][0] ^ child0[0]; - mySums[0][1] = mySums[0][1] ^ child0[1]; - mySums[0][2] = mySums[0][2] ^ child0[2]; - mySums[0][3] = mySums[0][3] ^ child0[3]; - mySums[0][4] = mySums[0][4] ^ child0[4]; - mySums[0][5] = mySums[0][5] ^ child0[5]; - mySums[0][6] = mySums[0][6] ^ child0[6]; - mySums[0][7] = mySums[0][7] ^ child0[7]; + // The next level that we want to construct. + auto level1 = levels[d + 1]; - // child1 = AES(parent) + parent - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) + { + // The value of the parent. + auto parent = level0[parentIdx]; + + auto& child0 = level1.data()[childIdx]; + auto& child1 = level1.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + // inspired by the Expand Accumualte idea to + // use + // + // child0 = AES(parent) ^ parent + // child1 = AES(parent) + parent + // + // but instead we are a bit more conservative and + // compute + // + // child0 = AES:Round(AES(parent), parent) + // = AES:Round(AES(parent), 0) ^ parent + // child1 = AES(parent) + parent + // + // That is, we applies an additional AES round function + // to the first child before XORing it with parent. + child0[0] = AES::roundEnc(child1[0], parent[0]); + child0[1] = AES::roundEnc(child1[1], parent[1]); + child0[2] = AES::roundEnc(child1[2], parent[2]); + child0[3] = AES::roundEnc(child1[3], parent[3]); + child0[4] = AES::roundEnc(child1[4], parent[4]); + child0[5] = AES::roundEnc(child1[5], parent[5]); + child0[6] = AES::roundEnc(child1[6], parent[6]); + child0[7] = AES::roundEnc(child1[7], parent[7]); + + // Update the running sums for this level. We keep + // a left and right totals for each level. Note that + // we are actually XOR in the incorrect value of the + // children of the active parent but this will cancel + // with inactiveChildValue thats already there. + mySums[0][0] = mySums[0][0] ^ child0[0]; + mySums[0][1] = mySums[0][1] ^ child0[1]; + mySums[0][2] = mySums[0][2] ^ child0[2]; + mySums[0][3] = mySums[0][3] ^ child0[3]; + mySums[0][4] = mySums[0][4] ^ child0[4]; + mySums[0][5] = mySums[0][5] ^ child0[5]; + mySums[0][6] = mySums[0][6] ^ child0[6]; + mySums[0][7] = mySums[0][7] ^ child0[7]; + + // child1 = AES(parent) + parent + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + mySums[1][0] = mySums[1][0] ^ child1[0]; + mySums[1][1] = mySums[1][1] ^ child1[1]; + mySums[1][2] = mySums[1][2] ^ child1[2]; + mySums[1][3] = mySums[1][3] ^ child1[3]; + mySums[1][4] = mySums[1][4] ^ child1[4]; + mySums[1][5] = mySums[1][5] ^ child1[5]; + mySums[1][6] = mySums[1][6] ^ child1[6]; + mySums[1][7] = mySums[1][7] ^ child1[7]; - mySums[1][0] = mySums[1][0] ^ child1[0]; - mySums[1][1] = mySums[1][1] ^ child1[1]; - mySums[1][2] = mySums[1][2] ^ child1[2]; - mySums[1][3] = mySums[1][3] ^ child1[3]; - mySums[1][4] = mySums[1][4] ^ child1[4]; - mySums[1][5] = mySums[1][5] ^ child1[5]; - mySums[1][6] = mySums[1][6] ^ child1[6]; - mySums[1][7] = mySums[1][7] ^ child1[7]; + } - } + // we have to update the non-active child of the active parent. + for (u64 i = 0; i < remTrees; ++i) + { + // the index of the leaf node that is active. + auto leafIdx = points[i + treeIdx]; - // we have to update the non-active child of the active parent. - for (u64 i = 0; i < 8; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = mPoints[i + treeIdx]; + // The index of the active (missing) child node. + auto missingChildIdx = leafIdx >> (mDepth - 1 - d); - // The index of the active (missing) child node. - auto missingChildIdx = leafIdx >> (mDepth - 1 - d); + // The index of the active child node sibling. + auto siblingIdx = missingChildIdx ^ 1; - // The index of the active child node sibling. - auto siblingIdx = missingChildIdx ^ 1; + // The indicator as to the left or right child is inactive + auto notAi = siblingIdx & 1; - // The indicator as to the left or right child is inactive - auto notAi = siblingIdx & 1; - - // our sums & OTs cancel and we are leaf with the - // correct value for the inactive child. - level1[siblingIdx][i] = - theirSums[d][notAi][i] ^ - mySums[notAi][i] ^ - mBaseOTs[i + treeIdx][mDepth - 1 - d]; - - // we have to set the active child to zero so - // the next children are predictable. - level1[missingChildIdx][i] = ZeroBlock; - } - } + // our sums & OTs cancel and we are leaf with the + // correct value for the inactive child. + level1[siblingIdx][i] = + (*theirSumsIter)[notAi] ^ + mySums[notAi][i] ^ + mBaseOTs[i + treeIdx][mDepth - 1 - d]; + ++theirSumsIter; - auto d = mDepth - 1; - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = levels[d]; + // we have to set the active child to zero so + // the next children are predictable. + level1[missingChildIdx][i] = ZeroBlock; + } + } - // The next level of theGGM tree that we are populating. - std::array child; + auto d = mDepth - 1; + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); + // The next level of theGGM tree that we are populating. + std::array child; - // We change the hash function for the leaf so lets update - // inactiveChildValues to use the new hash and subtract - // these from the leafSums - CoeffCtx::template Vec temp; - ctx.resize(temp, 2); - std::array, 2> leafSums; - for (u64 k = 0; k < 2; ++k) - { - inactiveChildValues[k] = gGgmAes[k].hashBlock(ZeroBlock); - ctx.fromBlock(temp[k], inactiveChildValues[k]); + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + CoeffCtx::template Vec temp; + ctx.resize(temp, 2); + for (u64 k = 0; k < 2; ++k) + { + ctx.resize(leafSums[k], 8); + ctx.zero(leafSums[k].begin(), leafSums[k].end()); + ctx.fromBlock(temp[k], gGgmAes[k].hashBlock(ZeroBlock)); + ctx.minus(leafSums[k][0], leafSums[k][0], temp[k]); + for (u64 i = 1; i < 8; ++i) + ctx.copy(leafSums[k][i], leafSums[k][0]); + } + // for leaf nodes both children should be hashed. + for (u64 parentIdx = 0, childIdx = 0, outputIdx = outputOffset; parentIdx < width; ++parentIdx) + { + // The value of the parent. + auto parent = level0[parentIdx]; - // leafSum = -inactiveChildValues - ctx.resize(leafSums[k], 8); - ctx.zero(leafSums[k].begin(), leafSums[k].end()); - ctx.minus(leafSums[k][0], leafSums[k][0], temp[k]); - for (u64 i = 1; i < 8; ++i) - ctx.copy(leafSums[k][i], leafSums[k][0]); + for (u64 keep = 0; keep < 2; ++keep, ++childIdx, outputIdx += 8) + { + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); + + ctx.fromBlock(leafLevel[outputIdx + 0], child[0]); + ctx.fromBlock(leafLevel[outputIdx + 1], child[1]); + ctx.fromBlock(leafLevel[outputIdx + 2], child[2]); + ctx.fromBlock(leafLevel[outputIdx + 3], child[3]); + ctx.fromBlock(leafLevel[outputIdx + 4], child[4]); + ctx.fromBlock(leafLevel[outputIdx + 5], child[5]); + ctx.fromBlock(leafLevel[outputIdx + 6], child[6]); + ctx.fromBlock(leafLevel[outputIdx + 7], child[7]); + + auto& leafSum = leafSums[keep]; + ctx.plus(leafSum[0], leafSum[0], leafLevel[outputIdx + 0]); + ctx.plus(leafSum[1], leafSum[1], leafLevel[outputIdx + 1]); + ctx.plus(leafSum[2], leafSum[2], leafLevel[outputIdx + 2]); + ctx.plus(leafSum[3], leafSum[3], leafLevel[outputIdx + 3]); + ctx.plus(leafSum[4], leafSum[4], leafLevel[outputIdx + 4]); + ctx.plus(leafSum[5], leafSum[5], leafLevel[outputIdx + 5]); + ctx.plus(leafSum[6], leafSum[6], leafLevel[outputIdx + 6]); + ctx.plus(leafSum[7], leafSum[7], leafLevel[outputIdx + 7]); + } + } } - - // for leaf nodes both children should be hashed. - for (u64 parentIdx = 0, childIdx = 0, outputIdx = outputOffset; parentIdx < width; ++parentIdx) + else { - // The value of the parent. - auto parent = level0[parentIdx]; - - for (u64 keep = 0; keep < 2; ++keep, ++childIdx, outputIdx += 8) + for (u64 k = 0; k < 2; ++k) { - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); - - ctx.fromBlock(leafLevel[outputIdx + 0], child[0]); - ctx.fromBlock(leafLevel[outputIdx + 1], child[1]); - ctx.fromBlock(leafLevel[outputIdx + 2], child[2]); - ctx.fromBlock(leafLevel[outputIdx + 3], child[3]); - ctx.fromBlock(leafLevel[outputIdx + 4], child[4]); - ctx.fromBlock(leafLevel[outputIdx + 5], child[5]); - ctx.fromBlock(leafLevel[outputIdx + 6], child[6]); - ctx.fromBlock(leafLevel[outputIdx + 7], child[7]); - - - - auto& leafSum = leafSums[keep]; - ctx.plus(leafSum[0], leafSum[0], leafLevel[outputIdx + 0]); - ctx.plus(leafSum[1], leafSum[1], leafLevel[outputIdx + 1]); - ctx.plus(leafSum[2], leafSum[2], leafLevel[outputIdx + 2]); - ctx.plus(leafSum[3], leafSum[3], leafLevel[outputIdx + 3]); - ctx.plus(leafSum[4], leafSum[4], leafLevel[outputIdx + 4]); - ctx.plus(leafSum[5], leafSum[5], leafLevel[outputIdx + 5]); - ctx.plus(leafSum[6], leafSum[6], leafLevel[outputIdx + 6]); - ctx.plus(leafSum[7], leafSum[7], leafLevel[outputIdx + 7]); + ctx.resize(leafSums[k], 8); + ctx.zero(leafSums[k].begin(), leafSums[k].end()); } } @@ -1100,11 +956,11 @@ namespace osuCrypto ctx.resize(leafOts, 2); PRNG otMasker; - for (u64 j = 0; j < 8; ++j) + for (u64 j = 0; j < remTrees; ++j) { // The index of the child on the active path. - auto activeChildIdx = mPoints[j + treeIdx]; + auto activeChildIdx = points[j + treeIdx]; // The index of the other (inactive) child. auto inactiveChildIdx = activeChildIdx ^ 1; @@ -1113,7 +969,7 @@ namespace osuCrypto auto notAi = inactiveChildIdx & 1; // offset to the first or second ot message, based on the one we want - auto offset = CoeffCtx::template byteSize() * 2 * notAi; + auto offset = ctx.byteSize() * 2 * notAi; // decrypt the ot string @@ -1138,10 +994,10 @@ namespace osuCrypto ctx.resize(leafOts, 1); PRNG otMasker; - for (u64 j = 0; j < 8; ++j) + for (u64 j = 0; j < remTrees; ++j) { // The index of the child on the active path. - auto activeChildIdx = mPoints[j + treeIdx]; + auto activeChildIdx = points[j + treeIdx]; // The index of the other (inactive) child. auto inactiveChildIdx = activeChildIdx ^ 1; @@ -1150,7 +1006,7 @@ namespace osuCrypto auto notAi = inactiveChildIdx & 1; // offset to the first or second ot message, based on the one we want - auto offset = CoeffCtx::template byteSize() * notAi; + auto offset = ctx.byteSize() * notAi; // decrypt the ot string span buff = leafMsg.subspan(offset, ctx.byteSize()); diff --git a/libOTe/Tools/SilentPprf.cpp b/libOTe/Tools/SilentPprf.cpp deleted file mode 100644 index d9d5e968..00000000 --- a/libOTe/Tools/SilentPprf.cpp +++ /dev/null @@ -1,995 +0,0 @@ -#include "SilentPprf.h" -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - -#include -#include -#include -#include -#include - -namespace osuCrypto -{ - - void SilentMultiPprfSender::setBase(span> baseMessages) - { - if (baseOtCount() != static_cast(baseMessages.size())) - throw RTE_LOC; - - mBaseOTs.resize(mPntCount, mDepth); - for (u64 i = 0; i < static_cast(mBaseOTs.size()); ++i) - mBaseOTs(i) = baseMessages[i]; - } - - void SilentMultiPprfReceiver::setBase(span baseMessages) - { - if (baseOtCount() != static_cast(baseMessages.size())) - throw RTE_LOC; - - // The OTs are used in blocks of 8, so make sure that there is a whole - // number of blocks. - mBaseOTs.resize(roundUpTo(mPntCount, 8), mDepth); - memcpy(mBaseOTs.data(), baseMessages.data(), baseMessages.size() * sizeof(block)); - } - - // This function copies the leaf values of the GGM tree - // to the output location. There are two modes for this - // function. If interleaved == false, then each tree is - // copied to a different contiguous regions of the output. - // If interleaved == true, then trees are interleaved such that .... - // @lvl - the GGM tree leafs. - // @output - the location that the GGM leafs should be written to. - // @numTrees - How many trees there are in total. - // @tIdx - the index of the first tree. - // @oFormat - do we interleave the output? - // @mal - ... - void copyOut( - span> lvl, - MatrixView output, - u64 totalTrees, - u64 tIdx, - PprfOutputFormat oFormat, - std::function> lvl)>& callback) - { - if (oFormat == PprfOutputFormat::ByLeafIndex) - { - - auto curSize = std::min(totalTrees - tIdx, 8); - if (curSize == 8) - { - - for (u64 i = 0; i < output.rows(); ++i) - { - auto oi = output[i].subspan(tIdx, 8); - auto& ii = lvl[i]; - oi[0] = ii[0]; - oi[1] = ii[1]; - oi[2] = ii[2]; - oi[3] = ii[3]; - oi[4] = ii[4]; - oi[5] = ii[5]; - oi[6] = ii[6]; - oi[7] = ii[7]; - } - } - else - { - for (u64 i = 0; i < output.rows(); ++i) - { - auto oi = output[i].subspan(tIdx, curSize); - auto& ii = lvl[i]; - for (u64 j = 0; j < curSize; ++j) - oi[j] = ii[j]; - } - } - - } - else if (oFormat == PprfOutputFormat::ByTreeIndex) - { - - auto curSize = std::min(totalTrees - tIdx, 8); - if (curSize == 8) - { - for (u64 i = 0; i < output.cols(); ++i) - { - auto& ii = lvl[i]; - output(tIdx + 0, i) = ii[0]; - output(tIdx + 1, i) = ii[1]; - output(tIdx + 2, i) = ii[2]; - output(tIdx + 3, i) = ii[3]; - output(tIdx + 4, i) = ii[4]; - output(tIdx + 5, i) = ii[5]; - output(tIdx + 6, i) = ii[6]; - output(tIdx + 7, i) = ii[7]; - } - } - else - { - for (u64 i = 0; i < output.cols(); ++i) - { - auto& ii = lvl[i]; - for (u64 j = 0; j < curSize; ++j) - output(tIdx + j, i) = ii[j]; - } - } - - } - else if (oFormat == PprfOutputFormat::Callback) - callback(tIdx, lvl); - else - throw RTE_LOC; - } - - u64 interleavedPoint(u64 point, u64 treeIdx, u64 totalTrees, u64 domain, PprfOutputFormat format) - { - switch (format) - { - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - { - - if (domain <= point) - return ~u64(0); - - auto subTree = treeIdx % 8; - auto forest = treeIdx / 8; - - return (forest * domain + point) * 8 + subTree; - } - break; - default: - throw RTE_LOC; - break; - } - //auto totalTrees = points.size(); - - } - - void interleavedPoints(span points, u64 domain, PprfOutputFormat format) - { - - for (u64 i = 0; i < points.size(); ++i) - { - points[i] = interleavedPoint(points[i], i, points.size(), domain, format); - } - } - - u64 getActivePath(const span& choiceBits) - { - u64 point = 0; - for (u64 i = 0; i < choiceBits.size(); ++i) - { - auto shift = choiceBits.size() - i - 1; - - point |= u64(1 ^ choiceBits[i]) << shift; - } - return point; - } - - void SilentMultiPprfReceiver::getPoints(span points, PprfOutputFormat format) - { - - switch (format) - { - case PprfOutputFormat::ByLeafIndex: - case PprfOutputFormat::ByTreeIndex: - - memset(points.data(), 0, points.size() * sizeof(u64)); - for (u64 j = 0; j < mPntCount; ++j) - { - points[j] = getActivePath(mBaseChoices[j]); - } - - break; - case PprfOutputFormat::Interleaved: - case PprfOutputFormat::Callback: - - if ((u64)points.size() != mPntCount) - throw RTE_LOC; - if (points.size() % 8) - throw RTE_LOC; - - getPoints(points, PprfOutputFormat::ByLeafIndex); - interleavedPoints(points, mDomain, format); - - break; - default: - throw RTE_LOC; - break; - } - } - - - BitVector SilentMultiPprfReceiver::sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng) - { - BitVector choices(mPntCount * mDepth); - - // The points are read in blocks of 8, so make sure that there is a - // whole number of blocks. - mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); - for (u64 i = 0; i < mPntCount; ++i) - { - u64 idx; - switch (format) - { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - } while (idx >= modulus); - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - - if (modulus > mPntCount * mDomain) - throw std::runtime_error("modulus too big. " LOCATION); - if (modulus < mPntCount * mDomain / 2) - throw std::runtime_error("modulus too small. " LOCATION); - - // make sure that at least the first element of this tree - // is within the modulus. - idx = interleavedPoint(0, i, mPntCount, mDomain, format); - if (idx >= modulus) - throw RTE_LOC; - - - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - - idx = interleavedPoint(idx, i, mPntCount, mDomain, format); - } while (idx >= modulus); - - - break; - default: - throw RTE_LOC; - break; - } - - } - - for (u64 i = 0; i < mBaseChoices.size(); ++i) - { - choices[i] = mBaseChoices(i); - } - - return choices; - } - - void SilentMultiPprfReceiver::setChoiceBits(PprfOutputFormat format, BitVector choices) - { - // Make sure we're given the right number of OTs. - if (choices.size() != baseOtCount()) - throw RTE_LOC; - - mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); - for (u64 i = 0; i < mPntCount; ++i) - { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = choices[mDepth * i + j]; - - switch (format) - { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - if (getActivePath(mBaseChoices[i]) >= mDomain) - throw RTE_LOC; - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - { - auto idx = getActivePath(mBaseChoices[i]); - auto idx2 = interleavedPoint(idx, i, mPntCount, mDomain, format); - if(idx2 > mPntCount * mDomain) - throw std::runtime_error("the base ot choice bits index outside of the domain. see sampleChoiceBits(...). " LOCATION); - break; - } - default: - throw RTE_LOC; - break; - } - } - } - - namespace - { - // A public PRF/PRG that we will use for deriving the GGM tree. - const std::array gAes = []() { - std::array aes; - aes[0].setKey(toBlock(3242342)); - aes[1].setKey(toBlock(8993849)); - return aes; - }(); - } - - - void SilentMultiPprfSender::expandOne( - block aesSeed, - u64 treeIdx, - bool programActivePath, - span>> levels, - span, 2>> encSums, - span> lastOts) - { - // The number of real trees for this iteration. - auto min = std::min(8, mPntCount - treeIdx); - - // the first level should be size 1, the root of the tree. - // we will populate it with random seeds using aesSeed in counter mode - // based on the tree index. - assert(levels[0].size() == 1); - mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); - - assert(encSums.size() == mDepth - programActivePath); - assert(encSums.size() < 24); - - // space for our sums of each level. Should always be less then - // 24 levels... If not increase the limit or make it a vector. - std::array, 2>, 24> sums; - memset(&sums, 0, sizeof(sums)); - - // For each level perform the following. - for (u64 d = 0; d < mDepth; ++d) - { - // The previous level of the GGM tree. - auto level0 = levels[d]; - - // The next level of theGGM tree that we are populating. - auto level1 = levels[d + 1]; - - // The total number of parents in this level. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - // use the optimized approach for intern nodes of the tree - if (d + 1 < mDepth && 0) - { - // For each child, populate the child by expanding the parent. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) - { - // The value of the parent. - auto& parent = level0.data()[parentIdx]; - - auto& child0 = level1.data()[childIdx]; - auto& child1 = level1.data()[childIdx + 1]; - mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); - - child0[0] = child1[0] ^ parent[0]; - child0[1] = child1[1] ^ parent[1]; - child0[2] = child1[2] ^ parent[2]; - child0[3] = child1[3] ^ parent[3]; - child0[4] = child1[4] ^ parent[4]; - child0[5] = child1[5] ^ parent[5]; - child0[6] = child1[6] ^ parent[6]; - child0[7] = child1[7] ^ parent[7]; - - // Update the running sums for this level. We keep - // a left and right totals for each level. - auto& sum = sums[d]; - sum[0][0] = sum[0][0] ^ child0[0]; - sum[0][1] = sum[0][1] ^ child0[1]; - sum[0][2] = sum[0][2] ^ child0[2]; - sum[0][3] = sum[0][3] ^ child0[3]; - sum[0][4] = sum[0][4] ^ child0[4]; - sum[0][5] = sum[0][5] ^ child0[5]; - sum[0][6] = sum[0][6] ^ child0[6]; - sum[0][7] = sum[0][7] ^ child0[7]; - - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; - - sum[1][0] = sum[1][0] ^ child1[0]; - sum[1][1] = sum[1][1] ^ child1[1]; - sum[1][2] = sum[1][2] ^ child1[2]; - sum[1][3] = sum[1][3] ^ child1[3]; - sum[1][4] = sum[1][4] ^ child1[4]; - sum[1][5] = sum[1][5] ^ child1[5]; - sum[1][6] = sum[1][6] ^ child1[6]; - sum[1][7] = sum[1][7] ^ child1[7]; - - } - } - else - { - // for the leaf nodes we need to hash both children. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto& parent = level0.data()[parentIdx]; - - // The bit that indicates if we are on the left child (0) - // or on the right child (1). - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - // The child that we will write in this iteration. - auto& child = level1[childIdx]; - - // The sum that this child node belongs to. - auto& sum = sums[d][keep]; - - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gAes[keep].hashBlocks<8>(parent.data(), child.data()); - - // Update the running sums for this level. We keep - // a left and right totals for each level. - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } - } - } - } - - // For all but the last level, mask the sums with the - // OT strings and send them over. - for (u64 d = 0; d < mDepth - programActivePath; ++d) - { - for (u64 j = 0; j < min; ++j) - { - encSums[d][0][j] = sums[d][0][j] ^ mBaseOTs[treeIdx + j][d][0]; - encSums[d][1][j] = sums[d][1][j] ^ mBaseOTs[treeIdx + j][d][1]; - } - } - - if (programActivePath) - { - // For the last level, we are going to do something special. - // The other party is currently missing both leaf children of - // the active parent. Since this is the last level, we want - // the inactive child to just be the normal value but the - // active child should be the correct value XOR the delta. - // This will be done by sending the sums and the sums plus - // delta and ensure that they can only decrypt the correct ones. - auto d = mDepth - 1; - assert(lastOts.size() == min); - for (u64 j = 0; j < min; ++j) - { - // Construct the sums where we will allow the delta (mValue) - // to either be on the left child or right child depending - // on which has the active path. - lastOts[j][0] = sums[d][0][j]; - lastOts[j][1] = sums[d][1][j] ^ mValue[treeIdx + j]; - lastOts[j][2] = sums[d][1][j]; - lastOts[j][3] = sums[d][0][j] ^ mValue[treeIdx + j]; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - std::array masks, maskIn; - maskIn[0] = mBaseOTs[treeIdx + j][d][0]; - maskIn[1] = mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; - maskIn[2] = mBaseOTs[treeIdx + j][d][1]; - maskIn[3] = mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; - mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); - - // Add the OT masks to the sums and send them over. - lastOts[j][0] = lastOts[j][0] ^ masks[0]; - lastOts[j][1] = lastOts[j][1] ^ masks[1]; - lastOts[j][2] = lastOts[j][2] ^ masks[2]; - lastOts[j][3] = lastOts[j][3] ^ masks[3]; - } - } - } - - void allocateExpandBuffer( - u64 depth, - u64 activeChildXorDelta, - std::vector& buff, - span< std::array, 2>>& sums, - span< std::array>& last) - { - - using SumType = std::array, 2>; - using LastType = std::array; - u64 numSums = depth - activeChildXorDelta; - u64 numLast = activeChildXorDelta * 8; - u64 numBlocks = numSums * 16 + numLast * 4; - buff.resize(numBlocks); - sums = span((SumType*)buff.data(), numSums); - last = span((LastType*)(sums.data() + sums.size()), numLast); - - void* sEnd = sums.data() + sums.size(); - void* lEnd = last.data() + last.size(); - void* end = buff.data() + buff.size(); - if (sEnd > end || lEnd > end) - throw RTE_LOC; - } - - void allocateExpandTree( - u64 depth, - TreeAllocator& alloc, - span>& tree, - std::vector>>& levels) - { - tree = alloc.get(); - assert((u64)tree.data() % 32 == 0); - levels[0] = tree.subspan(0, 1); - auto rem = tree.subspan(2); - for (auto i : rng(1ull, depth)) - { - levels[i] = rem.subspan(0, levels[i - 1].size() * 2); - assert((u64)levels[i].data() % 32 == 0); - rem = rem.subspan(levels[i].size()); - } - } - - void validateExpandFormat( - PprfOutputFormat oFormat, - MatrixView output, - u64 domain, - u64 pntCount - ) - { - - if (oFormat == PprfOutputFormat::ByLeafIndex) - { - if (output.rows() != domain) - throw RTE_LOC; - - if (output.cols() != pntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::ByTreeIndex) - { - if (output.cols() != domain) - throw RTE_LOC; - - if (output.rows() != pntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Interleaved) - { - if (output.cols() != 1) - throw RTE_LOC; - if (domain & 1) - throw RTE_LOC; - - auto rows = output.rows(); - if (rows > (domain * pntCount) || - rows / 128 != (domain * pntCount) / 128) - throw RTE_LOC; - if (pntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Callback) - { - if (domain & 1) - throw RTE_LOC; - if (pntCount & 7) - throw RTE_LOC; - } - else - { - throw RTE_LOC; - } - } - - - task<> SilentMultiPprfSender::expand( - Socket& chl, - span value, - block seed, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 numThreads) - { - if (activeChildXorDelta) - setValue(value); - - setTimePoint("SilentMultiPprfSender.start"); - - validateExpandFormat(oFormat, output, mDomain, mPntCount); - - MC_BEGIN(task<>, this, numThreads, oFormat, output, seed, &chl, activeChildXorDelta, - i = u64{}, - mTreeAllocDepth = u64{}, - tree = span>{}, - levels = std::vector>>{}, - buff = std::vector{}, - sums = span, 2>>{}, - last = span>{} - ); - - //if (oFormat == PprfOutputFormat::Callback && numThreads > 1) - // throw RTE_LOC; - - mTreeAllocDepth = mDepth + (oFormat != PprfOutputFormat::Interleaved); - mTreeAlloc.reserve(numThreads, (1ull << mTreeAllocDepth) + 2); - setTimePoint("SilentMultiPprfSender.reserve"); - - levels.resize(mDepth + 1); - allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); - - for (i = 0; i < mPntCount; i += 8) - { - // for interleaved format, the last level of the tree - // is simply the output. - if (oFormat == PprfOutputFormat::Interleaved) - { - auto b = (AlignedArray*)output.data(); - auto forest = i / 8; - b += forest * mDomain; - - levels.back() = span>(b, mDomain); - } - - // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); - - // exapnd the tree - expandOne(seed, i, activeChildXorDelta, levels, sums, last); - - MC_AWAIT(chl.send(std::move(buff))); - - // if we aren't interleaved, we need to copy the - // last layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) - copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); - } - - mBaseOTs = {}; - mTreeAlloc.del(tree); - mTreeAlloc.clear(); - - setTimePoint("SilentMultiPprfSender.de-alloc"); - - MC_END(); - } - - void SilentMultiPprfSender::setValue(span value) - { - mValue.resize(mPntCount); - - if (value.size() == 1) - { - std::fill(mValue.begin(), mValue.end(), value[0]); - } - else - { - if ((u64)value.size() != mPntCount) - throw RTE_LOC; - - std::copy(value.begin(), value.end(), mValue.begin()); - } - } - - void SilentMultiPprfSender::clear() - { - mBaseOTs.resize(0, 0); - mDomain = 0; - mDepth = 0; - mPntCount = 0; - } - - void SilentMultiPprfReceiver::expandOne( - - u64 treeIdx, - bool programActivePath, - span>> levels, - span, 2>> theirSums, - span> lastOts) - { - // This thread will process 8 trees at a time. - - // special case for the first level. - auto l1 = levels[1]; - for (u64 i = 0; i < 8; ++i) - { - - // For the non-active path, set the child of the root node - // as the OT message XOR'ed with the correction sum. - int notAi = mBaseChoices[i + treeIdx][0]; - l1[notAi][i] = mBaseOTs[i + treeIdx][0] ^ theirSums[0][notAi][i]; - l1[notAi ^ 1][i] = ZeroBlock; - } - - // space for our sums of each level. - std::array, 2> mySums; - - // For all other levels, expand the GGM tree and add in - // the correction along the active path. - for (u64 d = 1; d < mDepth; ++d) - { - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = levels[d]; - - // The next level that we want to construct. - auto level1 = levels[d + 1]; - - // Zero out the previous sums. - memset(mySums.data(), 0, sizeof(mySums)); - - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - // for internal nodes we the optimized approach. - if (d + 1 < mDepth && 0) - { - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto parent = level0[parentIdx]; - - auto& child0 = level1.data()[childIdx]; - auto& child1 = level1.data()[childIdx + 1]; - mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); - - child0[0] = child1[0] ^ parent[0]; - child0[1] = child1[1] ^ parent[1]; - child0[2] = child1[2] ^ parent[2]; - child0[3] = child1[3] ^ parent[3]; - child0[4] = child1[4] ^ parent[4]; - child0[5] = child1[5] ^ parent[5]; - child0[6] = child1[6] ^ parent[6]; - child0[7] = child1[7] ^ parent[7]; - - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent (assuming !DEBUG_PRINT_PPRF). - // This is ok since we will later XOR off these incorrect values. - mySums[0][0] = mySums[0][0] ^ child0[0]; - mySums[0][1] = mySums[0][1] ^ child0[1]; - mySums[0][2] = mySums[0][2] ^ child0[2]; - mySums[0][3] = mySums[0][3] ^ child0[3]; - mySums[0][4] = mySums[0][4] ^ child0[4]; - mySums[0][5] = mySums[0][5] ^ child0[5]; - mySums[0][6] = mySums[0][6] ^ child0[6]; - mySums[0][7] = mySums[0][7] ^ child0[7]; - - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; - - mySums[1][0] = mySums[1][0] ^ child1[0]; - mySums[1][1] = mySums[1][1] ^ child1[1]; - mySums[1][2] = mySums[1][2] ^ child1[2]; - mySums[1][3] = mySums[1][3] ^ child1[3]; - mySums[1][4] = mySums[1][4] ^ child1[4]; - mySums[1][5] = mySums[1][5] ^ child1[5]; - mySums[1][6] = mySums[1][6] ^ child1[6]; - mySums[1][7] = mySums[1][7] ^ child1[7]; - } - } - else - { - // for leaf nodes both children should be hashed. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto parent = level0[parentIdx]; - - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - // The child that we will write in this iteration. - auto& child = level1[childIdx]; - - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gAes[keep].hashBlocks<8>(parent.data(), child.data()); - - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent (assuming !DEBUG_PRINT_PPRF). - // This is ok since we will later XOR off these incorrect values. - auto& sum = mySums[keep]; - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } - } - } - - // For everything but the last level we have to - // 1) fix our sums so they dont include the incorrect - // values that are the children of the active parent - // 2) Update the non-active child of the active parent. - if (!programActivePath || d != mDepth - 1) - { - for (u64 i = 0; i < 8; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = mPoints[i + treeIdx]; - - // The index of the active child node. - auto activeChildIdx = leafIdx >> (mDepth - 1 - d); - - // The index of the active child node sibling. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; - - auto& inactiveChild = level1[inactiveChildIdx][i]; - - // correct the sum value by XORing off the incorrect - auto correctSum = - inactiveChild ^ - theirSums[d][notAi][i]; - - inactiveChild = - correctSum ^ - mySums[notAi][i] ^ - mBaseOTs[i + treeIdx][d]; - - } - } - } - - // last level. - auto level = levels[mDepth]; - if (programActivePath) - { - // Now processes the last level. This one is special - // because we must XOR in the correction value as - // before but we must also fixed the child value for - // the active child. To do this, we will receive 4 - // values. Two for each case (left active or right active). - //timer.setTimePoint("recv.recvLast"); - - auto d = mDepth - 1; - for (u64 j = 0; j < 8; ++j) - { - // The index of the child on the active path. - auto activeChildIdx = mPoints[j + treeIdx]; - - // The index of the other (inactive) child. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; - - std::array masks, maskIn; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - maskIn[0] = mBaseOTs[j + treeIdx][d]; - maskIn[1] = mBaseOTs[j + treeIdx][d] ^ AllOneBlock; - mAesFixedKey.hashBlocks<2>(maskIn.data(), masks.data()); - - // now get the chosen message OT strings by XORing - // the expended (random) OT strings with the lastOts values. - auto& ot0 = lastOts[j][2 * notAi + 0]; - auto& ot1 = lastOts[j][2 * notAi + 1]; - ot0 = ot0 ^ masks[0]; - ot1 = ot1 ^ masks[1]; - - auto& inactiveChild = level[inactiveChildIdx][j]; - auto& activeChild = level[activeChildIdx][j]; - - // Fix the sums we computed previously to not include the - // incorrect child values. - auto inactiveSum = mySums[notAi][j] ^ inactiveChild; - auto activeSum = mySums[notAi ^ 1][j] ^ activeChild; - - // Update the inactive and active child to have to correct - // value by XORing their full sum with out partial sum, which - // gives us exactly the value we are missing. - inactiveChild = ot0 ^ inactiveSum; - activeChild = ot1 ^ activeSum; - } - // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); - - //timer.setTimePoint("recv.expandLast"); - } - else - { - for (auto j : rng(std::min(8, mPntCount - treeIdx))) - { - // The index of the child on the active path. - auto activeChildIdx = mPoints[j + treeIdx]; - level[activeChildIdx][j] = ZeroBlock; - } - } - } - - task<> SilentMultiPprfReceiver::expand( - Socket& chl, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 _) - { - validateExpandFormat(oFormat, output, mDomain, mPntCount); - - MC_BEGIN(task<>, this, oFormat, output, &chl, activeChildXorDelta, - i = u64{}, - mTreeAllocDepth = u64{}, - tree = span>{}, - levels = std::vector>>{}, - buff = std::vector{}, - sums = span, 2>>{}, - last = span>{} - ); - - setTimePoint("SilentMultiPprfReceiver.start"); - mPoints.resize(roundUpTo(mPntCount, 8)); - getPoints(mPoints, PprfOutputFormat::ByLeafIndex); - - mTreeAllocDepth = mDepth + (oFormat != PprfOutputFormat::Interleaved); - mTreeAlloc.reserve(1, (1ull << mTreeAllocDepth) + 2); - setTimePoint("SilentMultiPprfSender.reserve"); - - levels.resize(mDepth + 1); - allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); - - for (i = 0; i < mPntCount; i += 8) - { - // for interleaved format, the last level of the tree - // is simply the output. - if (oFormat == PprfOutputFormat::Interleaved) - { - auto b = (AlignedArray*)output.data(); - auto forest = i / 8; - b += forest * mDomain; - - levels.back() = span>(b, mDomain); - } - - // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); - - MC_AWAIT(chl.recv(buff)); - - // exapnd the tree - expandOne(i, activeChildXorDelta, levels, sums, last); - - // if we aren't interleaved, we need to copy the - // last layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) - copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); - } - - setTimePoint("SilentMultiPprfReceiver.join"); - - mBaseOTs = {}; - mTreeAlloc.del(tree); - mTreeAlloc.clear(); - - setTimePoint("SilentMultiPprfReceiver.de-alloc"); - - MC_END(); - } - -} - -#endif diff --git a/libOTe/Tools/SilentPprf.h b/libOTe/Tools/SilentPprf.h deleted file mode 100644 index 44600021..00000000 --- a/libOTe/Tools/SilentPprf.h +++ /dev/null @@ -1,318 +0,0 @@ -#pragma once -// © 2020 Peter Rindal. -// © 2022 Visa. -// © 2022 Lawrence Roy. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - -#include -#include -#include -#include -#include -#include -#include "libOTe/Tools/Coproto.h" -#include -//#define DEBUG_PRINT_PPRF - -namespace osuCrypto -{ - - // the various formats that the output of the - // Pprf can be generated. - enum class PprfOutputFormat - { - // The i'th row holds the i'th leaf for all trees. - // The j'th tree is in the j'th column. - ByLeafIndex, - - // The i'th row holds the i'th tree. - // The j'th leaf is in the j'th column. - ByTreeIndex, - - // The native output mode. The output will be - // a single row with all leaf values. - // Every 8 trees are mixed together where the - // i'th leaf for each of the 8 tree will be next - // to each other. For example, let tij be the j'th - // leaf of the i'th tree. If we have m leaves, then - // - // t00 t10 ... t70 t01 t11 ... t71 ... t0m t1m ... t7m - // t80 t90 ... t_{15,0} t81 t91 ... t_{15,1} ... t8m t9m ... t_{15,m} - // ... - // - // These are all flattened into a single row. - Interleaved, - - // call the user's callback. The leaves will be in - // Interleaved format. - Callback - }; - - enum class OTType - { - Random, Correlated - }; - - enum class ChoiceBitPacking - { - False, True - }; - - enum class SilentSecType - { - SemiHonest, - Malicious, - //MaliciousFS - }; - - u64 interleavedPoint(u64 point, u64 treeIdx, u64 totalTrees, u64 domain, PprfOutputFormat format); - void interleavedPoints(span points, u64 domain, PprfOutputFormat format); - u64 getActivePath(const span& choiceBits); - - struct TreeAllocator - { - TreeAllocator() = default; - TreeAllocator(const TreeAllocator&) = delete; - TreeAllocator(TreeAllocator&&) = delete; - - using ValueType = AlignedArray; - std::list> mTrees; - std::vector> mFreeTrees; - std::mutex mMutex; - u64 mTreeSize = 0, mNumTrees = 0; - - void reserve(u64 num, u64 size) - { - std::lock_guard lock(mMutex); - mTreeSize = size; - mNumTrees += num; - mTrees.clear(); - mFreeTrees.clear(); - mTrees.emplace_back(num * size); - auto iter = mTrees.back().data(); - for (u64 i = 0; i < num; ++i) - { - mFreeTrees.push_back(span(iter, size)); - assert((u64)mFreeTrees.back().data() % 32 == 0); - iter += size; - } - } - - span get() - { - std::lock_guard lock(mMutex); - if (mFreeTrees.size() == 0) - { - assert(mTreeSize); - mTrees.emplace_back(mTreeSize); - mFreeTrees.push_back(span(mTrees.back().data(), mTreeSize)); - assert((u64)mFreeTrees.back().data() % 32 == 0); - ++mNumTrees; - } - - auto ret = mFreeTrees.back(); - mFreeTrees.pop_back(); - return ret; - } - - void del(span uPtr) - { - std::lock_guard lock(mMutex); - mFreeTrees.push_back(uPtr); - } - - void clear() - { - assert(mNumTrees == mFreeTrees.size()); - mTrees = {}; - mFreeTrees = {}; - mTreeSize = 0; - mNumTrees = 0; - } - }; - - - void allocateExpandBuffer( - u64 depth, - u64 activeChildXorDelta, - std::vector& buff, - span< std::array, 2>>& sums, - span< std::array>& last); - - void allocateExpandTree( - u64 depth, - TreeAllocator& alloc, - span>& tree, - std::vector>>& levels); - - class SilentMultiPprfSender : public TimerAdapter - { - public: - u64 mDomain = 0, mDepth = 0, mPntCount = 0; - std::vector mValue; - bool mPrint = false; - TreeAllocator mTreeAlloc; - Matrix> mBaseOTs; - - std::function>)> mOutputFn; - - SilentMultiPprfSender() = default; - SilentMultiPprfSender(const SilentMultiPprfSender&) = delete; - SilentMultiPprfSender(SilentMultiPprfSender&&) = delete; - - SilentMultiPprfSender(u64 domainSize, u64 pointCount) - { - configure(domainSize, pointCount); - } - - void configure(u64 domainSize, u64 pointCount) - { - mDomain = domainSize; - mDepth = log2ceil(mDomain); - mPntCount = pointCount; - //mPntCount8 = roundUpTo(pointCount, 8); - - mBaseOTs.resize(0, 0); - } - - - // the number of base OTs that should be set. - u64 baseOtCount() const - { - return mDepth * mPntCount; - } - - // returns true if the base OTs are currently set. - bool hasBaseOts() const - { - return mBaseOTs.size(); - } - - - void setBase(span> baseMessages); - - task<> expand(Socket& chls, span value, block seed, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) - { - MatrixView o(output.data(), output.size(), 1); - return expand(chls, value, seed, o, oFormat, activeChildXorDelta, numThreads); - } - - task<> expand( - Socket& chl, - span value, - block seed, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 numThreads); - - void setValue(span value); - - void clear(); - - void expandOne( - block aesSeed, - u64 treeIdx, - bool activeChildXorDelta, - span>> levels, - span, 2>> sums, - span> lastOts - ); - }; - - - class SilentMultiPprfReceiver : public TimerAdapter - { - public: - u64 mDomain = 0, mDepth = 0, mPntCount = 0; - - std::vector mPoints; - - Matrix mBaseOTs; - Matrix mBaseChoices; - bool mPrint = false; - TreeAllocator mTreeAlloc; - block mDebugValue; - std::function>)> mOutputFn; - - SilentMultiPprfReceiver() = default; - SilentMultiPprfReceiver(const SilentMultiPprfReceiver&) = delete; - SilentMultiPprfReceiver(SilentMultiPprfReceiver&&) = delete; - - void configure(u64 domainSize, u64 pointCount) - { - mDomain = domainSize; - mDepth = log2ceil(mDomain); - mPntCount = pointCount; - - mBaseOTs.resize(0, 0); - } - - // For output format ByLeafIndex or ByTreeIndex, the choice bits it - // samples are in blocks of mDepth, with mPntCount blocks total (one for - // each punctured point). For ByLeafIndex these blocks encode the punctured - // leaf index in big endian, while for ByTreeIndex they are in - // little endian. - BitVector sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng); - - // choices is in the same format as the output from sampleChoiceBits. - void setChoiceBits(PprfOutputFormat format, BitVector choices); - - // the number of base OTs that should be set. - u64 baseOtCount() const - { - return mDepth * mPntCount; - } - - // returns true if the base OTs are currently set. - bool hasBaseOts() const - { - return mBaseOTs.size(); - } - - void setBase(span baseMessages); - - std::vector getPoints(PprfOutputFormat format) - { - std::vector pnts(mPntCount); - getPoints(pnts, format); - return pnts; - } - void getPoints(span points, PprfOutputFormat format); - - task<> expand(Socket& chl, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) - { - MatrixView o(output.data(), output.size(), 1); - return expand(chl, o, oFormat, activeChildXorDelta, numThreads); - } - - // activeChildXorDelta says whether the sender is trying to program the - // active child to be its correct value XOR delta. If it is not, the - // active child will just take a random value. - task<> expand(Socket& chl, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads); - - void clear() - { - mBaseOTs.resize(0, 0); - mBaseChoices.resize(0, 0); - mDomain = 0; - mDepth = 0; - mPntCount = 0; - } - - void expandOne( - u64 treeIdx, - bool programActivePath, - span>> levels, - span, 2>> encSums, - span> lastOts); - }; -} -#endif diff --git a/libOTe/TwoChooseOne/ConfigureCode.cpp b/libOTe/TwoChooseOne/ConfigureCode.cpp index 2519de01..110227c5 100644 --- a/libOTe/TwoChooseOne/ConfigureCode.cpp +++ b/libOTe/TwoChooseOne/ConfigureCode.cpp @@ -7,7 +7,7 @@ #include "libOTe/Tools/QuasiCyclicCode.h" #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" -#include "libOTe/Tools/ExConvCode/ExConvCode2.h" +#include "libOTe/Tools/ExConvCode/ExConvCode.h" #include namespace osuCrypto { diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index 321d886e..d9d4fba6 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -12,7 +12,8 @@ #include #include #include "libOTe/Tools/QuasiCyclicCode.h" -#include "libOTe/Tools/Subfield/Subfield.h" +#include "libOTe/Tools/CoeffCtx.h" + namespace osuCrypto { @@ -67,11 +68,9 @@ namespace osuCrypto throw std::runtime_error("wrong number of silent base OTs"); auto genOts = recvBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOts = recvBaseOts.subspan(genOts.size(), mGapOts.size()); - auto malOts = recvBaseOts.subspan(genOts.size() + mGapOts.size()); + auto malOts = recvBaseOts.subspan(genOts.size()); mGen.setBase(genOts); - std::copy(gapOts.begin(), gapOts.end(), mGapOts.begin()); std::copy(malOts.begin(), malOts.end(), mMalCheckOts.begin()); } @@ -114,14 +113,7 @@ namespace osuCrypto if (isConfigured() == false) throw std::runtime_error("configure(...) must be called first"); - auto choice = mGen.sampleChoiceBits(mN2, getPprfFormat(), prng); - - if (mGapOts.size()) - { - mGapBaseChoice.resize(mGapOts.size()); - mGapBaseChoice.randomize(prng); - choice.append(mGapBaseChoice); - } + auto choice = mGen.sampleChoiceBits(prng); mS.resize(mNumPartitions); mGen.getPoints(mS, getPprfFormat()); @@ -209,7 +201,6 @@ namespace osuCrypto throw std::runtime_error("configure must be called first"); return mGen.baseOtCount() + - mGapOts.size() + (mMalType == SilentSecType::Malicious) * 128; } @@ -222,7 +213,6 @@ namespace osuCrypto { mMalType = malType; mNumThreads = numThreads; - mGapOts.resize(0); switch (mMultType) { @@ -239,28 +229,6 @@ namespace osuCrypto mScaler); break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - { - if (scaler != 2) - throw std::runtime_error("only scaler = 2 is supported for slv. " LOCATION); - - u64 gap; - SilverConfigure(numOTs, 128, - mMultType, - mRequestedNumOts, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - - mGapOts.resize(gap); - break; - } -#endif case osuCrypto::MultType::ExAcc7: case osuCrypto::MultType::ExAcc11: case osuCrypto::MultType::ExAcc21: @@ -469,61 +437,24 @@ namespace osuCrypto mA.resize(mN2); mC.resize(0); - //// do the compression to get the final OTs. - //if (mMultType == MultType::QuasiCyclic) - //{ - // rT = MatrixView(mA.data(), 128, mN2 / 128); - // // locally expand the seeds. - // MC_AWAIT(mGen.expand(chl, prng, rT, PprfOutputFormat::InterleavedTransposed, mNumThreads)); - // setTimePoint("recver.expand.pprf_transpose"); + MC_AWAIT(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); + setTimePoint("recver.expand.pprf_transpose"); + gTimer.setTimePoint("recver.expand.pprf_transpose"); - // if (mDebug) - // { - // MC_AWAIT(checkRT(chl, rT)); - // } - - // randMulQuasiCyclic(type); - - //} - //else - { - - main = mNumPartitions * mSizePer; - if (mGapOts.size()) - { - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - MC_AWAIT(chl.recv(gapVals)); - for (i = main, j = 0; i < mN2; ++i, ++j) - { - if (mGapBaseChoice[j]) - mA[i] = AES(mGapOts[j]).ecbEncBlock(ZeroBlock) ^ gapVals[j]; - else - mA[i] = mGapOts[j]; - } - } - - - MC_AWAIT(mGen.expand(chl, mA.subspan(0, main), PprfOutputFormat::Interleaved, true, mNumThreads)); - setTimePoint("recver.expand.pprf_transpose"); - gTimer.setTimePoint("recver.expand.pprf_transpose"); - - - if (mMalType == SilentSecType::Malicious) - MC_AWAIT(ferretMalCheck(chl, prng)); + if (mMalType == SilentSecType::Malicious) + MC_AWAIT(ferretMalCheck(chl, prng)); - if (mDebug) - { - rT = MatrixView(mA.data(), mN2, 1); - MC_AWAIT(checkRT(chl, rT)); - } - compress(type); + if (mDebug) + { + rT = MatrixView(mA.data(), mN2, 1); + MC_AWAIT(checkRT(chl, rT)); } + compress(type); + mA.resize(mRequestedNumOts); if (mC.size()) @@ -542,9 +473,10 @@ namespace osuCrypto sum0 = block{}, sum1 = block{}, mySum = block{}, - deltaShare = block{}, + b = AlignedUnVector(1), + //deltaShare = block{}, i = u64{}, - sender = NoisyVoleSender{}, + sender = NoisyVoleSender{}, theirHash = std::array{}, myHash = std::array{}, ro = RandomOracle(32) @@ -569,9 +501,9 @@ namespace osuCrypto mySum = sum0.gf128Reduce(sum1); - MC_AWAIT(sender.send(mMalCheckX, { &deltaShare,1 }, prng, mMalCheckOts, chl)); + MC_AWAIT(sender.send(mMalCheckX, b, prng, mMalCheckOts, chl, {})); ; - ro.Update(mySum ^ deltaShare); + ro.Update(mySum ^ b[0]); ro.Final(myHash); MC_AWAIT(chl.recv(theirHash)); @@ -726,7 +658,7 @@ namespace osuCrypto case osuCrypto::MultType::ExAcc40: { AlignedUnVector A2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mA.subspan(0, mEAEncoder.mCodeSize), A2); + mEAEncoder.dualEncode(mA.subspan(0, mEAEncoder.mCodeSize), A2, {}); std::swap(mA, A2); break; } @@ -734,7 +666,7 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mA.subspan(0, mExConvEncoder.generatorCols())); + mExConvEncoder.dualEncode(mA.begin(), {}); break; default: throw RTE_LOC; @@ -780,9 +712,10 @@ namespace osuCrypto { AlignedUnVector A2(mEAEncoder.mMessageSize); AlignedUnVector C2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode2( + mEAEncoder.dualEncode2( mA.subspan(0, mEAEncoder.mCodeSize), A2, - mC.subspan(0, mEAEncoder.mCodeSize), C2); + mC.subspan(0, mEAEncoder.mCodeSize), C2, + {}); std::swap(mA, A2); std::swap(mC, C2); @@ -793,10 +726,10 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode2( - mA.subspan(0, mExConvEncoder.mCodeSize), - mC.subspan(0, mExConvEncoder.mCodeSize) - ); + mExConvEncoder.dualEncode2( + mA.begin(), + mC.begin(), + {}); break; default: throw RTE_LOC; @@ -819,8 +752,6 @@ namespace osuCrypto mGen.clear(); - mGapOts = {}; - mS = {}; } diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h index 0d1334ce..4a5e2683 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h @@ -16,13 +16,14 @@ #include #include #include -#include +#include #include #include #include #include #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" +#include "SilentOtExtUtil.h" namespace osuCrypto { @@ -69,13 +70,9 @@ namespace osuCrypto macoro::optional mOtExtRecver; #endif - // The OTs recv msgs which will be used to flood the - // last gap bits of the noisy vector for the slv code. - std::vector mGapOts; - // The OTs recv msgs which will be used to create the // secret share of xa * delta as described in ferret. - std::vector mMalCheckOts; + AlignedUnVector mMalCheckOts; // The OTs choice bits which will be used to flood the // last gap bits of the noisy vector for the slv code. @@ -94,7 +91,7 @@ namespace osuCrypto block mMalCheckX = ZeroBlock; // The ggm tree thats used to generate the sparse vectors. - SilentMultiPprfReceiver mGen; + RegularPprfReceiver mGen; // The type of compress we will use to generate the // dense vectors from the sparse vectors. diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp index 7fb8e32d..148bc1c0 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp @@ -135,7 +135,7 @@ namespace osuCrypto if (isConfigured() == false) throw std::runtime_error("configure must be called first"); - auto n = mGen.baseOtCount() + mGapOts.size(); + auto n = mGen.baseOtCount(); if (mMalType == SilentSecType::Malicious) n += 128; @@ -151,12 +151,10 @@ namespace osuCrypto throw RTE_LOC; auto genOt = sendBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOt = sendBaseOts.subspan(genOt.size(), mGapOts.size()); - auto malOt = sendBaseOts.subspan(genOt.size() + gapOt.size()); + auto malOt = sendBaseOts.subspan(genOt.size()); mMalCheckOts.resize((mMalType == SilentSecType::Malicious) * 128); mGen.setBase(genOt); - std::copy(gapOt.begin(), gapOt.end(), mGapOts.begin()); std::copy(malOt.begin(), malOt.end(), mMalCheckOts.begin()); } @@ -166,7 +164,6 @@ namespace osuCrypto mMalType = malType; mNumThreads = numThreads; - mGapOts.resize(0); switch (mMultType) { @@ -183,28 +180,6 @@ namespace osuCrypto mScaler); break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - { - if (scaler != 2) - throw std::runtime_error("only scaler = 2 is supported for slv. " LOCATION); - - u64 gap; - SilverConfigure(numOTs, 128, - mMultType, - mRequestNumOts, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - - mGapOts.resize(gap); - break; - } -#endif case osuCrypto::MultType::ExAcc7: case osuCrypto::MultType::ExAcc11: case osuCrypto::MultType::ExAcc21: @@ -249,8 +224,6 @@ namespace osuCrypto mDelta = block(0,0); - mGapOts = {}; - mGen.clear(); } @@ -392,9 +365,8 @@ namespace osuCrypto Socket& chl) { MC_BEGIN(task<>,this, d, n, &prng, &chl, - rT = MatrixView{}, - gapVals = std::vector {}, - i = u64{}, j = u64{}, main = u64{} + i = u64{}, j = u64{}, + delta = AlignedUnVector{} ); gTimer.setTimePoint("sender.ot.enter"); @@ -420,54 +392,23 @@ namespace osuCrypto // allocate b mB.resize(mN2); + + delta.resize(1); + delta[0] = mDelta; - //if (mMultType == MultType::QuasiCyclic) - //{ - // rT = MatrixView(mB.data(), 128, mN2 / 128); - - // MC_AWAIT(mGen.expand(chl, mDelta, prng, rT, PprfOutputFormat::InterleavedTransposed, mNumThreads)); - // setTimePoint("sender.expand.pprf_transpose"); - // gTimer.setTimePoint("sender.expand.pprf_transpose"); - - // if (mDebug) - // MC_AWAIT(checkRT(chl)); - - // randMulQuasiCyclic(); - //} - //else - { - - main = mNumPartitions * mSizePer; - if (mGapOts.size()) - { - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - for (i = main, j = 0; i < mN2; ++i, ++j) - { - auto v = mGapOts[j][0] ^ mDelta; - gapVals[j] = AES(mGapOts[j][1]).ecbEncBlock(ZeroBlock) ^ v; - mB[i] = mGapOts[j][0]; - //std::cout << "jj " << j << " " <(1), + c = AlignedUnVector(1), + //deltaShare = block{}, i = u64{}, - recver = NoisyVoleReceiver{}, + recver = NoisyVoleReceiver{}, myHash = std::array{}, ro = RandomOracle(32) ); @@ -506,11 +449,11 @@ namespace osuCrypto mySum = sum0.gf128Reduce(sum1); + c[0] = mDelta; + //a[0] = deltaShare; + MC_AWAIT(recver.receive(c, a, prng, mMalCheckOts, chl, {})); - - MC_AWAIT(recver.receive({ &mDelta,1 }, { &deltaShare,1 }, prng, mMalCheckOts, chl)); - - ro.Update(mySum ^ deltaShare); + ro.Update(mySum ^ a[0]); ro.Final(myHash); MC_AWAIT(chl.send(std::move(myHash))); @@ -534,17 +477,6 @@ namespace osuCrypto #endif } break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - if (mTimer) - mEncoder.setTimer(getTimer()); - mEncoder.dualEncode(mB); - setTimePoint("sender.expand.ldpc.dualEncode"); - - break; -#endif case osuCrypto::MultType::ExAcc7: case osuCrypto::MultType::ExAcc11: case osuCrypto::MultType::ExAcc21: @@ -553,7 +485,7 @@ namespace osuCrypto if (mTimer) mEAEncoder.setTimer(getTimer()); AlignedUnVector B2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mB.subspan(0, mEAEncoder.mCodeSize), B2); + mEAEncoder.dualEncode(mB.subspan(0, mEAEncoder.mCodeSize), B2, {}); std::swap(mB, B2); break; } @@ -561,14 +493,14 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); + mExConvEncoder.dualEncode(mB.begin(), {}); break; default: throw RTE_LOC; break; } - + } // // diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h index 9ad0c7b4..1ab10edb 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h @@ -15,13 +15,14 @@ #include #include #include -#include +#include #include #include #include #include #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" +#include "SilentOtExtUtil.h" namespace osuCrypto { @@ -109,7 +110,7 @@ namespace osuCrypto #endif // The ggm tree thats used to generate the sparse vectors. - SilentMultiPprfSender mGen; + RegularPprfSender mGen; // The type of compress we will use to generate the // dense vectors from the sparse vectors. @@ -121,10 +122,6 @@ namespace osuCrypto ExConvCode mExConvEncoder; EACode mEAEncoder; - // The OTs send msgs which will be used to flood the - // last gap bits of the noisy vector for the slv code. - std::vector> mGapOts; - // The OTs send msgs which will be used to create the // secret share of xa * delta as described in ferret. std::vector> mMalCheckOts; diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h b/libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h new file mode 100644 index 00000000..9476bbc6 --- /dev/null +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h @@ -0,0 +1,23 @@ +#pragma once + + +namespace osuCrypto +{ + + enum class OTType + { + Random, Correlated + }; + + enum class ChoiceBitPacking + { + False, True + }; + + enum class SilentSecType + { + SemiHonest, + Malicious, + //MaliciousFS + }; +} \ No newline at end of file diff --git a/libOTe/Vole/Noisy/NoisyVoleReceiver.cpp b/libOTe/Vole/Noisy/NoisyVoleReceiver.cpp deleted file mode 100644 index 9a9a20cc..00000000 --- a/libOTe/Vole/Noisy/NoisyVoleReceiver.cpp +++ /dev/null @@ -1,88 +0,0 @@ -#include "NoisyVoleReceiver.h" - -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) -#include "cryptoTools/Common/BitIterator.h" -#include "cryptoTools/Common/Matrix.h" - - -namespace osuCrypto -{ - - task<> NoisyVoleReceiver::receive(span y, span z, PRNG& prng, - OtSender& ot, Socket& chl) - { - MC_BEGIN(task<>,this, y,z, &prng, &ot, &chl, - otMsg = AlignedUnVector>{ 128 } - ); - - setTimePoint("NoisyVoleReceiver.ot.begin"); - - MC_AWAIT(ot.send(otMsg, prng, chl)); - - setTimePoint("NoisyVoleReceiver.ot.end"); - - MC_AWAIT(receive(y, z, prng, otMsg, chl)); - - MC_END(); - } - - task<> NoisyVoleReceiver::receive( - span y, span z, - PRNG& _, span> otMsg, - Socket& chl) - { - MC_BEGIN(task<>,this, y,z, otMsg, &chl, - msg = Matrix{}, - prng = std::move(PRNG{}) - //buffer = std::vector{} - ); - - setTimePoint("NoisyVoleReceiver.begin"); - if (otMsg.size() != 128) - throw RTE_LOC; - if (y.size() != z.size()) - throw RTE_LOC; - if (z.size() == 0) - throw RTE_LOC; - - memset(z.data(), 0, sizeof(block) * z.size()); - msg.resize(otMsg.size(), y.size()); - - //buffer.resize(z.size()); - - for (u64 ii = 0; ii < (u64)otMsg.size(); ++ii) - { - //PRNG p0(otMsg[ii][0]); - //PRNG p1(otMsg[ii][1]); - prng.SetSeed(otMsg[ii][0], z.size()); - auto& buffer = prng.mBuffer; - - for (u64 j = 0; j < (u64)y.size(); ++j) - { - z[j] = z[j] ^ buffer[j]; - - block twoPowI = ZeroBlock; - *BitIterator((u8*)&twoPowI, ii) = 1; - - auto yy = y[j].gf128Mul(twoPowI); - - msg(ii, j) = yy ^ buffer[j]; - } - - prng.SetSeed(otMsg[ii][1], z.size()); - - for (u64 j = 0; j < (u64)y.size(); ++j) - { - // enc one message under the OT msg. - msg(ii, j) = msg(ii, j) ^ buffer[j]; - } - } - - MC_AWAIT(chl.send(std::move(msg))); - //chl.asyncSend(std::move(msg)); - setTimePoint("NoisyVoleReceiver.done"); - - MC_END(); - } -} -#endif \ No newline at end of file diff --git a/libOTe/Vole/Noisy/NoisyVoleReceiver.h b/libOTe/Vole/Noisy/NoisyVoleReceiver.h index da0182b6..49d45007 100644 --- a/libOTe/Vole/Noisy/NoisyVoleReceiver.h +++ b/libOTe/Vole/Noisy/NoisyVoleReceiver.h @@ -1,12 +1,22 @@ #pragma once // © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. #include #if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) @@ -14,20 +24,127 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" #include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/TwoChooseOne/OTExtInterface.h" #include "libOTe/Tools/Coproto.h" +#include "libOTe/TwoChooseOne/OTExtInterface.h" +#include "libOTe/Tools/CoeffCtx.h" + +namespace osuCrypto { + + template < + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class NoisyVoleReceiver : public TimerAdapter + { + public: + using VecF = typename CoeffCtx::template Vec; + + + // for chosen c, compute a such htat + // + // a = b + c * delta + // + template + task<> receive(VecG& c, VecF& a, PRNG& prng, + OtSender& ot, Socket& chl, CoeffCtx ctx) + { + MC_BEGIN(task<>, this, &c, &a, &prng, &ot, &chl, ctx, + otMsg = AlignedUnVector>{}); + + setTimePoint("NoisyVoleReceiver.ot.begin"); + otMsg.resize(ctx.bitSize()); + MC_AWAIT(ot.send(otMsg, prng, chl)); + + setTimePoint("NoisyVoleReceiver.ot.end"); + + MC_AWAIT(receive(c, a, prng, otMsg, chl,ctx)); + + MC_END(); + } + + // for chosen c, compute a such htat + // + // a = b + c * delta + // + template + task<> receive(VecG& c, VecF& a, PRNG& _, + span> otMsg, + Socket& chl, CoeffCtx ctx) + { + MC_BEGIN(task<>, this, &c, &a, otMsg, &chl, ctx, + buff = std::vector{}, + msg = VecF{}, + temp = VecF{}, + prng = std::move(PRNG{}) + ); + + if (c.size() != a.size()) + throw RTE_LOC; + if (a.size() == 0) + throw RTE_LOC; + + setTimePoint("NoisyVoleReceiver.begin"); + + ctx.zero(a.begin(), a.end()); + ctx.resize(msg, otMsg.size() * a.size()); + ctx.resize(temp, 2); + + for (size_t i = 0, k = 0; i < otMsg.size(); ++i) + { + prng.SetSeed(otMsg[i][0], a.size()); + + // t1 = 2^i + ctx.powerOfTwo(temp[1], i); + //std::cout << "2^i " << ctx.str(temp[1]) << "\n"; + + for (size_t j = 0; j < c.size(); ++j, ++k) + { + // msg[i,j] = otMsg[i,j,0] + ctx.fromBlock(msg[k], prng.get()); + //ctx.zero(msg.begin() + k, msg.begin() + k + 1); + //std::cout << "m" << i << ",0 = " << ctx.str(msg[k]) << std::endl; + + // a[j] += otMsg[i,j,0] + ctx.plus(a[j], a[j], msg[k]); + //std::cout << "z = " << ctx.str(a[j]) << std::endl; + + // temp = 2^i * c[j] + ctx.mul(temp[0], temp[1], c[j]); + //std::cout << "2^i y = " << ctx.str(temp[0]) << std::endl; + + // msg[i,j] = otMsg[i,j,0] + 2^i * c[j] + ctx.minus(msg[k], msg[k], temp[0]); + //std::cout << "m" << i << ",0 + 2^i y = " << ctx.str(msg[k]) << std::endl; + } + + k -= c.size(); + prng.SetSeed(otMsg[i][1], a.size()); + + for (size_t j = 0; j < c.size(); ++j, ++k) + { + // temp = otMsg[i,j,1] + ctx.fromBlock(temp[0], prng.get()); + //ctx.zero(temp.begin(), temp.begin() + 1); + //std::cout << "m" << i << ",1 = " << ctx.str(temp[0]) << std::endl; + + // enc one message under the OT msg. + // msg[i,j] = (otMsg[i,j,0] + 2^i * c[j]) - otMsg[i,j,1] + ctx.minus(msg[k], msg[k], temp[0]); + //std::cout << "m" << i << ",0 + 2^i y - m" << i << ",1 = " << ctx.str(msg[k]) << std::endl << std::endl; + } + } -namespace osuCrypto -{ - class NoisyVoleReceiver : public TimerAdapter - { - public: + buff.resize(msg.size() * ctx.byteSize()); + ctx.serialize(msg.begin(), msg.end(), buff.begin()); - task<> receive(span y, span z, PRNG& prng, OtSender& ot, Socket& chl); - task<> receive(span y, span z, PRNG& prng, span> otMsg, Socket& chl); + MC_AWAIT(chl.send(std::move(buff))); + setTimePoint("NoisyVoleReceiver.done"); - }; + MC_END(); + } + }; -} +} // namespace osuCrypto #endif diff --git a/libOTe/Vole/Noisy/NoisyVoleSender.cpp b/libOTe/Vole/Noisy/NoisyVoleSender.cpp deleted file mode 100644 index 83cc55fc..00000000 --- a/libOTe/Vole/Noisy/NoisyVoleSender.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "NoisyVoleSender.h" - -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) -#include "cryptoTools/Common/BitVector.h" -#include "cryptoTools/Common/Matrix.h" - -namespace osuCrypto -{ - task<> NoisyVoleSender::send( - block x, span z, - PRNG& prng, - OtReceiver& ot, - Socket& chl) - { - MC_BEGIN(task<>,this, x, z, &prng, &ot, &chl, - bv = BitVector((u8*)&x, 128), - otMsg = AlignedUnVector{ 128 }); - - setTimePoint("NoisyVoleSender.ot.begin"); - - //BitVector bv((u8*)&x, 128); - //std::array otMsg; - MC_AWAIT(ot.receive(bv, otMsg, prng, chl)); - setTimePoint("NoisyVoleSender.ot.end"); - - MC_AWAIT(send(x, z, prng, otMsg, chl)); - - MC_END(); - } - - task<> NoisyVoleSender::send( - block x, - span z, - PRNG& prng, - span otMsg, - Socket& chl) - { - MC_BEGIN(task<>,this, x, z, &prng, otMsg, &chl, - msg = Matrix{}, - buffer = std::vector{}, - xIter = BitIterator{}); - - if (otMsg.size() != 128) - throw RTE_LOC; - setTimePoint("NoisyVoleSender.main"); - - msg.resize(otMsg.size(), z.size()); - memset(z.data(), 0, sizeof(block) * z.size()); - - - MC_AWAIT(chl.recv(msg)); - - setTimePoint("NoisyVoleSender.recvMsg"); - buffer.resize(z.size()); - - xIter = BitIterator((u8*)&x); - - for (u64 i = 0; i < otMsg.size(); ++i, ++xIter) - { - PRNG pi(otMsg[i]); - pi.get(buffer); - - if (*xIter) - { - for (u64 j = 0; j < z.size(); ++j) - { - buffer[j] = msg(i, j) ^ buffer[j]; - } - } - - for (u64 j = 0; j < (u64)z.size(); ++j) - { - z[j] = z[j] ^ buffer[j]; - } - } - setTimePoint("NoisyVoleSender.done"); - - MC_END(); - } - -} -#endif diff --git a/libOTe/Vole/Noisy/NoisyVoleSender.h b/libOTe/Vole/Noisy/NoisyVoleSender.h index 862d3002..6868b453 100644 --- a/libOTe/Vole/Noisy/NoisyVoleSender.h +++ b/libOTe/Vole/Noisy/NoisyVoleSender.h @@ -1,32 +1,142 @@ #pragma once // © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). +// This code implements features described in [Silver: Silent VOLE and Oblivious +// Transfer from Hardness of Decoding Structured LDPC Codes, +// https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative +// Commons Attribution 4.0 International Public License +// (https://creativecommons.org/licenses/by/4.0/legalcode). #include #if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) +#include "cryptoTools/Common/BitVector.h" #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" #include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/TwoChooseOne/OTExtInterface.h" #include "libOTe/Tools/Coproto.h" +#include "libOTe/TwoChooseOne/OTExtInterface.h" +#include "libOTe/Tools/CoeffCtx.h" - -namespace osuCrypto -{ - class NoisyVoleSender : public TimerAdapter +namespace osuCrypto { + template < + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class NoisyVoleSender : public TimerAdapter { + public: + using VecF = typename CoeffCtx::template Vec; + + // for chosen delta, compute b such htat + // + // a = b + c * delta + // + template + task<> send(F delta, FVec& b, PRNG& prng, + OtReceiver& ot, Socket& chl, CoeffCtx ctx) { + MC_BEGIN(task<>, this, delta, &b, &prng, &ot, &chl, ctx, + bv = ctx.binaryDecomposition(delta), + otMsg = AlignedUnVector{ }); + otMsg.resize(bv.size()); + + setTimePoint("NoisyVoleSender.ot.begin"); + + MC_AWAIT(ot.receive(bv, otMsg, prng, chl)); + setTimePoint("NoisyVoleSender.ot.end"); + + MC_AWAIT(send(delta, b, prng, otMsg, chl, ctx)); + + MC_END(); + } + + // for chosen delta, compute b such htat + // + // a = b + c * delta + // + template + task<> send(F delta, FVec& b, PRNG& _, + span otMsg, Socket& chl, CoeffCtx ctx) { + MC_BEGIN(task<>, this, delta, &b, otMsg, &chl, ctx, + prng = std::move(PRNG{}), + buffer = std::vector{}, + msg = VecF{}, + temp = VecF{}, + xb = BitVector{}); + + xb = ctx.binaryDecomposition(delta); + + if (otMsg.size() != xb.size()) + throw RTE_LOC; + setTimePoint("NoisyVoleSender.main"); + + // b = 0; + ctx.zero(b.begin(), b.end()); + + // receive the the excrypted one shares. + buffer.resize(xb.size() * b.size() * ctx.byteSize()); + MC_AWAIT(chl.recv(buffer)); + ctx.resize(msg, xb.size() * b.size()); + ctx.deserialize(buffer.begin(), buffer.end(), msg.begin()); + + setTimePoint("NoisyVoleSender.recvMsg"); + + temp.resize(1); + for (size_t i = 0, k = 0; i < xb.size(); ++i) + { + // expand the zero shares or one share masks + prng.SetSeed(otMsg[i], b.size()); + + // otMsg[i,j, bc[i]] + //auto otMsgi = prng.getBufferSpan(b.size()); + + for (u64 j = 0; j < (u64)b.size(); ++j, ++k) + { + // temp = otMsg[i,j, xb[i]] + ctx.fromBlock(temp[0], prng.get()); + //ctx.zero(temp.begin(), temp.begin() + 1); + //std::cout << "m" << i << ","< send(block x, span z, PRNG& prng, OtReceiver& ot, Socket& chl); - task<> send(block x, span z, PRNG& prng, span otMsg, Socket& chl); }; -} +} // namespace osuCrypto #endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.cpp b/libOTe/Vole/Silent/SilentVoleReceiver.cpp deleted file mode 100644 index ecd25b4e..00000000 --- a/libOTe/Vole/Silent/SilentVoleReceiver.cpp +++ /dev/null @@ -1,821 +0,0 @@ -#include "libOTe/Vole/Silent/SilentVoleReceiver.h" - -#ifdef ENABLE_SILENT_VOLE -#include "libOTe/Vole/Silent/SilentVoleSender.h" -#include "libOTe/Vole/Noisy/NoisyVoleReceiver.h" -#include - -#include -#include -#include -#include -#include - -namespace osuCrypto -{ - - - //u64 getPartitions(u64 scaler, u64 p, u64 secParam); - - - - // sets the Iknp base OTs that are then used to extend - void SilentVoleReceiver::setBaseOts( - span> baseSendOts) - { -#ifdef ENABLE_SOFTSPOKEN_OT - mOtExtRecver.setBaseOts(baseSendOts); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - // return the number of base OTs soft spoken needs - u64 SilentVoleReceiver::baseOtCount() const { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtRecver.baseOtCount(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - // returns true if the soft spoken base OTs are currently set. - bool SilentVoleReceiver::hasBaseOts() const { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtRecver.hasBaseOts(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - }; - - - BitVector SilentVoleReceiver::sampleBaseChoiceBits(PRNG& prng) { - - if (isConfigured() == false) - throw std::runtime_error("configure(...) must be called first"); - - auto choice = mGen.sampleChoiceBits(mN2, getPprfFormat(), prng); - - mGapBaseChoice.resize(mGapOts.size()); - mGapBaseChoice.randomize(prng); - choice.append(mGapBaseChoice); - - return choice; - } - - std::vector SilentVoleReceiver::sampleBaseVoleVals(PRNG& prng) - { - if (isConfigured() == false) - throw RTE_LOC; - if (mGapBaseChoice.size() != mGapOts.size()) - throw std::runtime_error("sampleBaseChoiceBits must be called before sampleBaseVoleVals. " LOCATION); - - // sample the values of the noisy coordinate of c - // and perform a noicy vole to get x+y = mD * c - auto w = mNumPartitions + mGapOts.size(); - //std::vector y(w); - mNoiseValues.resize(w); - prng.get(mNoiseValues); - - mS.resize(mNumPartitions); - mGen.getPoints(mS, getPprfFormat()); - - auto j = mNumPartitions * mSizePer; - - for (u64 i = 0; i < (u64)mGapBaseChoice.size(); ++i) - { - if (mGapBaseChoice[i]) - { - mS.push_back(j + i); - } - } - - if (mMalType == SilentSecType::Malicious) - { - - mMalCheckSeed = prng.get(); - mMalCheckX = ZeroBlock; - auto yIter = mNoiseValues.begin(); - - for (u64 i = 0; i < mNumPartitions; ++i) - { - auto s = mS[i]; - auto xs = mMalCheckSeed.gf128Pow(s + 1); - mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); - ++yIter; - } - - auto sIter = mS.begin() + mNumPartitions; - for (u64 i = 0; i < mGapBaseChoice.size(); ++i) - { - if (mGapBaseChoice[i]) - { - auto s = *sIter; - auto xs = mMalCheckSeed.gf128Pow(s + 1); - mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); - ++sIter; - } - ++yIter; - } - - - std::vector y(mNoiseValues.begin(), mNoiseValues.end()); - y.push_back(mMalCheckX); - return y; - } - - return std::vector(mNoiseValues.begin(), mNoiseValues.end()); - } - - task<> SilentVoleReceiver::genBaseOts( - PRNG& prng, - Socket& chl) - { - setTimePoint("SilentVoleReceiver.gen.start"); -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtRecver.genBaseOts(prng, chl); - //mIknpSender.genBaseOts(mIknpRecver, prng, chl); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - void SilentVoleReceiver::configure( - u64 numOTs, - SilentBaseType type, - u64 secParam) - { - mState = State::Configured; - u64 gap = 0; - mBaseType = type; - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - { - u64 p, s; - QuasiCyclicConfigure(numOTs, secParam, - 2, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - p, - s - ); -#ifdef ENABLE_BITPOLYMUL - mQuasiCyclicEncoder.init(p, s); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - break; - } -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - SilverConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - - break; -#endif - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - EAConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - mEAEncoder); - - break; - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - - ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); - break; - default: - throw RTE_LOC; - break; - } - - mGapOts.resize(gap); - mGen.configure(mSizePer, mNumPartitions); - } - - - task<> SilentVoleReceiver::genSilentBaseOts( - PRNG& prng, - Socket& chl) - { -#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE -using BaseOT = McRosRoyTwist; -#elif defined ENABLE_MR -using BaseOT = MasnyRindal; -#elif defined ENABLE_MRR -using BaseOT = McRosRoy; -#else -using BaseOT = DefaultBaseOT; -#endif - - MC_BEGIN(task<>, this, &prng, &chl, - choice = BitVector{}, - bb = BitVector{}, - msg = AlignedUnVector{}, - baseVole = std::vector{}, - baseOt = BaseOT{}, - chl2 = Socket{}, - prng2 = std::move(PRNG{}), - noiseVals = std::vector{}, - noiseDeltaShares = std::vector{}, - nv = NoisyVoleReceiver{} - - ); - - setTimePoint("SilentVoleReceiver.genSilent.begin"); - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - choice = sampleBaseChoiceBits(prng); - msg.resize(choice.size()); - - // sample the noise vector noiseVals such that we will compute - // - // C = (000 noiseVals[0] 0000 ... 000 noiseVals[p] 000) - // - // and then we want secret shares of C * delta. As a first step - // we will compute secret shares of - // - // delta * noiseVals - // - // and store our share in voleDeltaShares. This party will then - // compute their share of delta * C as what comes out of the PPRF - // plus voleDeltaShares[i] added to the appreciate spot. Similarly, the - // other party will program the PPRF to output their share of delta * noiseVals. - // - noiseVals = sampleBaseVoleVals(prng); - noiseDeltaShares.resize(noiseVals.size()); - if (mTimer) - nv.setTimer(*mTimer); - - if (mBaseType == SilentBaseType::BaseExtend) - { -#ifdef ENABLE_SOFTSPOKEN_OT - - if (mOtExtSender.hasBaseOts() == false) - { - msg.resize(msg.size() + mOtExtSender.baseOtCount()); - bb.resize(mOtExtSender.baseOtCount()); - bb.randomize(prng); - choice.append(bb); - - MC_AWAIT(mOtExtRecver.receive(choice, msg, prng, chl)); - - mOtExtSender.setBaseOts( - span(msg).subspan( - msg.size() - mOtExtSender.baseOtCount(), - mOtExtSender.baseOtCount()), - bb); - - msg.resize(msg.size() - mOtExtSender.baseOtCount()); - MC_AWAIT(nv.receive(noiseVals, noiseDeltaShares, prng, mOtExtSender, chl)); - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - - MC_AWAIT( - macoro::when_all_ready( - nv.receive(noiseVals, noiseDeltaShares, prng2, mOtExtSender, chl2), - mOtExtRecver.receive(choice, msg, prng, chl) - )); - } -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - MC_AWAIT( - macoro::when_all_ready( - nv.receive(noiseVals, noiseDeltaShares, prng2, baseOt, chl2), - baseOt.receive(choice, msg, prng, chl) - )); - } - - - - - setSilentBaseOts(msg, noiseDeltaShares); - - setTimePoint("SilentVoleReceiver.genSilent.done"); - - MC_END(); - }; - - void SilentVoleReceiver::setSilentBaseOts( - span recvBaseOts, - span noiseDeltaShare) - { - if (isConfigured() == false) - throw std::runtime_error("configure(...) must be called first."); - - if (static_cast(recvBaseOts.size()) != silentBaseOtCount()) - throw std::runtime_error("wrong number of silent base OTs"); - - auto genOts = recvBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOts = recvBaseOts.subspan(mGen.baseOtCount(), mGapOts.size()); - - mGen.setBase(genOts); - std::copy(gapOts.begin(), gapOts.end(), mGapOts.begin()); - - if (mMalType == SilentSecType::Malicious) - { - mDeltaShare = noiseDeltaShare.back(); - noiseDeltaShare = noiseDeltaShare.subspan(0, noiseDeltaShare.size() - 1); - } - - mNoiseDeltaShare = AlignedVector(noiseDeltaShare.begin(), noiseDeltaShare.end()); - - mState = State::HasBase; - } - - - - //sigma = 0 Receiver - // - // u_i is the choice bit - // v_i = w_i + u_i * x - // - // ------------------------ - - // u' = 0000001000000000001000000000100000...00000, u_i = 1 iff i \in S - // - // v' = r + (x . u') = DPF(k0) - // = r + (000000x00000000000x000000000x00000...00000) - // - // u = u' * H bit-vector * H. Mapping n'->n bits - // v = v' * H block-vector * H. Mapping n'->n block - // - //sigma = 1 Sender - // - // x is the delta - // w_i is the zero message - // - // m_i0 = w_i - // m_i1 = w_i + x - // - // ------------------------ - // x - // r = DPF(k1) - // - // w = r * H - task<> SilentVoleReceiver::silentReceive( - span c, - span b, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, c, b, &prng, &chl); - if (c.size() != b.size()) - throw RTE_LOC; - - MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); - - std::memcpy(c.data(), mC.data(), c.size() * sizeof(block)); - std::memcpy(b.data(), mA.data(), b.size() * sizeof(block)); - clear(); - MC_END(); - } - - task<> SilentVoleReceiver::silentReceiveInplace( - u64 n, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, n, &prng, &chl, - gapVals = std::vector{}, - myHash = std::array{}, - theirHash = std::array{} - - ); - gTimer.setTimePoint("SilentVoleReceiver.ot.enter"); - - if (isConfigured() == false) - { - // first generate 128 normal base OTs - configure(n, SilentBaseType::BaseExtend); - } - - if (mRequestedNumOTs != n) - throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); - - if (hasSilentBaseOts() == false) - { - MC_AWAIT(genSilentBaseOts(prng, chl)); - } - - // allocate mA - mA.resize(0); - mA.resize(mN2); - - setTimePoint("SilentVoleReceiver.alloc"); - - // allocate the space for mC - mC.resize(0); - mC.resize(mN2, AllocType::Zeroed); - setTimePoint("SilentVoleReceiver.alloc.zero"); - - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - - if(gapVals.size()) - MC_AWAIT(chl.recv(gapVals)); - - for (auto g : rng(mGapOts.size())) - { - auto aa = mA.subspan(mNumPartitions * mSizePer); - auto cc = mC.subspan(mNumPartitions * mSizePer); - - auto noise = mNoiseValues.subspan(mNumPartitions); - auto noiseShares = mNoiseDeltaShare.subspan(mNumPartitions); - - if (mGapBaseChoice[g]) - { - cc[g] = noise[g]; - aa[g] = AES(mGapOts[g]).ecbEncBlock(ZeroBlock) ^ - gapVals[g] ^ - noiseShares[g]; - } - else - aa[g] = mGapOts[g]; - } - - setTimePoint("SilentVoleReceiver.recvGap"); - - - - if (mTimer) - mGen.setTimer(*mTimer); - // expand the seeds into mA - MC_AWAIT(mGen.expand(chl, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); - - setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); - - // populate the noisy coordinates of mC and - // update mA to be a secret share of mC * delta - for (u64 i = 0; i < mNumPartitions; ++i) - { - auto pnt = mS[i]; - mC[pnt] = mNoiseValues[i]; - mA[pnt] = mA[pnt] ^ mNoiseDeltaShare[i]; - } - - - if (mDebug) - { - MC_AWAIT(checkRT(chl)); - setTimePoint("SilentVoleReceiver.expand.checkRT"); - } - - - if (mMalType == SilentSecType::Malicious) - { - MC_AWAIT(chl.send(std::move(mMalCheckSeed))); - - myHash = ferretMalCheck(mDeltaShare, mNoiseValues); - - MC_AWAIT(chl.recv(theirHash)); - - if (theirHash != myHash) - throw RTE_LOC; - } - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - -#ifdef ENABLE_BITPOLYMUL - if (mTimer) - mQuasiCyclicEncoder.setTimer(getTimer()); - - // compress both mA and mC in place. - mQuasiCyclicEncoder.dualEncode(mA.subspan(0, mQuasiCyclicEncoder.size())); - mQuasiCyclicEncoder.dualEncode(mC.subspan(0, mQuasiCyclicEncoder.size())); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - - setTimePoint("SilentVoleReceiver.expand.mQuasiCyclicEncoder.a"); - break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - if (mTimer) - mEncoder.setTimer(getTimer()); - - // compress both mA and mC in place. - mEncoder.dualEncode2(mA, mC); - setTimePoint("SilentVoleReceiver.expand.cirTransEncode.a"); - break; -#endif - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - { - if (mTimer) - mEAEncoder.setTimer(getTimer()); - - AlignedUnVector - A2(mEAEncoder.mMessageSize), - C2(mEAEncoder.mMessageSize); - - // compress both mA and mC in place. - mEAEncoder.dualEncode2( - mA.subspan(0, mEAEncoder.mCodeSize), A2, - mC.subspan(0, mEAEncoder.mCodeSize), C2); - - std::swap(mA, A2); - std::swap(mC, C2); - - setTimePoint("SilentVoleReceiver.expand.cirTransEncode.a"); - break; - } - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - if (mTimer) - mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode2( - mA.subspan(0, mExConvEncoder.mCodeSize), - mC.subspan(0, mExConvEncoder.mCodeSize) - ); - break; - default: - throw RTE_LOC; - break; - } - - // resize the buffers down to only contain the real elements. - mA.resize(mRequestedNumOTs); - mC.resize(mRequestedNumOTs); - - mNoiseValues = {}; - mNoiseDeltaShare = {}; - - // make the protocol as done and that - // mA,mC are ready to be consumed. - mState = State::Default; - - MC_END(); - } - - - std::array SilentVoleReceiver::ferretMalCheck( - block deltaShare, - span yy) - { - - block xx = mMalCheckSeed; - block sum0 = ZeroBlock; - block sum1 = ZeroBlock; - - - for (u64 i = 0; i < (u64)mA.size(); ++i) - { - block low, high; - xx.gf128Mul(mA[i], low, high); - sum0 = sum0 ^ low; - sum1 = sum1 ^ high; - //mySum = mySum ^ xx.gf128Mul(mA[i]); - - // xx = mMalCheckSeed^{i+1} - xx = xx.gf128Mul(mMalCheckSeed); - } - block mySum = sum0.gf128Reduce(sum1); - - std::array myHash; - RandomOracle ro(32); - ro.Update(mySum ^ deltaShare); - ro.Final(myHash); - return myHash; - } - - - u64 SilentVoleReceiver::silentBaseOtCount() const - { - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - return mGen.baseOtCount() + mGapOts.size(); - - } - - task<> SilentVoleReceiver::checkRT(Socket& chl) const - { - MC_BEGIN(task<>, this, &chl, - B = AlignedVector(mA.size()), - sparseNoiseDelta = std::vector(mA.size()), - noiseDeltaShare2 = std::vector(), - delta = block{} - ); - //std::vector mB(mA.size()); - MC_AWAIT(chl.recv(delta)); - MC_AWAIT(chl.recv(B)); - MC_AWAIT(chl.recvResize(noiseDeltaShare2)); - - //check that at locations mS[0],...,mS[..] - // that we hold a sharing mA, mB of - // - // delta * mC = delta * (00000 noiseDeltaShare2[0] 0000 .... 0000 noiseDeltaShare2[m] 0000) - // - // where noiseDeltaShare2[i] is at position mS[i] of mC - // - // That is, I hold mA, mC s.t. - // - // delta * mC = mA + mB - // - - if (noiseDeltaShare2.size() != mNoiseDeltaShare.size()) - throw RTE_LOC; - - for (auto i : rng(mNoiseDeltaShare.size())) - { - if ((mNoiseDeltaShare[i] ^ noiseDeltaShare2[i]) != mNoiseValues[i].gf128Mul(delta)) - throw RTE_LOC; - } - - { - - for (auto i : rng(mNumPartitions* mSizePer)) - { - auto iter = std::find(mS.begin(), mS.end(), i); - if (iter != mS.end()) - { - auto d = iter - mS.begin(); - - if (mC[i] != mNoiseValues[d]) - throw RTE_LOC; - - if (mNoiseValues[d].gf128Mul(delta) != (mA[i] ^ B[i])) - { - std::cout << "bad vole base correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; - std::cout << "i " << i << std::endl; - std::cout << "mA[i] " << mA[i] << std::endl; - std::cout << "mB[i] " << B[i] << std::endl; - std::cout << "mC[i] " << mC[i] << std::endl; - std::cout << "delta " << delta << std::endl; - std::cout << "mA[i] + mB[i] " << (mA[i] ^ B[i]) << std::endl; - std::cout << "mC[i] * delta " << (mC[i].gf128Mul(delta)) << std::endl; - - throw RTE_LOC; - } - } - else - { - if (mA[i] != B[i]) - { - std::cout << mA[i] << " " << B[i] << std::endl; - throw RTE_LOC; - } - - if (mC[i] != oc::ZeroBlock) - throw RTE_LOC; - } - } - - u64 d = mNumPartitions; - for (auto j : rng(mGapBaseChoice.size())) - { - auto idx = j + mNumPartitions * mSizePer; - auto aa = mA.subspan(mNumPartitions * mSizePer); - auto bb = B.subspan(mNumPartitions * mSizePer); - auto cc = mC.subspan(mNumPartitions * mSizePer); - auto noise = mNoiseValues.subspan(mNumPartitions); - //auto noiseShare = mNoiseValues.subspan(mNumPartitions); - if (mGapBaseChoice[j]) - { - if (mS[d++] != idx) - throw RTE_LOC; - - if (cc[j] != noise[j]) - { - std::cout << "sparse noise vector mC is not the expected value" << std::endl; - std::cout << "i j " << idx << " " << j << std::endl; - std::cout << "mC[i] " << cc[j] << std::endl; - std::cout << "noise[j] " << noise[j] << std::endl; - throw RTE_LOC; - } - - if (noise[j].gf128Mul(delta) != (aa[j] ^ bb[j])) - { - - std::cout << "bad vole base GAP correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; - std::cout << "i " << idx << std::endl; - std::cout << "mA[i] " << aa[j] << std::endl; - std::cout << "mB[i] " << bb[j] << std::endl; - std::cout << "mC[i] " << cc[j] << std::endl; - std::cout << "delta " << delta << std::endl; - std::cout << "mA[i] + mB[i] " << (aa[j] ^ bb[j]) << std::endl; - std::cout << "mC[i] * delta " << (cc[j].gf128Mul(delta)) << std::endl; - std::cout << "noise * delta " << (noise[j].gf128Mul(delta)) << std::endl; - throw RTE_LOC; - } - - } - else - { - if (aa[j] != bb[j]) - throw RTE_LOC; - - if (cc[j] != oc::ZeroBlock) - throw RTE_LOC; - } - } - - if (d != mS.size()) - throw RTE_LOC; - } - - - //{ - - // auto cDelta = B; - // for (u64 i = 0; i < cDelta.size(); ++i) - // cDelta[i] = cDelta[i] ^ mA[i]; - - // std::vector exp(mN2); - // for (u64 i = 0; i < mNumPartitions; ++i) - // { - // auto j = mS[i]; - // exp[j] = noiseDeltaShare2[i]; - // } - - // auto iter = mS.begin() + mNumPartitions; - // for (u64 i = 0, j = mNumPartitions * mSizePer; i < mGapOts.size(); ++i, ++j) - // { - // if (mGapBaseChoice[i]) - // { - // if (*iter != j) - // throw RTE_LOC; - // ++iter; - - // exp[j] = noiseDeltaShare2[mNumPartitions + i]; - // } - // } - - // if (iter != mS.end()) - // throw RTE_LOC; - - // bool failed = false; - // for (u64 i = 0; i < mN2; ++i) - // { - // if (neq(cDelta[i], exp[i])) - // { - // std::cout << i << " / " << mN2 << - // " cd = " << cDelta[i] << - // " exp= " << exp[i] << std::endl; - // failed = true; - // } - // } - - // if (failed) - // throw RTE_LOC; - - // std::cout << "debug check ok" << std::endl; - //} - - MC_END(); - } - - - void SilentVoleReceiver::clear() - { - mS = {}; - mA = {}; - mC = {}; - mGen.clear(); - mGapBaseChoice = {}; - } - - -} -#endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index 91e4c110..5f8ace0a 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -14,25 +14,35 @@ #include #include #include -#include +#include "libOTe/Tools/Pprf/RegularPprf.h" #include #include #include -#include #include -#include #include - +#include +#include +#include +#include +#include "libOTe/Tools/QuasiCyclicCode.h" namespace osuCrypto { - // For more documentation see SilentOtExtSender. - class SilentVoleReceiver : public TimerAdapter + template< + typename F, + typename G = F, + typename Ctx = DefaultCoeffCtx + > + class SilentSubfieldVoleReceiver : public TimerAdapter { public: static constexpr u64 mScaler = 2; + static constexpr bool MaliciousSupported = + std::is_same_v&& + std::is_same_v; + enum class State { Default, @@ -40,22 +50,29 @@ namespace osuCrypto HasBase }; + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; + // The current state of the protocol State mState = State::Default; - // The number of OTs the user requested. - u64 mRequestedNumOTs = 0; + // the context used to perform F, G operations + Ctx mCtx; + + // The number of correlations the user requested. + u64 mRequestSize = 0; + + // the LPN security parameter + u64 mSecParam = 0; + + // The length of the noisy vectors (2 * mN for the most codes). + u64 mNoiseVecSize = 0; - // The number of OTs actually produced (at least the number requested). - u64 mN = 0; - - // The length of the noisy vectors (2 * mN for the silver codes). - u64 mN2 = 0; - // We perform regular LPN, so this is the // size of the each chunk. u64 mSizePer = 0; + // the number of noisy positions u64 mNumPartitions = 0; // The noisy coordinates. @@ -68,42 +85,30 @@ namespace osuCrypto // the sparse vector. MultType mMultType = DefaultMultType; -#ifdef ENABLE_INSECURE_SILVER - // The silver encoder. - SilverEncoder mEncoder; -#endif - - ExConvCode mExConvEncoder; - EACode mEAEncoder; - -#ifdef ENABLE_BITPOLYMUL - QuasiCyclicCode mQuasiCyclicEncoder; -#endif - // The multi-point punctured PRF for generating // the sparse vectors. - SilentMultiPprfReceiver mGen; + RegularPprfReceiver mGen; // The internal buffers for holding the expanded vectors. - // mA + mB = mC * delta - AlignedUnVector mA; - - // mA + mB = mC * delta - AlignedUnVector mC; + // mA = mB + mC * delta + VecF mA; - std::vector mGapOts; + // mA = mB + mC * delta + VecG mC; u64 mNumThreads = 1; bool mDebug = false; - BitVector mIknpSendBaseChoice, mGapBaseChoice; + BitVector mIknpSendBaseChoice; - SilentSecType mMalType = SilentSecType::SemiHonest; + SilentSecType mMalType = SilentSecType::SemiHonest; - block mMalCheckSeed, mMalCheckX, mDeltaShare; - - AlignedVector mNoiseDeltaShare, mNoiseValues; + block mMalCheckSeed, mMalCheckX, mMalBaseA; + + // we + VecF mBaseA; + VecG mBaseC; #ifdef ENABLE_SOFTSPOKEN_OT @@ -111,73 +116,306 @@ namespace osuCrypto SoftSpokenMalOtReceiver mOtExtRecver; #endif - // sets the Iknp base OTs that are then used to extend - void setBaseOts( - span> baseSendOts); - - // return the number of base OTs IKNP needs - u64 baseOtCount() const; + // // sets the Iknp base OTs that are then used to extend + // void setBaseOts( + // span> baseSendOts); + // + // // return the number of base OTs IKNP needs + // u64 baseOtCount() const; u64 baseVoleCount() const { - return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); } - // returns true if the IKNP base OTs are currently set. - bool hasBaseOts() const; - - // returns true if the silent base OTs are set. + // // returns true if the IKNP base OTs are currently set. + // bool hasBaseOts() const; + // + // returns true if the silent base OTs are set. bool hasSilentBaseOts() const { return mGen.hasBaseOts(); }; - - // Generate the IKNP base OTs - task<> genBaseOts(PRNG& prng, Socket& chl) ; + // + // // Generate the IKNP base OTs + // task<> genBaseOts(PRNG& prng, Socket& chl) ; // Generate the silent base OTs. If the Iknp // base OTs are set then we do an IKNP extend, // otherwise we perform a base OT protocol to // generate the needed OTs. - task<> genSilentBaseOts(PRNG& prng, Socket& chl); - + task<> genSilentBaseOts(PRNG& prng, Socket& chl) + { + using BaseOT = DefaultBaseOT; + + + MC_BEGIN(task<>, this, &prng, &chl, + choice = BitVector{}, + bb = BitVector{}, + msg = AlignedUnVector{}, + baseVole = std::vector{}, + baseOt = BaseOT{}, + chl2 = Socket{}, + prng2 = std::move(PRNG{}), + noiseVals = VecG{}, + baseAs = VecF{}, + nv = NoisyVoleReceiver{} + + ); + + setTimePoint("SilentVoleReceiver.genSilent.begin"); + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + choice = sampleBaseChoiceBits(prng); + msg.resize(choice.size()); + + // sample the noise vector noiseVals such that we will compute + // + // C = (000 noiseVals[0] 0000 ... 000 noiseVals[p] 000) + // + // and then we want secret shares of C * delta. As a first step + // we will compute secret shares of + // + // delta * noiseVals + // + // and store our share in voleDeltaShares. This party will then + // compute their share of delta * C as what comes out of the PPRF + // plus voleDeltaShares[i] added to the appreciate spot. Similarly, the + // other party will program the PPRF to output their share of delta * noiseVals. + // + noiseVals = sampleBaseVoleVals(prng); + mCtx.resize(baseAs, noiseVals.size()); + + if (mTimer) + nv.setTimer(*mTimer); + + if (mBaseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + + if (mOtExtSender.hasBaseOts() == false) + { + msg.resize(msg.size() + mOtExtSender.baseOtCount()); + bb.resize(mOtExtSender.baseOtCount()); + bb.randomize(prng); + choice.append(bb); + + MC_AWAIT(mOtExtRecver.receive(choice, msg, prng, chl)); + + mOtExtSender.setBaseOts( + span(msg).subspan( + msg.size() - mOtExtSender.baseOtCount(), + mOtExtSender.baseOtCount()), + bb); + + msg.resize(msg.size() - mOtExtSender.baseOtCount()); + MC_AWAIT(nv.receive(noiseVals, baseAs, prng, mOtExtSender, chl, mCtx)); + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + + MC_AWAIT( + macoro::when_all_ready( + nv.receive(noiseVals, baseAs, prng2, mOtExtSender, chl2, mCtx), + mOtExtRecver.receive(choice, msg, prng, chl) + )); + } +#else + throw std::runtime_error("soft spoken must be enabled"); +#endif + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + MC_AWAIT( + macoro::when_all_ready( + baseOt.receive(choice, msg, prng, chl), + nv.receive(noiseVals, baseAs, prng2, baseOt, chl2, mCtx)) + ); + } + + setSilentBaseOts(msg, baseAs); + setTimePoint("SilentVoleReceiver.genSilent.done"); + MC_END(); + }; + // configure the silent OT extension. This sets // the parameters and figures out how many base OT // will be needed. These can then be ganerated for // a different OT extension or using a base OT protocol. void configure( - u64 n, - SilentBaseType baseType = SilentBaseType::BaseExtend, - u64 secParam = 128); + u64 requestSize, + SilentBaseType type = SilentBaseType::BaseExtend, + u64 secParam = 128, + Ctx ctx = {}) + { + mCtx = std::move(ctx); + mSecParam = secParam; + mRequestSize = requestSize; + mState = State::Configured; + mBaseType = type; + double minDist = 0; + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + u64 _1, _2; + ExConvConfigure(mScaler, mMultType, _1, _2, minDist); + break; + } + case MultType::QuasiCyclic: + QuasiCyclicConfigure(mScaler, minDist); + break; + default: + throw RTE_LOC; + break; + } + + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); + mNoiseVecSize = mSizePer * mNumPartitions; + + //std::cout << "n " << mRequestSize << " -> " << mNoiseVecSize << " = " << mSizePer << " * " << mNumPartitions << std::endl; + + mGen.configure(mSizePer, mNumPartitions); + } // return true if this instance has been configured. bool isConfigured() const { return mState != State::Default; } // Returns how many base OTs the silent OT extension // protocol will needs. - u64 silentBaseOtCount() const; + u64 silentBaseOtCount() const + { + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + return mGen.baseOtCount(); + + } // The silent base OTs must have specially set base OTs. // This returns the choice bits that should be used. // Call this is you want to use a specific base OT protocol // and then pass the OT messages back using setSilentBaseOts(...). - BitVector sampleBaseChoiceBits(PRNG& prng); + BitVector sampleBaseChoiceBits(PRNG& prng) { - std::vector sampleBaseVoleVals(PRNG& prng); + if (isConfigured() == false) + throw std::runtime_error("configure(...) must be called first"); + + auto choice = mGen.sampleChoiceBits(prng); + + return choice; + } + + VecG sampleBaseVoleVals(PRNG& prng) + { + if (isConfigured() == false) + throw RTE_LOC; + + // sample the values of the noisy coordinate of c + // and perform a noicy vole to get a = b + mD * c + + + VecG zero, one; + mCtx.resize(zero, 1); + mCtx.resize(one, 1); + mCtx.zero(zero.begin(), zero.end()); + mCtx.one(one.begin(), one.end()); + mCtx.resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); + for (size_t i = 0; i < mNumPartitions; i++) + { + mCtx.fromBlock(mBaseC[i], prng.get()); + + // must not be zero. + while(mCtx.eq(zero[0], mBaseC[i])) + mCtx.fromBlock(mBaseC[i], prng.get()); + + // if we are not a field, then the noise should be odd. + if (mCtx.isField() == false) + { + u8 odd = mCtx.binaryDecomposition(mBaseC[i])[0]; + if (odd) + mCtx.plus(mBaseC[i], mBaseC[i], one[0]); + } + } + + + mS.resize(mNumPartitions); + mGen.getPoints(mS, PprfOutputFormat::Interleaved); + + if (mMalType == SilentSecType::Malicious) + { + if constexpr (MaliciousSupported) + { + mMalCheckSeed = prng.get(); + + auto yIter = mBaseC.begin(); + mCtx.zero(mBaseC.end() - 1, mBaseC.end()); + for (u64 i = 0; i < mNumPartitions; ++i) + { + auto s = mS[i]; + auto xs = mMalCheckSeed.gf128Pow(s + 1); + mBaseC[mNumPartitions] = mBaseC[mNumPartitions] ^ xs.gf128Mul(*yIter); + ++yIter; + } + } + else + { + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + } + } + + return mBaseC; + } // Set the externally generated base OTs. This choice // bits must be the one return by sampleBaseChoiceBits(...). - void setSilentBaseOts(span recvBaseOts, - span voleBase); + void setSilentBaseOts( + span recvBaseOts, + VecF& baseA) + { + if (isConfigured() == false) + throw std::runtime_error("configure(...) must be called first."); + + if (static_cast(recvBaseOts.size()) != silentBaseOtCount()) + throw std::runtime_error("wrong number of silent base OTs"); + + mGen.setBase(recvBaseOts); + + mCtx.resize(mBaseA, baseA.size()); + mCtx.copy(baseA.begin(), baseA.end(), mBaseA.begin()); + mState = State::HasBase; + } // Perform the actual OT extension. If silent // base OTs have been generated or set, then // this function is non-interactive. Otherwise // the silent base OTs will automatically be performed. task<> silentReceive( - span c, - span a, - PRNG & prng, - Socket & chl); + VecG& c, + VecF& a, + PRNG& prng, + Socket& chl) + { + MC_BEGIN(task<>, this, &c, &a, &prng, &chl); + if (c.size() != a.size()) + throw RTE_LOC; + + MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); + + mCtx.copy(mC.begin(), mC.begin() + c.size(), c.begin()); + mCtx.copy(mA.begin(), mA.begin() + a.size(), a.begin()); + + clear(); + MC_END(); + } // Perform the actual OT extension. If silent // base OTs have been generated or set, then @@ -186,23 +424,330 @@ namespace osuCrypto task<> silentReceiveInplace( u64 n, PRNG& prng, - Socket& chl); + Socket& chl) + { + MC_BEGIN(task<>, this, n, &prng, &chl, + myHash = std::array{}, + theirHash = std::array{} + ); + gTimer.setTimePoint("SilentVoleReceiver.ot.enter"); + + if (isConfigured() == false) + { + // first generate 128 normal base OTs + configure(n, SilentBaseType::BaseExtend); + } + + if (mRequestSize != n) + throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); + + if (hasSilentBaseOts() == false) + { + MC_AWAIT(genSilentBaseOts(prng, chl)); + } + + // allocate mA + mCtx.resize(mA, 0); + mCtx.resize(mA, mNoiseVecSize); + + setTimePoint("SilentVoleReceiver.alloc"); + + // allocate the space for mC + mCtx.resize(mC, 0); + mCtx.resize(mC, mNoiseVecSize); + mCtx.zero(mC.begin(), mC.end()); + setTimePoint("SilentVoleReceiver.alloc.zero"); + + if (mTimer) + mGen.setTimer(*mTimer); + + // As part of the setup, we have generated + // + // mBaseA + mBaseB = mBaseC * mDelta + // + // We have mBaseA, mBaseC, + // they have mBaseB, mDelta + // This was done with a small (noisy) vole. + // + // We use the Pprf to expand as + // + // mA' = mB + mS(mBaseB) + // = mB + mS(mBaseC * mDelta - mBaseA) + // = mB + mS(mBaseC * mDelta) - mS(mBaseA) + // + // Therefore if we add mS(mBaseA) to mA' we will get + // + // mA = mB + mS(mBaseC * mDelta) + // + MC_AWAIT(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); + + setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); + + // populate the noisy coordinates of mC and + // update mA to be a secret share of mC * delta + for (u64 i = 0; i < mNumPartitions; ++i) + { + auto pnt = mS[i]; + mCtx.copy(mC[pnt], mBaseC[i]); + mCtx.plus(mA[pnt], mA[pnt], mBaseA[i]); + } + + if (mDebug) + { + MC_AWAIT(checkRT(chl)); + setTimePoint("SilentVoleReceiver.expand.checkRT"); + } + + + if (mMalType == SilentSecType::Malicious) + { + MC_AWAIT(chl.send(std::move(mMalCheckSeed))); + + if constexpr (MaliciousSupported) + myHash = ferretMalCheck(); + else + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + + MC_AWAIT(chl.recv(theirHash)); + + if (theirHash != myHash) + throw RTE_LOC; + } + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + u64 expanderWeight, accumulatorWeight; + double _; + ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _); + ExConvCode encoder; + if (mScaler * mRequestSize > mNoiseVecSize) + throw RTE_LOC; + encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); + + if (mTimer) + encoder.setTimer(getTimer()); + + encoder.dualEncode2( + mA.begin(), + mC.begin(), + {} + ); + break; + } + case osuCrypto::MultType::QuasiCyclic: + { +#ifdef ENABLE_BITPOLYMUL + if constexpr ( + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + QuasiCyclicCode encoder; + encoder.init2(mRequestSize, mNoiseVecSize); + encoder.dualEncode(mA); + encoder.dualEncode(mC); + } + else + throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); +#else + throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); +#endif + break; + } + default: + throw std::runtime_error("Code is not supported. " LOCATION); + break; + } + // resize the buffers down to only contain the real elements. + mCtx.resize(mA, mRequestSize); + mCtx.resize(mC, mRequestSize); + mBaseC = {}; + mBaseA = {}; + + // make the protocol as done and that + // mA,mC are ready to be consumed. + mState = State::Default; + + MC_END(); + } - // internal. - task<> checkRT(Socket& chls) const; - std::array ferretMalCheck( - block deltaShare, - span y); - PprfOutputFormat getPprfFormat() + // internal. + task<> checkRT(Socket& chl) + { + MC_BEGIN(task<>, this, &chl, + B = VecF{}, + sparseNoiseDelta = VecF{}, + baseB = VecF{}, + delta = VecF{}, + tempF = VecF{}, + tempG = VecG{}, + buffer = std::vector{} + ); + + // recv delta + buffer.resize(mCtx.byteSize()); + mCtx.resize(delta, 1); + MC_AWAIT(chl.recv(buffer)); + mCtx.deserialize(buffer.begin(), buffer.end(), delta.begin()); + + // recv B + buffer.resize(mCtx.byteSize() * mA.size()); + mCtx.resize(B, mA.size()); + MC_AWAIT(chl.recv(buffer)); + mCtx.deserialize(buffer.begin(), buffer.end(), B.begin()); + + // recv the noisy values. + buffer.resize(mCtx.byteSize() * mBaseA.size()); + mCtx.resize(baseB, mBaseA.size()); + MC_AWAIT(chl.recvResize(buffer)); + mCtx.deserialize(buffer.begin(), buffer.end(), baseB.begin()); + + // it shoudl hold that + // + // mBaseA = baseB + mBaseC * mDelta + // + // and + // + // mA = mB + mC * mDelta + // + { + bool verbose = false; + bool failed = false; + std::vector index(mS.size()); + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(), + [&](std::size_t i, std::size_t j) { return mS[i] < mS[j]; }); + + mCtx.resize(tempF, 2); + mCtx.resize(tempG, 1); + mCtx.zero(tempG.begin(), tempG.end()); + + + // check the correlation that + // + // mBaseA + mBaseB = mBaseC * mDelta + for (auto i : rng(mBaseA.size())) + { + // temp[0] = baseB[i] + mBaseA[i] + mCtx.plus(tempF[0], baseB[i], mBaseA[i]); + + // temp[1] = mBaseC[i] * delta[0] + mCtx.mul(tempF[1], delta[0], mBaseC[i]); + + if (!mCtx.eq(tempF[0], tempF[1])) + throw RTE_LOC; + + if (i < mNumPartitions) + { + //auto idx = index[i]; + auto point = mS[i]; + if (!mCtx.eq(mBaseC[i], mC[point])) + throw RTE_LOC; + + if (i && mS[index[i - 1]] >= mS[index[i]]) + throw RTE_LOC; + } + } + + + auto iIter = index.begin(); + auto leafIdx = mS[*iIter]; + F act = tempF[0]; + G zero = tempG[0]; + mCtx.zero(tempG.begin(), tempG.end()); + + for (u64 j = 0; j < mA.size(); ++j) + { + mCtx.mul(act, delta[0], mC[j]); + mCtx.plus(act, act, B[j]); + + bool active = false; + if (j == leafIdx) + { + active = true; + } + else if (!mCtx.eq(zero, mC[j])) + throw RTE_LOC; + + if (mA[j] != act) + { + failed = true; + if (verbose) + std::cout << Color::Red; + } + + if (verbose) + { + std::cout << j << " act " << mCtx.str(act) + << " a " << mCtx.str(mA[j]) << " b " << mCtx.str(B[j]); + + if (active) + std::cout << " < " << mCtx.str(delta[0]); + + std::cout << std::endl << Color::Default; + } + + if (j == leafIdx) + { + ++iIter; + if (iIter != index.end()) + { + leafIdx = mS[*iIter]; + } + } + } + + if (failed) + throw RTE_LOC; + } + + MC_END(); + } + + std::array ferretMalCheck() { - return PprfOutputFormat::Interleaved; + + block xx = mMalCheckSeed; + block sum0 = ZeroBlock; + block sum1 = ZeroBlock; + + + for (u64 i = 0; i < (u64)mA.size(); ++i) + { + block low, high; + xx.gf128Mul(mA[i], low, high); + sum0 = sum0 ^ low; + sum1 = sum1 ^ high; + //mySum = mySum ^ xx.gf128Mul(mA[i]); + + // xx = mMalCheckSeed^{i+1} + xx = xx.gf128Mul(mMalCheckSeed); + } + + // = < + block mySum = sum0.gf128Reduce(sum1); + + std::array myHash; + RandomOracle ro(32); + ro.Update(mySum ^ mBaseA.back()); + ro.Final(myHash); + return myHash; } - void clear(); + void clear() + { + mS = {}; + mA = {}; + mC = {}; + mGen.clear(); + } }; } #endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleSender.cpp b/libOTe/Vole/Silent/SilentVoleSender.cpp deleted file mode 100644 index fa79e962..00000000 --- a/libOTe/Vole/Silent/SilentVoleSender.cpp +++ /dev/null @@ -1,435 +0,0 @@ -#include "libOTe/Vole/Silent/SilentVoleSender.h" -#ifdef ENABLE_SILENT_VOLE - -#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" -#include "libOTe/Vole/Noisy/NoisyVoleSender.h" - -#include "libOTe/Base/BaseOT.h" -#include "libOTe/Tools/Tools.h" -#include "cryptoTools/Common/Log.h" -#include "cryptoTools/Crypto/RandomOracle.h" - - -namespace osuCrypto -{ - u64 SilentVoleSender::baseOtCount() const - { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtSender.baseOtCount(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - bool SilentVoleSender::hasBaseOts() const - { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtSender.hasBaseOts(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - // sets the soft spoken base OTs that are then used to extend - void SilentVoleSender::setBaseOts( - span baseRecvOts, - const BitVector& choices) - { -#ifdef ENABLE_SOFTSPOKEN_OT - mOtExtSender.setBaseOts(baseRecvOts, choices); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - - task<> SilentVoleSender::genSilentBaseOts(PRNG& prng, Socket& chl, cp::optional delta) - { - -#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE - using BaseOT = McRosRoyTwist; -#elif defined ENABLE_MR - using BaseOT = MasnyRindal; -#elif defined ENABLE_MRR - using BaseOT = McRosRoy; -#else - using BaseOT = DefaultBaseOT; -#endif - - MC_BEGIN(task<>,this, delta, &prng, &chl, - msg = AlignedUnVector>(silentBaseOtCount()), - baseOt = BaseOT{}, - prng2 = std::move(PRNG{}), - xx = BitVector{}, - chl2 = Socket{}, - nv = NoisyVoleSender{}, - noiseDeltaShares = std::vector{} - ); - setTimePoint("SilentVoleSender.genSilent.begin"); - - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - - delta = delta.value_or(prng.get()); - xx.append(delta->data(), 128); - - // compute the correlation for the noisy coordinates. - noiseDeltaShares.resize(baseVoleCount()); - - - if (mBaseType == SilentBaseType::BaseExtend) - { -#ifdef ENABLE_SOFTSPOKEN_OT - - if (mOtExtRecver.hasBaseOts() == false) - { - msg.resize(msg.size() + mOtExtRecver.baseOtCount()); - MC_AWAIT(mOtExtSender.send(msg, prng, chl)); - - mOtExtRecver.setBaseOts( - span>(msg).subspan( - msg.size() - mOtExtRecver.baseOtCount(), - mOtExtRecver.baseOtCount())); - msg.resize(msg.size() - mOtExtRecver.baseOtCount()); - - MC_AWAIT(nv.send(*delta, noiseDeltaShares, prng, mOtExtRecver, chl)); - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - MC_AWAIT( - macoro::when_all_ready( - nv.send(*delta, noiseDeltaShares, prng2, mOtExtRecver, chl2), - mOtExtSender.send(msg, prng, chl))); - } -#else - -#endif - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - MC_AWAIT( - macoro::when_all_ready( - nv.send(*delta, noiseDeltaShares, prng2, baseOt, chl2), - baseOt.send(msg, prng, chl))); - } - - - setSilentBaseOts(msg, noiseDeltaShares); - setTimePoint("SilentVoleSender.genSilent.done"); - MC_END(); - } - - u64 SilentVoleSender::silentBaseOtCount() const - { - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - return mGen.baseOtCount(); - } - - void SilentVoleSender::setSilentBaseOts( - span> sendBaseOts, - span noiseDeltaShares) - { - if ((u64)sendBaseOts.size() != silentBaseOtCount()) - throw RTE_LOC; - - if (noiseDeltaShares.size() != baseVoleCount()) - throw RTE_LOC; - - mGen.setBase(sendBaseOts); - mNoiseDeltaShares.resize(noiseDeltaShares.size()); - std::copy(noiseDeltaShares.begin(), noiseDeltaShares.end(), mNoiseDeltaShares.begin()); - } - - void SilentVoleSender::configure( - u64 numOTs, - SilentBaseType type, - u64 secParam) - { - mBaseType = type; - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - { - u64 p, s; - - QuasiCyclicConfigure(numOTs, secParam, - 2, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - p, - s - ); -#ifdef ENABLE_BITPOLYMUL - mQuasiCyclicEncoder.init(p, s); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - break; - } - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - - EAConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - mEAEncoder); - break; - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - - ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); - break; - default: - throw RTE_LOC; - break; - } - - mGen.configure(mSizePer, mNumPartitions); - - mState = State::Configured; - } - - //sigma = 0 Receiver - // - // u_i is the choice bit - // v_i = w_i + u_i * x - // - // ------------------------ - - // u' = 0000001000000000001000000000100000...00000, u_i = 1 iff i \in S - // - // v' = r + (x . u') = DPF(k0) - // = r + (000000x00000000000x000000000x00000...00000) - // - // u = u' * H bit-vector * H. Mapping n'->n bits - // v = v' * H block-vector * H. Mapping n'->n block - // - //sigma = 1 Sender - // - // x is the delta - // w_i is the zero message - // - // m_i0 = w_i - // m_i1 = w_i + x - // - // ------------------------ - // x - // r = DPF(k1) - // - // w = r * H - - - task<> SilentVoleSender::checkRT(Socket& chl, block delta) const - { - MC_BEGIN(task<>,this, &chl, delta); - MC_AWAIT(chl.send(delta)); - MC_AWAIT(chl.send(mB)); - MC_AWAIT(chl.send(mNoiseDeltaShares)); - MC_END(); - } - - void SilentVoleSender::clear() - { - mB = {}; - mGen.clear(); - } - - task<> SilentVoleSender::silentSend( - block delta, - span b, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>,this, delta, b, &prng, &chl); - - MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); - - std::memcpy(b.data(), mB.data(), b.size() * sizeof(block)); - clear(); - - setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); - MC_END(); - } - - task<> SilentVoleSender::silentSendInplace( - block delta, - u64 n, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>,this, delta, n, &prng, &chl, - deltaShare = block{}, - X = block{}, - hash = std::array{}, - noiseShares = span{}, - mbb = span{} - ); - setTimePoint("SilentVoleSender.ot.enter"); - - - if (isConfigured() == false) - { - // first generate 128 normal base OTs - configure(n, SilentBaseType::BaseExtend); - } - - if (mRequestedNumOTs != n) - throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); - - if (mGen.hasBaseOts() == false) - { - // recvs data - MC_AWAIT(genSilentBaseOts(prng, chl, delta)); - } - - //mDelta = delta; - - setTimePoint("SilentVoleSender.start"); - //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); - - if (mMalType == SilentSecType::Malicious) - { - deltaShare = mNoiseDeltaShares.back(); - mNoiseDeltaShares.pop_back(); - } - - // allocate B - mB.resize(0); - mB.resize(mN2); - - - if (mTimer) - mGen.setTimer(*mTimer); - - // program the output the PPRF to be secret shares of - // our secret share of delta * noiseVals. The receiver - // can then manually add their shares of this to the - // output of the PPRF at the correct locations. - noiseShares = span(mNoiseDeltaShares.data(), mNumPartitions); - mbb = mB.subspan(0, mNumPartitions * mSizePer); - MC_AWAIT(mGen.expand(chl, noiseShares, prng.get(), mbb, - PprfOutputFormat::Interleaved, true, mNumThreads)); - - setTimePoint("SilentVoleSender.expand.pprf_transpose"); - if (mDebug) - { - MC_AWAIT(checkRT(chl, delta)); - setTimePoint("SilentVoleSender.expand.checkRT"); - } - - - if (mMalType == SilentSecType::Malicious) - { - MC_AWAIT(chl.recv(X)); - hash = ferretMalCheck(X, deltaShare); - MC_AWAIT(chl.send(std::move(hash))); - } - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - -#ifdef ENABLE_BITPOLYMUL - - if (mTimer) - mQuasiCyclicEncoder.setTimer(getTimer()); - - mQuasiCyclicEncoder.dualEncode(mB.subspan(0, mQuasiCyclicEncoder.size())); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - setTimePoint("SilentVoleSender.expand.QuasiCyclic"); - break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - if (mTimer) - mEncoder.setTimer(getTimer()); - - mEncoder.dualEncode(mB); - setTimePoint("SilentVoleSender.expand.Silver"); - break; -#endif - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - { - if (mTimer) - mEAEncoder.setTimer(getTimer()); - AlignedUnVector B2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mB.subspan(0,mEAEncoder.mCodeSize), B2); - std::swap(mB, B2); - - setTimePoint("SilentVoleSender.expand.Silver"); - break; - } - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - if (mTimer) - mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); - break; - default: - throw RTE_LOC; - break; - } - - - mB.resize(mRequestedNumOTs); - - mState = State::Default; - mNoiseDeltaShares.clear(); - - MC_END(); - } - - std::array SilentVoleSender::ferretMalCheck(block X, block deltaShare) - { - - auto xx = X; - block sum0 = ZeroBlock; - block sum1 = ZeroBlock; - for (u64 i = 0; i < (u64)mB.size(); ++i) - { - block low, high; - xx.gf128Mul(mB[i], low, high); - sum0 = sum0 ^ low; - sum1 = sum1 ^ high; - - xx = xx.gf128Mul(X); - } - - block mySum = sum0.gf128Reduce(sum1); - - std::array myHash; - RandomOracle ro(32); - ro.Update(mySum ^ deltaShare); - ro.Final(myHash); - - return myHash; - //chl.send(myHash); - } -} - -#endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index ebd38289..aa091c60 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -15,23 +15,33 @@ #include #include #include -#include +#include "libOTe/Tools/Pprf/RegularPprf.h" #include #include #include -#include -#include #include -//#define NO_HASH +#include +#include +#include +#include + namespace osuCrypto { - - class SilentVoleSender : public TimerAdapter + template< + typename F, + typename G = F, + typename Ctx = DefaultCoeffCtx + > + class SilentSubfieldVoleSender : public TimerAdapter { public: static constexpr u64 mScaler = 2; + static constexpr bool MaliciousSupported = + std::is_same_v&& + std::is_same_v; + enum class State { Default, @@ -39,99 +49,232 @@ namespace osuCrypto HasBase }; + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; State mState = State::Default; - SilentMultiPprfSender mGen; + // the context used to perform F, G operations + Ctx mCtx; + + // the pprf used to generate the noise vector. + RegularPprfSender mGen; + + // the number of correlations requested. + u64 mRequestSize = 0; - u64 mRequestedNumOTs = 0; - u64 mN2 = 0; - u64 mN = 0; + // the length of the noisy vector. + u64 mNoiseVecSize = 0; + + // the weight of the nosy vector u64 mNumPartitions = 0; + + // the size of each regular, weight 1, subvector + // of the noisy vector. mNoiseVecSize = mNumPartions * mSizePer u64 mSizePer = 0; - u64 mNumThreads = 1; - std::vector> mGapOts; - SilentBaseType mBaseType; - //block mDelta; - std::vector mNoiseDeltaShares; - SilentSecType mMalType = SilentSecType::SemiHonest; + // the lpn security parameters + u64 mSecParam = 0; -#ifdef ENABLE_SOFTSPOKEN_OT - SoftSpokenMalOtSender mOtExtSender; - SoftSpokenMalOtReceiver mOtExtRecver; -#endif + // the type of base OT OT that should be performed. + // Base requires more work but less communication. + SilentBaseType mBaseType = SilentBaseType::BaseExtend; - MultType mMultType = DefaultMultType; - ExConvCode mExConvEncoder; - EACode mEAEncoder; + // the base Vole correlation. To generate the silent vole, + // we must first create a small vole + // mBaseA + mBaseB = mBaseC * mDelta. + // These will be used to initialize the non-zeros of the noisy + // vector. mBaseB is the b in this corrlations. + VecF mBaseB; -#ifdef ENABLE_BITPOLYMUL - QuasiCyclicCode mQuasiCyclicEncoder; -#endif + // the full sized noisy vector. This will initalially be + // sparse with the corrlations + // mA = mB + mC * mDelta + // before it is compressed. + VecF mB; - //span mB; - //u64 mBackingSize = 0; - //std::unique_ptr mBacking; - AlignedUnVector mB; + // determines if the malicious checks are performed. + SilentSecType mMalType = SilentSecType::SemiHonest; - ///////////////////////////////////////////////////// - // The standard OT extension interface - ///////////////////////////////////////////////////// + // A flag to specify the linear code to use + MultType mMultType = DefaultMultType; - // the number of IKNP base OTs that should be set. - u64 baseOtCount() const; - // returns true if the IKNP base OTs are currently set. - bool hasBaseOts() const; + block mDeltaShare; - // sets the IKNP base OTs that are then used to extend - void setBaseOts( - span baseRecvOts, - const BitVector& choices); +#ifdef ENABLE_SOFTSPOKEN_OT + SoftSpokenMalOtSender mOtExtSender; + SoftSpokenMalOtReceiver mOtExtRecver; +#endif - // use the default base OT class to generate the - // IKNP base OTs that are required. - task<> genBaseOts(PRNG& prng, Socket& chl) - { - return mOtExtSender.genBaseOts(prng, chl); - } - ///////////////////////////////////////////////////// - // The native silent OT extension interface - ///////////////////////////////////////////////////// - u64 baseVoleCount() const { - return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + + u64 baseVoleCount() const + { + return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); } // Generate the silent base OTs. If the Iknp // base OTs are set then we do an IKNP extend, // otherwise we perform a base OT protocol to // generate the needed OTs. - task<> genSilentBaseOts(PRNG& prng, Socket& chl, cp::optional delta = {}); + task<> genSilentBaseOts(PRNG& prng, Socket& chl, F delta) + { + using BaseOT = DefaultBaseOT; + + MC_BEGIN(task<>, this, delta, &prng, &chl, + msg = AlignedUnVector>(silentBaseOtCount()), + baseOt = BaseOT{}, + prng2 = std::move(PRNG{}), + xx = BitVector{}, + chl2 = Socket{}, + nv = NoisyVoleSender{}, + b = VecF{} + ); + setTimePoint("SilentVoleSender.genSilent.begin"); + + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + xx = mCtx.binaryDecomposition(delta); + + // compute the correlation for the noisy coordinates. + b.resize(baseVoleCount()); + + + if (mBaseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + + if (mOtExtRecver.hasBaseOts() == false) + { + msg.resize(msg.size() + mOtExtRecver.baseOtCount()); + MC_AWAIT(mOtExtSender.send(msg, prng, chl)); + + mOtExtRecver.setBaseOts( + span>(msg).subspan( + msg.size() - mOtExtRecver.baseOtCount(), + mOtExtRecver.baseOtCount())); + msg.resize(msg.size() - mOtExtRecver.baseOtCount()); + + MC_AWAIT(nv.send(delta, b, prng, mOtExtRecver, chl, mCtx)); + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + MC_AWAIT( + macoro::when_all_ready( + nv.send(delta, b, prng2, mOtExtRecver, chl2, mCtx), + mOtExtSender.send(msg, prng, chl))); + } +#else + +#endif + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + //MC_AWAIT(baseOt.send(msg, prng, chl)); + //MC_AWAIT(nv.send(delta, b, prng2, baseOt, chl2)); + MC_AWAIT( + macoro::when_all_ready( + nv.send(delta, b, prng2, baseOt, chl2, mCtx), + baseOt.send(msg, prng, chl))); + } + + + setSilentBaseOts(msg, b); + setTimePoint("SilentVoleSender.genSilent.done"); + MC_END(); + } // configure the silent OT extension. This sets // the parameters and figures out how many base OT // will be needed. These can then be ganerated for // a different OT extension or using a base OT protocol. void configure( - u64 n, - SilentBaseType baseType = SilentBaseType::BaseExtend, - u64 secParam = 128); + u64 requestSize, + SilentBaseType type = SilentBaseType::BaseExtend, + u64 secParam = 128, + Ctx ctx = {}) + { + mCtx = std::move(ctx); + mSecParam = secParam; + mRequestSize = requestSize; + mState = State::Configured; + mBaseType = type; + double minDist = 0; + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + u64 _1, _2; + ExConvConfigure(mScaler, mMultType, _1, _2, minDist); + break; + } + case MultType::QuasiCyclic: + QuasiCyclicConfigure(mScaler, minDist); + break; + default: + throw RTE_LOC; + break; + } + + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); + mNoiseVecSize = mSizePer * mNumPartitions; + + mGen.configure(mSizePer, mNumPartitions); + } // return true if this instance has been configured. bool isConfigured() const { return mState != State::Default; } // Returns how many base OTs the silent OT extension // protocol will needs. - u64 silentBaseOtCount() const; + u64 silentBaseOtCount() const + { + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + return mGen.baseOtCount(); + } // Set the externally generated base OTs. This choice // bits must be the one return by sampleBaseChoiceBits(...). void setSilentBaseOts( span> sendBaseOts, - span sendBaseVole); + const VecF& b) + { + if ((u64)sendBaseOts.size() != silentBaseOtCount()) + throw RTE_LOC; + + if (b.size() != baseVoleCount()) + throw RTE_LOC; + + mGen.setBase(sendBaseOts); + + // we store the negative of b. This is because + // we need the correlation + // + // mBaseA + mBaseB = mBaseC * delta + // + // for the pprf to expand correctly but the + // input correlation is a vole: + // + // mBaseA = b + mBaseC * delta + // + mCtx.resize(mBaseB, b.size()); + mCtx.zero(mBaseB.begin(), mBaseB.end()); + for (u64 i = 0; i < mBaseB.size(); ++i) + mCtx.minus(mBaseB[i], mBaseB[i], b[i]); + } // The native OT extension interface of silent // OT. The receiver does not get to specify @@ -139,10 +282,22 @@ namespace osuCrypto // the protocol picks them at random. Use the // send(...) interface for the normal behavior. task<> silentSend( - block delta, - span b, + F delta, + VecF& b, PRNG& prng, - Socket& chls); + Socket& chl) + { + MC_BEGIN(task<>, this, delta, &b, &prng, &chl); + + MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); + + mCtx.copy(mB.begin(), mB.begin() + b.size(), b.begin()); + //std::memcpy(b.data(), mB.data(), b.size() * mCtx.bytesF); + clear(); + + setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); + MC_END(); + } // The native OT extension interface of silent // OT. The receiver does not get to specify @@ -150,18 +305,170 @@ namespace osuCrypto // the protocol picks them at random. Use the // send(...) interface for the normal behavior. task<> silentSendInplace( - block delta, + F delta, u64 n, PRNG& prng, - Socket& chls); + Socket& chl) + { + MC_BEGIN(task<>, this, delta, n, &prng, &chl, + deltaShare = block{}, + X = block{}, + hash = std::array{}, + baseB = VecF{} + ); + setTimePoint("SilentVoleSender.ot.enter"); + + + if (isConfigured() == false) + { + // first generate 128 normal base OTs + configure(n, SilentBaseType::BaseExtend); + } + + if (mRequestSize != n) + throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); + + if (mGen.hasBaseOts() == false) + { + // recvs data + MC_AWAIT(genSilentBaseOts(prng, chl, delta)); + } + + setTimePoint("SilentVoleSender.start"); + //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); + + // allocate B + mCtx.resize(mB, 0); + mCtx.resize(mB, mNoiseVecSize); + + if (mTimer) + mGen.setTimer(*mTimer); + + // extract just the first mNumPartitions value of mBaseB. + // the last is for the malicious check (if present). + mCtx.resize(baseB, mNumPartitions); + mCtx.copy(mBaseB.begin(), mBaseB.begin() + mNumPartitions, baseB.begin()); + + // program the output the PPRF to be secret shares of + // our secret share of delta * noiseVals. The receiver + // can then manually add their shares of this to the + // output of the PPRF at the correct locations. + MC_AWAIT(mGen.expand(chl, baseB, prng.get(), mB, + PprfOutputFormat::Interleaved, true, 1)); + setTimePoint("SilentVoleSender.expand.pprf"); + + if (mDebug) + { + MC_AWAIT(checkRT(chl, delta)); + setTimePoint("SilentVoleSender.expand.checkRT"); + } + + if (mMalType == SilentSecType::Malicious) + { + MC_AWAIT(chl.recv(X)); + + if constexpr (MaliciousSupported) + hash = ferretMalCheck(X); + else + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + + MC_AWAIT(chl.send(std::move(hash))); + } + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + ExConvCode encoder; + u64 expanderWeight, accumulatorWeight; + double _1; + ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _1); + if (mScaler * mRequestSize > mNoiseVecSize) + throw RTE_LOC; + encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); + if (mTimer) + encoder.setTimer(getTimer()); + encoder.dualEncode(mB.begin(), mCtx); + break; + } + case MultType::QuasiCyclic: + { +#ifdef ENABLE_BITPOLYMUL + if constexpr ( + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + QuasiCyclicCode encoder; + encoder.init2(mRequestSize, mNoiseVecSize); + encoder.dualEncode(mB); + } + else + throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); +#else + throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); +#endif + + break; + } + default: + throw std::runtime_error("Code is not supported. " LOCATION); + break; + } + + mCtx.resize(mB, mRequestSize); + + + mState = State::Default; + mBaseB.clear(); + + MC_END(); + } bool mDebug = false; - task<> checkRT(Socket& chl, block delta) const; + task<> checkRT(Socket& chl, F delta) const + { + MC_BEGIN(task<>, this, &chl, delta); + MC_AWAIT(chl.send(delta)); + MC_AWAIT(chl.send(mB)); + MC_AWAIT(chl.send(mBaseB)); + MC_END(); + } - std::array ferretMalCheck(block X, block deltaShare); + std::array ferretMalCheck(block X) + { - void clear(); + auto xx = X; + block sum0 = ZeroBlock; + block sum1 = ZeroBlock; + for (u64 i = 0; i < (u64)mB.size(); ++i) + { + block low, high; + xx.gf128Mul(mB[i], low, high); + sum0 = sum0 ^ low; + sum1 = sum1 ^ high; + + xx = xx.gf128Mul(X); + } + + block mySum = sum0.gf128Reduce(sum1); + + std::array myHash; + RandomOracle ro(32); + ro.Update(mySum ^ mBaseB.back()); + ro.Final(myHash); + + return myHash; + //chl.send(myHash); + } + + void clear() + { + mB = {}; + mGen.clear(); + } }; } diff --git a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp index f8ad9f5a..ea4d17ea 100644 --- a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp +++ b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp @@ -315,7 +315,7 @@ namespace osuCrypto //mNumVolesPadded = (computeNumVolesPadded(fieldBits_, numVoles_)); mGenerateFn = (selectGenerateImpl(mFieldBits)); if (!mPprf) - mPprf.reset(new SilentMultiPprfSender); + mPprf.reset(new PprfSender); } void SmallFieldVoleReceiver::init(u64 fieldBits_, u64 numVoles_, bool malicious) @@ -323,7 +323,7 @@ namespace osuCrypto SmallFieldVoleBase::init(fieldBits_, numVoles_, malicious); mGenerateFn = (selectGenerateImpl(mFieldBits)); if (!mPprf) - mPprf.reset(new SilentMultiPprfReceiver); + mPprf.reset(new PprfReceiver); } void SmallFieldVoleReceiver::setDelta(BitVector delta_) @@ -347,26 +347,9 @@ namespace osuCrypto std::copy(seeds_.begin(), seeds_.end(), mSeeds.data()); } - // The choice bits (as expected by SilentMultiPprfReceiver) store the locations of the complements - // of the active paths (because they tell which messages were transferred, not which ones weren't), - // in big endian. We want delta, which is the locations of the active paths, in little endian. - static BitVector choicesToDelta(const BitVector& choices, u64 fieldBits, u64 numVoles) - { - if ((u64)choices.size() != numVoles * fieldBits) - throw RTE_LOC; - - BitVector delta(choices.size()); - for (u64 i = 0; i < numVoles; ++i) - for (u64 j = 0; j < fieldBits; ++j) - delta[i * fieldBits + j] = 1 ^ choices[(i + 1) * fieldBits - j - 1]; - return delta; - } - - void SmallFieldVoleReceiver::setBaseOts(span baseMessages, const BitVector& choices) { - setDelta(choicesToDelta(choices, mFieldBits, mNumVoles)); - + setDelta(choices); if (!mPprf) throw RTE_LOC; @@ -379,7 +362,7 @@ namespace osuCrypto throw RTE_LOC; mPprf->setBase(baseMessages); - mPprf->setChoiceBits(PprfOutputFormat::ByTreeIndex, choices); + mPprf->setChoiceBits(choices); } void SmallFieldVoleSender::setBaseOts(span> msgs) @@ -404,16 +387,19 @@ namespace osuCrypto MC_BEGIN(task<>, this, &chl, &prng, numThreads, corrections = std::vector>{}, hashes = std::vector>{}, - seedView = MatrixView{} + seedView = MatrixView{}, + _ = AlignedUnVector{} ); assert(mSeeds.size() == 0 && mNumVoles && mNumVoles <= mNumVolesPadded); mSeeds.resize(mNumVolesPadded * fieldSize()); std::fill(mSeeds.begin(), mSeeds.end(), block(0, 0)); + mSeeds.resize(mNumVoles * fieldSize()); + MC_AWAIT(mPprf->expand(chl, _, prng.get(), mSeeds, PprfOutputFormat::ByTreeIndex, false, 1)); + mSeeds.resize(mNumVolesPadded * fieldSize()); seedView = MatrixView(mSeeds.data(), mNumVoles, fieldSize()); - MC_AWAIT(mPprf->expand(chl, span(), prng.get(), seedView, PprfOutputFormat::ByTreeIndex, false, 1)); // Prove consistency if (mMalicious) @@ -452,7 +438,8 @@ namespace osuCrypto task<> SmallFieldVoleReceiver::expand(Socket& chl, PRNG& prng, u64 numThreads) { MC_BEGIN(task<>, this, &chl, &prng, numThreads, - seedsFull = Matrix{}, + seeds = AlignedUnVector{}, + seedsFull = MatrixView{}, totals = std::vector>{}, entryHashes = std::vector>{}, corrections = std::vector>{}, @@ -464,9 +451,9 @@ namespace osuCrypto mSeeds.resize(mNumVolesPadded * (fieldSize() - 1)); std::fill(mSeeds.begin(), mSeeds.end(), block(0, 0)); - - seedsFull.resize(mNumVoles, fieldSize()); - MC_AWAIT(mPprf->expand(chl, seedsFull, PprfOutputFormat::ByTreeIndex, false, 1)); + seeds.resize(mNumVoles * fieldSize()); + MC_AWAIT(mPprf->expand(chl, seeds, PprfOutputFormat::ByTreeIndex, false, 1)); + seedsFull = MatrixView(seeds.data(), mNumVoles, fieldSize()); // Check consistency if (mMalicious) diff --git a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h index b4e02937..68120264 100644 --- a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h +++ b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h @@ -16,7 +16,7 @@ #include #include "libOTe/TwoChooseOne/TcoOtDefines.h" #include "libOTe/Tools/Coproto.h" -#include "libOTe/Tools/SilentPprf.h" +#include "libOTe/Tools/Pprf/RegularPprf.h" namespace osuCrypto { @@ -121,9 +121,14 @@ namespace osuCrypto class SmallFieldVoleSender : public SmallFieldVoleBase { + + public: + using PprfSender = RegularPprfSender; + + private: SmallFieldVoleSender(const SmallFieldVoleSender& b) : SmallFieldVoleBase(b) - , mPprf(new SilentMultiPprfSender) + , mPprf(new PprfSender) , mGenerateFn(b.mGenerateFn) {} @@ -134,7 +139,7 @@ namespace osuCrypto // wastes a few AES calls, but saving them wouldn't have helped much because you still have to // pay for the AES latency. - std::unique_ptr mPprf; + std::unique_ptr mPprf; SmallFieldVoleSender() = default; SmallFieldVoleSender(SmallFieldVoleSender&&) = default; @@ -203,17 +208,20 @@ namespace osuCrypto class SmallFieldVoleReceiver : public SmallFieldVoleBase { + public: + using PprfReceiver = RegularPprfReceiver; + private: SmallFieldVoleReceiver(const SmallFieldVoleReceiver& b) : SmallFieldVoleBase(b) - , mPprf(new SilentMultiPprfReceiver) + , mPprf(new PprfReceiver) , mDelta(b.mDelta) , mDeltaUnpacked(b.mDeltaUnpacked) , mGenerateFn(b.mGenerateFn) {} public: - std::unique_ptr mPprf; + std::unique_ptr mPprf; BitVector mDelta; AlignedUnVector mDeltaUnpacked; // Each bit of delta becomes a byte, either 0 or 0xff. diff --git a/libOTe/Vole/Subfield/NoisyVoleReceiver.h b/libOTe/Vole/Subfield/NoisyVoleReceiver.h deleted file mode 100644 index 5587cac3..00000000 --- a/libOTe/Vole/Subfield/NoisyVoleReceiver.h +++ /dev/null @@ -1,148 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#include -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) - -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/Timer.h" -#include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/Tools/Coproto.h" -#include "libOTe/TwoChooseOne/OTExtInterface.h" -#include "libOTe/Tools/Subfield/Subfield.h" - -namespace osuCrypto { - - template < - typename F, - typename G = F, - typename CoeffCtx = DefaultCoeffCtx - > - class NoisySubfieldVoleReceiver : public TimerAdapter - { - public: - - // for chosen c, compute a such htat - // - // a = b + c * delta - // - template - task<> receive(VecG& c, VecF& a, PRNG& prng, - OtSender& ot, Socket& chl) - { - MC_BEGIN(task<>, this, &c, &a, &prng, &ot, &chl, - otMsg = AlignedUnVector>{}); - - setTimePoint("NoisyVoleReceiver.ot.begin"); - otMsg.resize(CoeffCtx::bitSize()); - MC_AWAIT(ot.send(otMsg, prng, chl)); - - setTimePoint("NoisyVoleReceiver.ot.end"); - - MC_AWAIT(receive(c, a, prng, otMsg, chl)); - - MC_END(); - } - - // for chosen c, compute a such htat - // - // a = b + c * delta - // - template - task<> receive(VecG& c, VecF& a, PRNG& _, - span> otMsg, - Socket& chl) - { - MC_BEGIN(task<>, this, &c, &a, otMsg, &chl, - buff = std::vector{}, - msg = typename CoeffCtx::Vec{}, - temp = typename CoeffCtx::Vec{}, - prng = std::move(PRNG{}) - ); - - if (c.size() != a.size()) - throw RTE_LOC; - if (a.size() == 0) - throw RTE_LOC; - - setTimePoint("NoisyVoleReceiver.begin"); - - CoeffCtx::zero(a.begin(), a.end()); - CoeffCtx::resize(msg, otMsg.size() * a.size()); - CoeffCtx::resize(temp, 2); - - for (size_t i = 0, k = 0; i < otMsg.size(); ++i) - { - prng.SetSeed(otMsg[i][0], a.size()); - - // t1 = 2^i - CoeffCtx::powerOfTwo(temp[1], i); - //std::cout << "2^i " << CoeffCtx::str(temp[1]) << "\n"; - - for (size_t j = 0; j < c.size(); ++j, ++k) - { - // msg[i,j] = otMsg[i,j,0] - CoeffCtx::fromBlock(msg[k], prng.get()); - //CoeffCtx::zero(msg.begin() + k, msg.begin() + k + 1); - //std::cout << "m" << i << ",0 = " << CoeffCtx::str(msg[k]) << std::endl; - - // a[j] += otMsg[i,j,0] - CoeffCtx::plus(a[j], a[j], msg[k]); - //std::cout << "z = " << CoeffCtx::str(a[j]) << std::endl; - - // temp = 2^i * c[j] - CoeffCtx::mul(temp[0], temp[1], c[j]); - //std::cout << "2^i y = " << CoeffCtx::str(temp[0]) << std::endl; - - // msg[i,j] = otMsg[i,j,0] + 2^i * c[j] - CoeffCtx::minus(msg[k], msg[k], temp[0]); - //std::cout << "m" << i << ",0 + 2^i y = " << CoeffCtx::str(msg[k]) << std::endl; - } - - k -= c.size(); - prng.SetSeed(otMsg[i][1], a.size()); - - for (size_t j = 0; j < c.size(); ++j, ++k) - { - // temp = otMsg[i,j,1] - CoeffCtx::fromBlock(temp[0], prng.get()); - //CoeffCtx::zero(temp.begin(), temp.begin() + 1); - //std::cout << "m" << i << ",1 = " << CoeffCtx::str(temp[0]) << std::endl; - - // enc one message under the OT msg. - // msg[i,j] = (otMsg[i,j,0] + 2^i * c[j]) - otMsg[i,j,1] - CoeffCtx::minus(msg[k], msg[k], temp[0]); - //std::cout << "m" << i << ",0 + 2^i y - m" << i << ",1 = " << CoeffCtx::str(msg[k]) << std::endl << std::endl; - } - } - - buff.resize(msg.size() * CoeffCtx::byteSize()); - CoeffCtx::serialize(msg.begin(), msg.end(), buff.begin()); - - MC_AWAIT(chl.send(std::move(buff))); - setTimePoint("NoisyVoleReceiver.done"); - - MC_END(); - } - - }; - -} // namespace osuCrypto -#endif diff --git a/libOTe/Vole/Subfield/NoisyVoleSender.h b/libOTe/Vole/Subfield/NoisyVoleSender.h deleted file mode 100644 index 00a75ed9..00000000 --- a/libOTe/Vole/Subfield/NoisyVoleSender.h +++ /dev/null @@ -1,141 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -// This code implements features described in [Silver: Silent VOLE and Oblivious -// Transfer from Hardness of Decoding Structured LDPC Codes, -// https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative -// Commons Attribution 4.0 International Public License -// (https://creativecommons.org/licenses/by/4.0/legalcode). - -#include -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) - -#include "cryptoTools/Common/BitVector.h" -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/Timer.h" -#include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/Tools/Coproto.h" -#include "libOTe/TwoChooseOne/OTExtInterface.h" -#include "libOTe/Tools/Subfield/Subfield.h" - -namespace osuCrypto { - template < - typename F, - typename G = F, - typename CoeffCtx = DefaultCoeffCtx - > - class NoisySubfieldVoleSender : public TimerAdapter - { - - public: - - // for chosen delta, compute b such htat - // - // a = b + c * delta - // - template - task<> send(F delta, FVec& b, PRNG& prng, - OtReceiver& ot, Socket& chl) { - MC_BEGIN(task<>, this, delta, &b, &prng, &ot, &chl, - bv = CoeffCtx::binaryDecomposition(delta), - otMsg = AlignedUnVector{ }); - otMsg.resize(bv.size()); - - setTimePoint("NoisyVoleSender.ot.begin"); - - MC_AWAIT(ot.receive(bv, otMsg, prng, chl)); - setTimePoint("NoisyVoleSender.ot.end"); - - MC_AWAIT(send(delta, b, prng, otMsg, chl)); - - MC_END(); - } - - // for chosen delta, compute b such htat - // - // a = b + c * delta - // - template - task<> send(F delta, FVec& b, PRNG& _, - span otMsg, Socket& chl) { - MC_BEGIN(task<>, this, delta, &b, otMsg, &chl, - prng = std::move(PRNG{}), - buffer = std::vector{}, - msg = typename CoeffCtx::Vec{}, - temp = typename CoeffCtx::Vec{}, - xb = BitVector{}); - - xb = CoeffCtx::binaryDecomposition(delta); - - if (otMsg.size() != xb.size()) - throw RTE_LOC; - setTimePoint("NoisyVoleSender.main"); - - // b = 0; - CoeffCtx::zero(b.begin(), b.end()); - - // receive the the excrypted one shares. - buffer.resize(xb.size() * b.size() * CoeffCtx::byteSize()); - MC_AWAIT(chl.recv(buffer)); - CoeffCtx::resize(msg, xb.size() * b.size()); - CoeffCtx::deserialize(buffer.begin(), buffer.end(), msg.begin()); - - setTimePoint("NoisyVoleSender.recvMsg"); - - temp.resize(1); - for (size_t i = 0, k = 0; i < xb.size(); ++i) - { - // expand the zero shares or one share masks - prng.SetSeed(otMsg[i], b.size()); - - // otMsg[i,j, bc[i]] - //auto otMsgi = prng.getBufferSpan(b.size()); - - for (u64 j = 0; j < (u64)b.size(); ++j, ++k) - { - // temp = otMsg[i,j, xb[i]] - CoeffCtx::fromBlock(temp[0], prng.get()); - //CoeffCtx::zero(temp.begin(), temp.begin() + 1); - //std::cout << "m" << i << ","< -#ifdef ENABLE_SILENT_VOLE - -#include -#include -#include -#include "libOTe/Tools/Subfield/SubfieldPprf.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "libOTe/Tools/QuasiCyclicCode.h" -namespace osuCrypto -{ - - - template< - typename F, - typename G = F, - typename Ctx = DefaultCoeffCtx - > - class SilentSubfieldVoleReceiver : public TimerAdapter - { - public: - static constexpr u64 mScaler = 2; - - static constexpr bool MaliciousSupported = - std::is_same_v&& - std::is_same_v; - - enum class State - { - Default, - Configured, - HasBase - }; - - using VecF = typename Ctx::template Vec; - using VecG = typename Ctx::template Vec; - - // The current state of the protocol - State mState = State::Default; - - // the context used to perform F, G operations - Ctx mCtx; - - // The number of correlations the user requested. - u64 mRequestSize = 0; - - // the LPN security parameter - u64 mSecParam = 0; - - // The length of the noisy vectors (2 * mN for the most codes). - u64 mNoiseVecSize = 0; - - // We perform regular LPN, so this is the - // size of the each chunk. - u64 mSizePer = 0; - - // the number of noisy positions - u64 mNumPartitions = 0; - - // The noisy coordinates. - std::vector mS; - - // What type of Base OTs should be performed. - SilentBaseType mBaseType; - - // The matrix multiplication type which compresses - // the sparse vector. - MultType mMultType = DefaultMultType; - - // The multi-point punctured PRF for generating - // the sparse vectors. - SilentSubfieldPprfReceiver mGen; - - // The internal buffers for holding the expanded vectors. - // mA = mB + mC * delta - VecF mA; - - // mA = mB + mC * delta - VecG mC; - - u64 mNumThreads = 1; - - bool mDebug = false; - - BitVector mIknpSendBaseChoice; - - SilentSecType mMalType = SilentSecType::SemiHonest; - - block mMalCheckSeed, mMalCheckX, mMalBaseA; - - // we - VecF mBaseA; - VecG mBaseC; - - -#ifdef ENABLE_SOFTSPOKEN_OT - SoftSpokenMalOtSender mOtExtSender; - SoftSpokenMalOtReceiver mOtExtRecver; -#endif - - // // sets the Iknp base OTs that are then used to extend - // void setBaseOts( - // span> baseSendOts); - // - // // return the number of base OTs IKNP needs - // u64 baseOtCount() const; - - u64 baseVoleCount() const - { - return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); - } - - // // returns true if the IKNP base OTs are currently set. - // bool hasBaseOts() const; - // - // returns true if the silent base OTs are set. - bool hasSilentBaseOts() const { - return mGen.hasBaseOts(); - }; - // - // // Generate the IKNP base OTs - // task<> genBaseOts(PRNG& prng, Socket& chl) ; - - // Generate the silent base OTs. If the Iknp - // base OTs are set then we do an IKNP extend, - // otherwise we perform a base OT protocol to - // generate the needed OTs. - task<> genSilentBaseOts(PRNG& prng, Socket& chl) - { - using BaseOT = DefaultBaseOT; - - - MC_BEGIN(task<>, this, &prng, &chl, - choice = BitVector{}, - bb = BitVector{}, - msg = AlignedUnVector{}, - baseVole = std::vector{}, - baseOt = BaseOT{}, - chl2 = Socket{}, - prng2 = std::move(PRNG{}), - noiseVals = VecG{}, - baseAs = VecF{}, - nv = NoisySubfieldVoleReceiver{} - - ); - - setTimePoint("SilentVoleReceiver.genSilent.begin"); - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - choice = sampleBaseChoiceBits(prng); - msg.resize(choice.size()); - - // sample the noise vector noiseVals such that we will compute - // - // C = (000 noiseVals[0] 0000 ... 000 noiseVals[p] 000) - // - // and then we want secret shares of C * delta. As a first step - // we will compute secret shares of - // - // delta * noiseVals - // - // and store our share in voleDeltaShares. This party will then - // compute their share of delta * C as what comes out of the PPRF - // plus voleDeltaShares[i] added to the appreciate spot. Similarly, the - // other party will program the PPRF to output their share of delta * noiseVals. - // - noiseVals = sampleBaseVoleVals(prng); - mCtx.resize(baseAs, noiseVals.size()); - - if (mTimer) - nv.setTimer(*mTimer); - - if (mBaseType == SilentBaseType::BaseExtend) - { -#ifdef ENABLE_SOFTSPOKEN_OT - - if (mOtExtSender.hasBaseOts() == false) - { - msg.resize(msg.size() + mOtExtSender.baseOtCount()); - bb.resize(mOtExtSender.baseOtCount()); - bb.randomize(prng); - choice.append(bb); - - MC_AWAIT(mOtExtRecver.receive(choice, msg, prng, chl)); - - mOtExtSender.setBaseOts( - span(msg).subspan( - msg.size() - mOtExtSender.baseOtCount(), - mOtExtSender.baseOtCount()), - bb); - - msg.resize(msg.size() - mOtExtSender.baseOtCount()); - MC_AWAIT(nv.receive(noiseVals, baseAs, prng, mOtExtSender, chl)); - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - - MC_AWAIT( - macoro::when_all_ready( - nv.receive(noiseVals, baseAs, prng2, mOtExtSender, chl2), - mOtExtRecver.receive(choice, msg, prng, chl) - )); - } -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - MC_AWAIT( - macoro::when_all_ready( - baseOt.receive(choice, msg, prng, chl), - nv.receive(noiseVals, baseAs, prng2, baseOt, chl2)) - ); - } - - setSilentBaseOts(msg, baseAs); - setTimePoint("SilentVoleReceiver.genSilent.done"); - MC_END(); - }; - - // configure the silent OT extension. This sets - // the parameters and figures out how many base OT - // will be needed. These can then be ganerated for - // a different OT extension or using a base OT protocol. - void configure( - u64 requestSize, - SilentBaseType type = SilentBaseType::BaseExtend, - u64 secParam = 128, - Ctx ctx = {}) - { - mCtx = std::move(ctx); - mSecParam = secParam; - mRequestSize = requestSize; - mState = State::Configured; - mBaseType = type; - double minDist = 0; - switch (mMultType) - { - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - { - u64 _1, _2; - ExConvConfigure(mScaler, mMultType, _1, _2, minDist); - break; - } - case MultType::QuasiCyclic: - QuasiCyclicConfigure(mScaler, minDist); - break; - default: - throw RTE_LOC; - break; - } - - mNumPartitions = getRegNoiseWeight(minDist, secParam); - mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); - mNoiseVecSize = mSizePer * mNumPartitions; - - //std::cout << "n " << mRequestSize << " -> " << mNoiseVecSize << " = " << mSizePer << " * " << mNumPartitions << std::endl; - - mGen.configure(mSizePer, mNumPartitions); - } - - // return true if this instance has been configured. - bool isConfigured() const { return mState != State::Default; } - - // Returns how many base OTs the silent OT extension - // protocol will needs. - u64 silentBaseOtCount() const - { - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - return mGen.baseOtCount(); - - } - - // The silent base OTs must have specially set base OTs. - // This returns the choice bits that should be used. - // Call this is you want to use a specific base OT protocol - // and then pass the OT messages back using setSilentBaseOts(...). - BitVector sampleBaseChoiceBits(PRNG& prng) { - - if (isConfigured() == false) - throw std::runtime_error("configure(...) must be called first"); - - auto choice = mGen.sampleChoiceBits(prng); - - return choice; - } - - VecG sampleBaseVoleVals(PRNG& prng) - { - if (isConfigured() == false) - throw RTE_LOC; - - // sample the values of the noisy coordinate of c - // and perform a noicy vole to get a = b + mD * c - - - VecG zero, one; - mCtx.resize(zero, 1); - mCtx.zero(zero.begin(), zero.end()); - mCtx.one(one.begin(), one.end()); - mCtx.resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); - for (size_t i = 0; i < mNumPartitions; i++) - { - mCtx.fromBlock(mBaseC[i], prng.get()); - - // must not be zero. - while(mCtx.eq(zero[0], mBaseC[i])) - mCtx.fromBlock(mBaseC[i], prng.get()); - - // if we are not a field, then the noise should be odd. - if (mCtx.isField() == false) - { - auto odd = mCtx.binaryDecomposition(mBaseC[i])[0]; - if (odd) - mCtx.plus(mBaseC[i], mBaseC[i], one[0]); - } - } - - - mS.resize(mNumPartitions); - mGen.getPoints(mS, PprfOutputFormat::Interleaved); - - if (mMalType == SilentSecType::Malicious) - { - if constexpr (MaliciousSupported) - { - mMalCheckSeed = prng.get(); - - auto yIter = mBaseC.begin(); - mCtx.zero(mBaseC.end() - 1, mBaseC.end()); - for (u64 i = 0; i < mNumPartitions; ++i) - { - auto s = mS[i]; - auto xs = mMalCheckSeed.gf128Pow(s + 1); - mBaseC[mNumPartitions] = mBaseC[mNumPartitions] ^ xs.gf128Mul(*yIter); - ++yIter; - } - } - else - { - throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); - } - } - - return mBaseC; - } - - // Set the externally generated base OTs. This choice - // bits must be the one return by sampleBaseChoiceBits(...). - void setSilentBaseOts( - span recvBaseOts, - VecF& baseA) - { - if (isConfigured() == false) - throw std::runtime_error("configure(...) must be called first."); - - if (static_cast(recvBaseOts.size()) != silentBaseOtCount()) - throw std::runtime_error("wrong number of silent base OTs"); - - mGen.setBase(recvBaseOts); - - mCtx.resize(mBaseA, baseA.size()); - mCtx.copy(baseA.begin(), baseA.end(), mBaseA.begin()); - mState = State::HasBase; - } - - // Perform the actual OT extension. If silent - // base OTs have been generated or set, then - // this function is non-interactive. Otherwise - // the silent base OTs will automatically be performed. - task<> silentReceive( - VecG& c, - VecF& a, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, &c, &a, &prng, &chl); - if (c.size() != a.size()) - throw RTE_LOC; - - MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); - - mCtx.copy(mC.begin(), mC.begin() + c.size(), c.begin()); - mCtx.copy(mA.begin(), mA.begin() + a.size(), a.begin()); - - clear(); - MC_END(); - } - - // Perform the actual OT extension. If silent - // base OTs have been generated or set, then - // this function is non-interactive. Otherwise - // the silent base OTs will automatically be performed. - task<> silentReceiveInplace( - u64 n, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, n, &prng, &chl, - myHash = std::array{}, - theirHash = std::array{} - ); - gTimer.setTimePoint("SilentVoleReceiver.ot.enter"); - - if (isConfigured() == false) - { - // first generate 128 normal base OTs - configure(n, SilentBaseType::BaseExtend); - } - - if (mRequestSize != n) - throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); - - if (hasSilentBaseOts() == false) - { - MC_AWAIT(genSilentBaseOts(prng, chl)); - } - - // allocate mA - mCtx.resize(mA, 0); - mCtx.resize(mA, mNoiseVecSize); - - setTimePoint("SilentVoleReceiver.alloc"); - - // allocate the space for mC - mCtx.resize(mC, 0); - mCtx.resize(mC, mNoiseVecSize); - mCtx.zero(mC.begin(), mC.end()); - setTimePoint("SilentVoleReceiver.alloc.zero"); - - if (mTimer) - mGen.setTimer(*mTimer); - - // As part of the setup, we have generated - // - // mBaseA + mBaseB = mBaseC * mDelta - // - // We have mBaseA, mBaseC, - // they have mBaseB, mDelta - // This was done with a small (noisy) vole. - // - // We use the Pprf to expand as - // - // mA' = mB + mS(mBaseB) - // = mB + mS(mBaseC * mDelta - mBaseA) - // = mB + mS(mBaseC * mDelta) - mS(mBaseA) - // - // Therefore if we add mS(mBaseA) to mA' we will get - // - // mA = mB + mS(mBaseC * mDelta) - // - MC_AWAIT(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); - - setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); - - // populate the noisy coordinates of mC and - // update mA to be a secret share of mC * delta - for (u64 i = 0; i < mNumPartitions; ++i) - { - auto pnt = mS[i]; - mCtx.copy(mC[pnt], mBaseC[i]); - mCtx.plus(mA[pnt], mA[pnt], mBaseA[i]); - } - - if (mDebug) - { - MC_AWAIT(checkRT(chl)); - setTimePoint("SilentVoleReceiver.expand.checkRT"); - } - - - if (mMalType == SilentSecType::Malicious) - { - MC_AWAIT(chl.send(std::move(mMalCheckSeed))); - - if constexpr (MaliciousSupported) - myHash = ferretMalCheck(); - else - throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); - - MC_AWAIT(chl.recv(theirHash)); - - if (theirHash != myHash) - throw RTE_LOC; - } - - switch (mMultType) - { - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - { - u64 expanderWeight, accumulatorWeight; - double _; - ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _); - ExConvCode2 encoder; - if (mScaler * mRequestSize > mNoiseVecSize) - throw RTE_LOC; - encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); - - if (mTimer) - encoder.setTimer(getTimer()); - - encoder.dualEncode2( - mA.begin(), - mC.begin() - ); - break; - } - case osuCrypto::MultType::QuasiCyclic: - { -#ifdef ENABLE_BITPOLYMUL - if constexpr ( - std::is_same_v && - std::is_same_v && - std::is_same_v) - { - QuasiCyclicCode encoder; - encoder.init2(mRequestSize, mNoiseVecSize); - encoder.dualEncode(mA); - encoder.dualEncode(mC); - } - else - throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); -#else - throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); -#endif - break; - } - default: - throw std::runtime_error("Code is not supported. " LOCATION); - break; - } - - // resize the buffers down to only contain the real elements. - mCtx.resize(mA, mRequestSize); - mCtx.resize(mC, mRequestSize); - - mBaseC = {}; - mBaseA = {}; - - // make the protocol as done and that - // mA,mC are ready to be consumed. - mState = State::Default; - - MC_END(); - } - - - - // internal. - task<> checkRT(Socket& chl) const - { - MC_BEGIN(task<>, this, &chl, - B = VecF{}, - sparseNoiseDelta = VecF{}, - baseB = VecF{}, - delta = VecF{}, - tempF = VecF{}, - tempG = VecG{}, - buffer = std::vector{} - ); - - // recv delta - buffer.resize(mCtx.byteSize()); - mCtx.resize(delta, 1); - MC_AWAIT(chl.recv(buffer)); - mCtx.deserialize(buffer.begin(), buffer.end(), delta.begin()); - - // recv B - buffer.resize(mCtx.byteSize() * mA.size()); - mCtx.resize(B, mA.size()); - MC_AWAIT(chl.recv(buffer)); - mCtx.deserialize(buffer.begin(), buffer.end(), B.begin()); - - // recv the noisy values. - buffer.resize(mCtx.byteSize() * mBaseA.size()); - mCtx.resize(baseB, mBaseA.size()); - MC_AWAIT(chl.recvResize(buffer)); - mCtx.deserialize(buffer.begin(), buffer.end(), baseB.begin()); - - // it shoudl hold that - // - // mBaseA = baseB + mBaseC * mDelta - // - // and - // - // mA = mB + mC * mDelta - // - { - bool verbose = false; - bool failed = false; - std::vector index(mS.size()); - std::iota(index.begin(), index.end(), 0); - std::sort(index.begin(), index.end(), - [&](std::size_t i, std::size_t j) { return mS[i] < mS[j]; }); - - mCtx.resize(tempF, 2); - mCtx.resize(tempG, 1); - mCtx.zero(tempG.begin(), tempG.end()); - - - // check the correlation that - // - // mBaseA + mBaseB = mBaseC * mDelta - for (auto i : rng(mBaseA.size())) - { - // temp[0] = baseB[i] + mBaseA[i] - mCtx.plus(tempF[0], baseB[i], mBaseA[i]); - - // temp[1] = mBaseC[i] * delta[0] - mCtx.mul(tempF[1], delta[0], mBaseC[i]); - - if (!mCtx.eq(tempF[0], tempF[1])) - throw RTE_LOC; - - if (i < mNumPartitions) - { - //auto idx = index[i]; - auto point = mS[i]; - if (!mCtx.eq(mBaseC[i], mC[point])) - throw RTE_LOC; - - if (i && mS[index[i - 1]] >= mS[index[i]]) - throw RTE_LOC; - } - } - - - auto iIter = index.begin(); - auto leafIdx = mS[*iIter]; - F act = tempF[0]; - G zero = tempG[0]; - mCtx.zero(tempG.begin(), tempG.end()); - - for (u64 j = 0; j < mA.size(); ++j) - { - mCtx.mul(act, delta[0], mC[j]); - mCtx.plus(act, act, B[j]); - - bool active = false; - if (j == leafIdx) - { - active = true; - } - else if (!mCtx.eq(zero, mC[j])) - throw RTE_LOC; - - if (mA[j] != act) - { - failed = true; - if (verbose) - std::cout << Color::Red; - } - - if (verbose) - { - std::cout << j << " act " << mCtx.str(act) - << " a " << mCtx.str(mA[j]) << " b " << mCtx.str(B[j]); - - if (active) - std::cout << " < " << mCtx.str(delta[0]); - - std::cout << std::endl << Color::Default; - } - - if (j == leafIdx) - { - ++iIter; - if (iIter != index.end()) - { - leafIdx = mS[*iIter]; - } - } - } - - if (failed) - throw RTE_LOC; - } - - MC_END(); - } - - std::array ferretMalCheck() - { - - block xx = mMalCheckSeed; - block sum0 = ZeroBlock; - block sum1 = ZeroBlock; - - - for (u64 i = 0; i < (u64)mA.size(); ++i) - { - block low, high; - xx.gf128Mul(mA[i], low, high); - sum0 = sum0 ^ low; - sum1 = sum1 ^ high; - //mySum = mySum ^ xx.gf128Mul(mA[i]); - - // xx = mMalCheckSeed^{i+1} - xx = xx.gf128Mul(mMalCheckSeed); - } - - // = < - block mySum = sum0.gf128Reduce(sum1); - - std::array myHash; - RandomOracle ro(32); - ro.Update(mySum ^ mBaseA.back()); - ro.Final(myHash); - return myHash; - } - - void clear() - { - mS = {}; - mA = {}; - mC = {}; - mGen.clear(); - } - }; -} -#endif \ No newline at end of file diff --git a/libOTe/Vole/Subfield/SilentVoleSender.h b/libOTe/Vole/Subfield/SilentVoleSender.h deleted file mode 100644 index dd2496b0..00000000 --- a/libOTe/Vole/Subfield/SilentVoleSender.h +++ /dev/null @@ -1,475 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). - -#include -#ifdef ENABLE_SILENT_VOLE - -#include -#include -#include -#include -#include "libOTe/Tools/Subfield/SubfieldPprf.h" -#include -#include -#include -#include -#include -#include -#include -//#define NO_HASH - -namespace osuCrypto -{ - template< - typename F, - typename G = F, - typename Ctx = DefaultCoeffCtx - > - class SilentSubfieldVoleSender : public TimerAdapter - { - public: - static constexpr u64 mScaler = 2; - - static constexpr bool MaliciousSupported = - std::is_same_v&& - std::is_same_v; - - enum class State - { - Default, - Configured, - HasBase - }; - - using VecF = typename Ctx::template Vec; - using VecG = typename Ctx::template Vec; - - State mState = State::Default; - - // the context used to perform F, G operations - Ctx mCtx; - - // the pprf used to generate the noise vector. - SilentSubfieldPprfSender mGen; - - // the number of correlations requested. - u64 mRequestSize = 0; - - // the length of the noisy vector. - u64 mNoiseVecSize = 0; - - // the weight of the nosy vector - u64 mNumPartitions = 0; - - // the size of each regular, weight 1, subvector - // of the noisy vector. mNoiseVecSize = mNumPartions * mSizePer - u64 mSizePer = 0; - - // the lpn security parameters - u64 mSecParam = 0; - - // the type of base OT OT that should be performed. - // Base requires more work but less communication. - SilentBaseType mBaseType = SilentBaseType::BaseExtend; - - // the base Vole correlation. To generate the silent vole, - // we must first create a small vole - // mBaseA + mBaseB = mBaseC * mDelta. - // These will be used to initialize the non-zeros of the noisy - // vector. mBaseB is the b in this corrlations. - VecF mBaseB; - - // the full sized noisy vector. This will initalially be - // sparse with the corrlations - // mA = mB + mC * mDelta - // before it is compressed. - VecF mB; - - // determines if the malicious checks are performed. - SilentSecType mMalType = SilentSecType::SemiHonest; - - // A flag to specify the linear code to use - MultType mMultType = DefaultMultType; - - - block mDeltaShare; - -#ifdef ENABLE_SOFTSPOKEN_OT - SoftSpokenMalOtSender mOtExtSender; - SoftSpokenMalOtReceiver mOtExtRecver; -#endif - - - - - u64 baseVoleCount() const - { - return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); - } - - // Generate the silent base OTs. If the Iknp - // base OTs are set then we do an IKNP extend, - // otherwise we perform a base OT protocol to - // generate the needed OTs. - task<> genSilentBaseOts(PRNG& prng, Socket& chl, F delta) - { - using BaseOT = DefaultBaseOT; - - MC_BEGIN(task<>, this, delta, &prng, &chl, - msg = AlignedUnVector>(silentBaseOtCount()), - baseOt = BaseOT{}, - prng2 = std::move(PRNG{}), - xx = BitVector{}, - chl2 = Socket{}, - nv = NoisySubfieldVoleSender{}, - b = VecF{} - ); - setTimePoint("SilentVoleSender.genSilent.begin"); - - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - xx = mCtx.binaryDecomposition(delta); - - // compute the correlation for the noisy coordinates. - b.resize(baseVoleCount()); - - - if (mBaseType == SilentBaseType::BaseExtend) - { -#ifdef ENABLE_SOFTSPOKEN_OT - - if (mOtExtRecver.hasBaseOts() == false) - { - msg.resize(msg.size() + mOtExtRecver.baseOtCount()); - MC_AWAIT(mOtExtSender.send(msg, prng, chl)); - - mOtExtRecver.setBaseOts( - span>(msg).subspan( - msg.size() - mOtExtRecver.baseOtCount(), - mOtExtRecver.baseOtCount())); - msg.resize(msg.size() - mOtExtRecver.baseOtCount()); - - MC_AWAIT(nv.send(delta, b, prng, mOtExtRecver, chl)); - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - MC_AWAIT( - macoro::when_all_ready( - nv.send(delta, b, prng2, mOtExtRecver, chl2), - mOtExtSender.send(msg, prng, chl))); - } -#else - -#endif - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - //MC_AWAIT(baseOt.send(msg, prng, chl)); - //MC_AWAIT(nv.send(delta, b, prng2, baseOt, chl2)); - MC_AWAIT( - macoro::when_all_ready( - nv.send(delta, b, prng2, baseOt, chl2), - baseOt.send(msg, prng, chl))); - } - - - setSilentBaseOts(msg, b); - setTimePoint("SilentVoleSender.genSilent.done"); - MC_END(); - } - - // configure the silent OT extension. This sets - // the parameters and figures out how many base OT - // will be needed. These can then be ganerated for - // a different OT extension or using a base OT protocol. - void configure( - u64 requestSize, - SilentBaseType type = SilentBaseType::BaseExtend, - u64 secParam = 128, - Ctx ctx = {}) - { - mCtx = std::move(ctx); - mSecParam = secParam; - mRequestSize = requestSize; - mState = State::Configured; - mBaseType = type; - double minDist = 0; - - switch (mMultType) - { - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - { - u64 _1, _2; - ExConvConfigure(mScaler, mMultType, _1, _2, minDist); - break; - } - case MultType::QuasiCyclic: - QuasiCyclicConfigure(mScaler, minDist); - break; - default: - throw RTE_LOC; - break; - } - - mNumPartitions = getRegNoiseWeight(minDist, secParam); - mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); - mNoiseVecSize = mSizePer * mNumPartitions; - - mGen.configure(mSizePer, mNumPartitions); - } - - // return true if this instance has been configured. - bool isConfigured() const { return mState != State::Default; } - - // Returns how many base OTs the silent OT extension - // protocol will needs. - u64 silentBaseOtCount() const - { - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - return mGen.baseOtCount(); - } - - // Set the externally generated base OTs. This choice - // bits must be the one return by sampleBaseChoiceBits(...). - void setSilentBaseOts( - span> sendBaseOts, - const VecF& b) - { - if ((u64)sendBaseOts.size() != silentBaseOtCount()) - throw RTE_LOC; - - if (b.size() != baseVoleCount()) - throw RTE_LOC; - - mGen.setBase(sendBaseOts); - - // we store the negative of b. This is because - // we need the correlation - // - // mBaseA + mBaseB = mBaseC * delta - // - // for the pprf to expand correctly but the - // input correlation is a vole: - // - // mBaseA = b + mBaseC * delta - // - mCtx.resize(mBaseB, b.size()); - mCtx.zero(mBaseB.begin(), mBaseB.end()); - for (u64 i = 0; i < mBaseB.size(); ++i) - mCtx.minus(mBaseB[i], mBaseB[i], b[i]); - } - - // The native OT extension interface of silent - // OT. The receiver does not get to specify - // which OT message they receiver. Instead - // the protocol picks them at random. Use the - // send(...) interface for the normal behavior. - task<> silentSend( - F delta, - VecF& b, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, delta, &b, &prng, &chl); - - MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); - - mCtx.copy(mB.begin(), mB.begin() + b.size(), b.begin()); - //std::memcpy(b.data(), mB.data(), b.size() * mCtx.bytesF); - clear(); - - setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); - MC_END(); - } - - // The native OT extension interface of silent - // OT. The receiver does not get to specify - // which OT message they receiver. Instead - // the protocol picks them at random. Use the - // send(...) interface for the normal behavior. - task<> silentSendInplace( - F delta, - u64 n, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, delta, n, &prng, &chl, - deltaShare = block{}, - X = block{}, - hash = std::array{}, - baseB = VecF{} - ); - setTimePoint("SilentVoleSender.ot.enter"); - - - if (isConfigured() == false) - { - // first generate 128 normal base OTs - configure(n, SilentBaseType::BaseExtend); - } - - if (mRequestSize != n) - throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); - - if (mGen.hasBaseOts() == false) - { - // recvs data - MC_AWAIT(genSilentBaseOts(prng, chl, delta)); - } - - setTimePoint("SilentVoleSender.start"); - //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); - - // allocate B - mCtx.resize(mB, 0); - mCtx.resize(mB, mNoiseVecSize); - - if (mTimer) - mGen.setTimer(*mTimer); - - // extract just the first mNumPartitions value of mBaseB. - // the last is for the malicious check (if present). - mCtx.resize(baseB, mNumPartitions); - mCtx.copy(mBaseB.begin(), mBaseB.begin() + mNumPartitions, baseB.begin()); - - // program the output the PPRF to be secret shares of - // our secret share of delta * noiseVals. The receiver - // can then manually add their shares of this to the - // output of the PPRF at the correct locations. - MC_AWAIT(mGen.expand(chl, baseB, prng.get(), mB, - PprfOutputFormat::Interleaved, true, 1)); - setTimePoint("SilentVoleSender.expand.pprf"); - - if (mDebug) - { - MC_AWAIT(checkRT(chl, delta)); - setTimePoint("SilentVoleSender.expand.checkRT"); - } - - if (mMalType == SilentSecType::Malicious) - { - MC_AWAIT(chl.recv(X)); - - if constexpr (MaliciousSupported) - hash = ferretMalCheck(X); - else - throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); - - MC_AWAIT(chl.send(std::move(hash))); - } - - switch (mMultType) - { - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - { - ExConvCode2 encoder; - u64 expanderWeight, accumulatorWeight; - double _1; - ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _1); - if (mScaler * mRequestSize > mNoiseVecSize) - throw RTE_LOC; - encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); - if (mTimer) - encoder.setTimer(getTimer()); - encoder.dualEncode(mB.begin()); - break; - } - case MultType::QuasiCyclic: - { -#ifdef ENABLE_BITPOLYMUL - if constexpr ( - std::is_same_v && - std::is_same_v && - std::is_same_v) - { - QuasiCyclicCode encoder; - encoder.init2(mRequestSize, mNoiseVecSize); - encoder.dualEncode(mB); - } - else - throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); -#else - throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); -#endif - - break; - } - default: - throw std::runtime_error("Code is not supported. " LOCATION); - break; - } - - mCtx.resize(mB, mRequestSize); - - - mState = State::Default; - mBaseB.clear(); - - MC_END(); - } - - bool mDebug = false; - - task<> checkRT(Socket& chl, F delta) const - { - MC_BEGIN(task<>, this, &chl, delta); - MC_AWAIT(chl.send(delta)); - MC_AWAIT(chl.send(mB)); - MC_AWAIT(chl.send(mBaseB)); - MC_END(); - } - - std::array ferretMalCheck(block X) - { - - auto xx = X; - block sum0 = ZeroBlock; - block sum1 = ZeroBlock; - for (u64 i = 0; i < (u64)mB.size(); ++i) - { - block low, high; - xx.gf128Mul(mB[i], low, high); - sum0 = sum0 ^ low; - sum1 = sum1 ^ high; - - xx = xx.gf128Mul(X); - } - - block mySum = sum0.gf128Reduce(sum1); - - std::array myHash; - RandomOracle ro(32); - ro.Update(mySum ^ mBaseB.back()); - ro.Final(myHash); - - return myHash; - //chl.send(myHash); - } - - void clear() - { - mB = {}; - mGen.clear(); - } - }; - -} - -#endif \ No newline at end of file diff --git a/libOTe_Tests/EACode_Tests.cpp b/libOTe_Tests/EACode_Tests.cpp index 2aa06829..872ba6f9 100644 --- a/libOTe_Tests/EACode_Tests.cpp +++ b/libOTe_Tests/EACode_Tests.cpp @@ -1,6 +1,7 @@ #include "EACode_Tests.h" #include "libOTe/Tools/EACode/EACode.h" #include +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { @@ -18,76 +19,67 @@ namespace osuCrypto EACode code; code.config(k, n, bw); - auto A = code.getA(); - auto B = code.getB(); - auto G = B * A; + //auto A = code.getA(); + //auto B = code.getB(); + //auto G = B * A; std::vector m0(k), m1(k), c(n), c0(n), c1(n), a1(n); std::vector c2(n), m2(k); - if (v) - { - std::cout << "B\n" << B << std::endl << std::endl; - std::cout << "A'\n" << code.getAPar() << std::endl << std::endl; - std::cout << "A\n" << A << std::endl << std::endl; - std::cout << "G\n" << G << std::endl; + //if (v) + //{ + // std::cout << "B\n" << B << std::endl << std::endl; + // std::cout << "A'\n" << code.getAPar() << std::endl << std::endl; + // std::cout << "A\n" << A << std::endl << std::endl; + // std::cout << "G\n" << G << std::endl; - } + //} PRNG prng(ZeroBlock); prng.get(c0.data(), c0.size()); auto a0 = c0; - code.accumulate(a0); - A.multAdd(c0, a1); - //A.leftMultAdd(c0, c1); - if (a0 != a1) - { - if (v) - { + code.accumulate(a0, {}); - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (a0[i]) << " "; - std::cout << "\n"; - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (c1[i]) << " "; - std::cout << "\n"; - } + block sum = c0[0]; + for (u64 i = 0; i < a0.size(); ++i) + { + if (a0[i] != sum) + throw RTE_LOC; - throw RTE_LOC; + if(i+1 < a0.size()) + sum += c0[i + 1]; } - auto cc = c0; - B.multAdd(cc, m0); - code.expand(cc, m1); - - if (m0 != m1) - throw RTE_LOC; - - m0.resize(0); - m0.resize(k); - - G.multAdd(c0, m0); - - - cc = c0; - code.dualEncode(cc, m1); - - if (m0 != m1) - throw RTE_LOC; + u64 i = 0; + detail::ExpanderModd expanderCoeff(code.mSeed, code.mCodeSize); + auto main = k / 8 * 8; + for (; i < main; i += 8) + { + for (u64 j = 0; j < code.mExpanderWeight; ++j) + { + for (u64 p = 0; p < 8; ++p) + { + auto idx = expanderCoeff.get(); + m0[i + p] = m0[i + p] ^ a0[idx]; + } + } + } - cc = c0; - for (u64 i = 0; i < code.mCodeSize; ++i) - c2[i] = c0[i].get(0); + for (; i < k; ++i) + { + for (u64 j = 0; j < code.mExpanderWeight; ++j) + { + auto idx = expanderCoeff.get(); + m0[i] = m0[i] ^ a0[idx]; + } + } - code.dualEncode2(cc, m1, c2, m2); + code.dualEncode(c0, m1, {}); if (m0 != m1) throw RTE_LOC; - for (u64 i = 0; i < code.mMessageSize; ++i) - m2[i] = m0[i].get(0); - } } \ No newline at end of file diff --git a/libOTe_Tests/ExConvCode_Tests.cpp b/libOTe_Tests/ExConvCode_Tests.cpp index fa1785a7..0a35e1d7 100644 --- a/libOTe_Tests/ExConvCode_Tests.cpp +++ b/libOTe_Tests/ExConvCode_Tests.cpp @@ -1,8 +1,8 @@ #include "ExConvCode_Tests.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" -#include "libOTe/Tools/ExConvCode/ExConvCode2.h" +#include "libOTe/Tools/ExConvCode/ExConvCode.h" #include -#include "libOTe/Tools/Subfield/Subfield.h" +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { @@ -49,7 +49,7 @@ namespace osuCrypto void exConvTest(u64 k, u64 n, u64 bw, u64 aw, bool sys) { - ExConvCode2 code; + ExConvCode code; code.config(k, n, bw, aw, sys); auto accOffset = sys * k; @@ -60,23 +60,23 @@ namespace osuCrypto { x1[i] = x2[i] = x3[i] = prng.get(); } - + CoeffCtx ctx; std::vector rand(divCeil(aw, 8)); for (i64 i = 0; i < x1.size() - aw - 1; ++i) { prng.get(rand.data(), rand.size()); - code.accOne(x1.begin() + i, x1.end(), rand.data(), std::integral_constant{}); + code.accOne(x1.begin() + i, x1.end(), rand.data(), ctx, std::integral_constant{}); if (aw == 16) - code.accOne(x2.begin() + i, x2.end(), rand.data(), std::integral_constant{}); + code.accOne(x2.begin() + i, x2.end(), rand.data(), ctx, std::integral_constant{}); - CoeffCtx::plus(x3[i + 1], x3[i + 1], x3[i]); + ctx.plus(x3[i + 1], x3[i + 1], x3[i]); for (u64 j = 0; j < aw && (i + j + 2) < x3.size(); ++j) { if (*BitIterator(rand.data(), j)) { - CoeffCtx::plus(x3[i + j + 2], x3[i + j + 2], x3[i]); + ctx.plus(x3[i + j + 2], x3[i + j + 2], x3[i]); } } @@ -100,11 +100,11 @@ namespace osuCrypto x4 = x1; //std::cout << std::endl; - code.accumulateFixed(x1.begin() + accOffset); + code.accumulateFixed(x1.begin() + accOffset, ctx); if (aw == 16) { - code.accumulateFixed(x2.begin() + accOffset); + code.accumulateFixed(x2.begin() + accOffset, ctx); if (x1 != x2) { @@ -128,7 +128,7 @@ namespace osuCrypto if (mtxCoeffIter > mtxCoeffEnd) { // generate more mtx coefficients - ExConvCode2::refill(coeffGen); + ExConvCode::refill(coeffGen); mtxCoeffIter = (u8*)coeffGen.mBuffer.data(); } @@ -136,14 +136,14 @@ namespace osuCrypto auto xj = xi + 1; if (xj != end) { - CoeffCtx::plus(*xj, *xj, *xi); + ctx.plus(*xj, *xj, *xi); ++xj; } for (u64 j = 0; j < aw && xj != end; ++j, ++xj) { if (*BitIterator(mtxCoeffIter, j)) { - CoeffCtx::plus(*xj, *xj, *xi); + ctx.plus(*xj, *xj, *xi); } } ++mtxCoeffIter; @@ -190,7 +190,7 @@ namespace osuCrypto for (u64 p = 0; p < 8; ++p) { auto idx = expanderCoeff.get(); - CoeffCtx::plus(y2[i + p], y2[i + p], x1[idx + accOffset]); + ctx.plus(y2[i + p], y2[i + p], x1[idx + accOffset]); } } } @@ -200,14 +200,14 @@ namespace osuCrypto for (u64 j = 0; j < code.mExpander.mExpanderWeight; ++j) { auto idx = expanderCoeff.get(); - CoeffCtx::plus(y2[i], y2[i], x1[idx + accOffset]); + ctx.plus(y2[i], y2[i], x1[idx + accOffset]); } } if (y1 != y2) throw RTE_LOC; - code.dualEncode(x4.begin()); + code.dualEncode(x4.begin(), {}); x4.resize(k); if (x4 != y1) @@ -224,7 +224,7 @@ namespace osuCrypto //std::vector i1, o1; //std::vector i2, o2; - //ExpanderCode2 ex; + //ExpanderCode ex; //ex.expandMany( // std::tuple{ // std::pair{i0.begin(), o0.begin()}, diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp index 6d4a5a62..2e95de2e 100644 --- a/libOTe_Tests/Pprf_Tests.cpp +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -1,6 +1,6 @@ #include "Pprf_Tests.h" -#include "libOTe/Tools/Subfield/SubfieldPprf.h" +#include "libOTe/Tools/Pprf/RegularPprf.h" #include "cryptoTools/Common/Log.h" #include "Common.h" #include @@ -17,8 +17,8 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) PRNG prng(CCBlock); auto format = PprfOutputFormat::Interleaved; - SilentSubfieldPprfSender sender; - SilentSubfieldPprfReceiver recver; + RegularPprfSender sender; + RegularPprfReceiver recver; sender.configure(domain, pntCount); recver.configure(domain, pntCount); @@ -40,40 +40,38 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) sender.setBase(sendOTs); recver.setBase(recvOTs); - std::vector points(8); - recver.getPoints(points, PprfOutputFormat::ByLeafIndex); block seed = CCBlock; - auto sTree = span>{}; auto sLevels = std::vector>>{}; - auto rTree = span>{}; auto rLevels = std::vector>>{}; auto sBuff = std::vector{}; - auto sSums = span, 2>>{}; + auto sSums = span>{}; auto sLast = span{}; - TreeAllocator mTreeAlloc; + pprf::TreeAllocator mTreeAlloc; sLevels.resize(depth); rLevels.resize(depth); mTreeAlloc.reserve(2, (1ull << depth) + 2); - allocateExpandTree(depth, mTreeAlloc, sTree, sLevels); - allocateExpandTree(depth, mTreeAlloc, rTree, rLevels); + + + pprf::allocateExpandTree(mTreeAlloc, sLevels); + pprf::allocateExpandTree(mTreeAlloc, rLevels); Ctx::Vec sLeafLevel(8ull << depth); Ctx::Vec rLeafLevel(8ull << depth); u64 leafOffset = 0; Ctx ctx; - allocateExpandBuffer(depth - 1, program, sBuff, sSums, sLast, ctx); + pprf::allocateExpandBuffer(depth - 1, pntCount, program, sBuff, sSums, sLast, ctx); - recver.mPoints.resize(roundUpTo(recver.mPntCount, 8)); - recver.getPoints(recver.mPoints, PprfOutputFormat::ByLeafIndex); + std::vector points(recver.mPntCount); + recver.getPoints(points, PprfOutputFormat::ByLeafIndex); sender.expandOne(seed, 0, program, sLevels, sLeafLevel, leafOffset, sSums, sLast, ctx); - recver.expandOne(0, program, rLevels, rLeafLevel, leafOffset, sSums, sLast, ctx); + recver.expandOne(0, program, rLevels, rLeafLevel, leafOffset, sSums, sLast, points, ctx); bool failed = false; for (u64 i = 0; i < pntCount; ++i) @@ -125,10 +123,10 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) if (j == leafIdx) { F exp; - Ctx::plus(exp, sLeaves(j, i), value); + ctx.plus(exp, sLeaves(j, i), value); if (program && exp != rLeaves(j, i)) { - std::cout << i << " exp " << Ctx::str(exp) << " " << Ctx::str(rLeaves(j, i)) << std::endl; + std::cout << i << " exp " << ctx.str(exp) << " " << ctx.str(rLeaves(j, i)) << std::endl; throw RTE_LOC; } } @@ -149,7 +147,7 @@ void Tools_Pprf_expandOne_test(const oc::CLP& cmd) #if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - for (u64 domain : { 4, 128, 4522}) for (bool program : {true, false}) + for (u64 domain : { 2, 128, 4522}) for (bool program : {true, false}) { Tools_Pprf_expandOne_test_impl(domain, program); @@ -180,13 +178,14 @@ void Tools_Pprf_test_impl( auto sockets = cp::LocalAsyncSocket::makePair(); - SilentSubfieldPprfSender sender; - SilentSubfieldPprfReceiver recver; + RegularPprfSender sender; + RegularPprfReceiver recver; Vec delta; + Ctx ctx; auto seed = prng.get(); - Ctx::resize(delta, numPoints * program); + ctx.resize(delta, numPoints * program); for (u64 i = 0; i < delta.size(); ++i) - Ctx::fromBlock(delta[i], seed); + ctx.fromBlock(delta[i], seed); sender.configure(domain, numPoints); recver.configure(domain, numPoints); @@ -270,9 +269,9 @@ void Tools_Pprf_test_impl( if (points[j] == i) { if (program) - Ctx::plus(exp, b[idx], delta[j]); + ctx.plus(exp, b[idx], delta[j]); else - Ctx::zero(&exp, &exp + 1); + ctx.zero(&exp, &exp + 1); } else exp = b[idx]; @@ -286,7 +285,7 @@ void Tools_Pprf_test_impl( } if (verbose) { - std::cout << "r[" << j << "][" << i << "] " << exp << " " << Ctx::str(a[idx]); + std::cout << "r[" << j << "][" << i << "] " << exp << " " << ctx.str(a[idx]); if (points[j] == i) std::cout << " < "; @@ -315,7 +314,7 @@ void Tools_Pprf_test_impl( auto iIter = index.begin(); auto leafIdx = points[*iIter]; F deltaVal; - Ctx::zero(&deltaVal, &deltaVal + 1); + ctx.zero(&deltaVal, &deltaVal + 1); if(program) deltaVal = delta[*iIter]; @@ -328,16 +327,16 @@ void Tools_Pprf_test_impl( // act = a - b // = point * delta - Ctx::minus(act, a[j], b[j]); - Ctx::zero(&exp, &exp + 1); + ctx.minus(act, a[j], b[j]); + ctx.zero(&exp, &exp + 1); bool active = false; if (j == leafIdx) { active = true; if (program) - Ctx::copy(exp, deltaVal); + ctx.copy(exp, deltaVal); else - Ctx::minus(exp, exp, b[j]); + ctx.minus(exp, exp, b[j]); } if (exp != act) @@ -349,8 +348,8 @@ void Tools_Pprf_test_impl( if (verbose) { - std::cout << j << " exp " << Ctx::str(exp) << " " << Ctx::str(act) - << " a " << Ctx::str(a[j]) << " b " << Ctx::str(b[j]); + std::cout << j << " exp " << ctx.str(exp) << " " << ctx.str(act) + << " a " << ctx.str(a[j]) << " b " << ctx.str(b[j]); if (active) std::cout << " < " << deltaVal; @@ -419,7 +418,7 @@ void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd) auto f = PprfOutputFormat::ByTreeIndex; auto v = cmd.isSet("v"); - for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false*/ }) + for (auto d : { 32,3242 }) for (auto n : { 8, 19}) for (auto p : { true/*, false*/ }) { Tools_Pprf_test_impl(d, n, p, f, v); Tools_Pprf_test_impl(d, n, p, f, v); diff --git a/libOTe_Tests/SilentOT_Tests.cpp b/libOTe_Tests/SilentOT_Tests.cpp index 3262dd78..e95bba79 100644 --- a/libOTe_Tests/SilentOT_Tests.cpp +++ b/libOTe_Tests/SilentOT_Tests.cpp @@ -1,6 +1,5 @@ #include "SilentOT_Tests.h" -#include "libOTe/Tools/SilentPprf.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" #include diff --git a/libOTe_Tests/Subfield_Test.h b/libOTe_Tests/Subfield_Test.h deleted file mode 100644 index 2aa90901..00000000 --- a/libOTe_Tests/Subfield_Test.h +++ /dev/null @@ -1,12 +0,0 @@ -#include "cryptoTools/Common/CLP.h" - - -namespace osuCrypto -{ - - - void Subfield_Tools_Pprf_test(const oc::CLP& cmd); - void Subfield_Noisy_Vole_test(const oc::CLP& cmd); - void Subfield_Silent_Vole_test(const oc::CLP& cmd); - -} \ No newline at end of file diff --git a/libOTe_Tests/Subfield_Tests.cpp b/libOTe_Tests/Subfield_Tests.cpp deleted file mode 100644 index 3a7ea0c0..00000000 --- a/libOTe_Tests/Subfield_Tests.cpp +++ /dev/null @@ -1,224 +0,0 @@ -#include "Subfield_Test.h" -#include "libOTe/Tools/Subfield/Subfield.h" -#include "libOTe/Tools/ExConvCode/ExConvCode2.h" -#include "libOTe/Vole/Subfield/NoisyVoleSender.h" -#include "libOTe/Vole/Subfield/NoisyVoleReceiver.h" -#include "libOTe/Vole/Subfield/SilentVoleSender.h" -#include "libOTe/Vole/Subfield/SilentVoleReceiver.h" - -#include "Common.h" - -namespace osuCrypto -{ - static_assert(std::is_trivially_copyable_v>); - static_assert(std::is_trivially_copyable_v); - - using tests_libOTe::eval; - - template - void subfield_vole_test(u64 n) - { - PRNG prng(CCBlock); - - F delta = prng.get(); - std::vector c(n); - std::vector a(n), b(n); - prng.get(c.data(), c.size()); - - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; - - auto chls = cp::LocalAsyncSocket::makePair(); - - BitVector recvChoice = Trait::binaryDecomposition(delta); - std::vector otRecvMsg(recvChoice.size()); - std::vector> otSendMsg(recvChoice.size()); - prng.get>(otSendMsg); - for (u64 i = 0; i < recvChoice.size(); ++i) - otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; - - // compute a,b such that - // - // a = b + c * delta - // - auto p0 = recv.receive(c, a, prng, otSendMsg, chls[0]); - auto p1 = send.send(delta, b, prng, otRecvMsg, chls[1]); - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - F prod, sum; - - Trait::mul(prod, delta, c[i]); - Trait::minus(sum, a[i], b[i]); - - if (prod != sum) - { - throw RTE_LOC; - } - } - } - - void Subfield_Noisy_Vole_test(const oc::CLP& cmd) - { - - for (u64 n : {1, 8, 433}) - { - subfield_vole_test(n); - subfield_vole_test(n); - subfield_vole_test(n); - subfield_vole_test, u32, CoeffCtxArray>(n); - } - } - - void Subfield_Silent_Vole_test(const oc::CLP& cmd) { - using namespace oc; -#if defined(ENABLE_SILENTOT) - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 102043); - u64 nt = cmd.getOr("nt", std::thread::hardware_concurrency()); - block seed = block(0, cmd.getOr("seed", 0)); - - //{ - // PRNG prng(seed); - // u64 x = prng.get(); - // std::vector c(n), z0(n), z1(n); - - // SilentSubfieldVoleReceiver recv; - // SilentSubfieldVoleSender send; - - // recv.mMultType = MultType::ExConv7x24; - // send.mMultType = MultType::ExConv7x24; - - // recv.setTimer(timer); - // send.setTimer(timer); - - // // recv.mDebug = true; - // // send.mDebug = true; - - // auto chls = cp::LocalAsyncSocket::makePair(); - - // timer.setTimePoint("net"); - - // timer.setTimePoint("ot"); - // // fakeBase(n, nt, prng, delta, recv, send); - - // auto p0 = send.silentSend(x, span(z0), prng, chls[0]); - // auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); - - // eval(p0, p1); - // timer.setTimePoint("send"); - // for (u64 i = 0; i < n; ++i) { - // u64 left = c[i] * x; - // u64 right = z1[i] - z0[i]; - // if (left != right) { - // std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; - // std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; - // throw RTE_LOC; - // } - // } - //} - - //{ - // PRNG prng(seed); - // constexpr size_t N = 10; - // using G = u32; - // using F = std::array; - // using CoeffCtx = CoeffCtxArray; - // F x; - // CoeffCtx::fromBlock(x, prng.get()); - // std::vector c(n); - // std::vector a(n), b(n); - - // SilentSubfieldVoleReceiver recv; - // SilentSubfieldVoleSender send; - - // recv.mMultType = MultType::ExConv7x24; - // send.mMultType = MultType::ExConv7x24; - - // recv.setTimer(timer); - // send.setTimer(timer); - - // // recv.mDebug = true; - // // send.mDebug = true; - - // auto chls = cp::LocalAsyncSocket::makePair(); - - // timer.setTimePoint("net"); - - // timer.setTimePoint("ot"); - // // fakeBase(n, nt, prng, delta, recv, send); - - // auto p0 = send.silentSend(x, span(b), prng, chls[0]); - // auto p1 = recv.silentReceive(span(c), span(a), prng, chls[1]); - - // eval(p0, p1); - // // std::cout << "transferred " << (chls[0].bytesSent() + chls[0].bytesReceived()) << std::endl; - // timer.setTimePoint("verify"); - - // timer.setTimePoint("send"); - // for (u64 i = 0; i < n; i++) { - // for (u64 j = 0; j < N; j++) { - // throw RTE_LOC;// fix this - // // c = a delta + b - // // c - b = a delta - // //G left = a[i] * delta[j]; - // //G right = c[i][j] - b[i][j]; - // //if (left != right) { - // // std::cout << "bad " << i << "\n a[i] " << a[i] << " * delta[j] " << delta[j] << " = " << left << std::endl; - // // std::cout << "c[i][j] " << c[i][j] << " - b " << b[i][j] << " = " << right << std::endl; - // // throw RTE_LOC; - // //} - // } - // } - //} - - //{ - // PRNG prng(seed); - // block x = prng.get(); - // std::vector c(n), z0(n), z1(n); - - // SilentSubfieldVoleReceiver recv; - // SilentSubfieldVoleSender send; - - // recv.mMultType = MultType::ExConv7x24; - // send.mMultType = MultType::ExConv7x24; - - // recv.setTimer(timer); - // send.setTimer(timer); - - // // recv.mDebug = true; - // // send.mDebug = true; - - // auto chls = cp::LocalAsyncSocket::makePair(); - - // timer.setTimePoint("net"); - - // timer.setTimePoint("ot"); - // // fakeBase(n, nt, prng, delta, recv, send); - - // auto p0 = send.silentSend(x, span(z0), prng, chls[0]); - // auto p1 = recv.silentReceive(span(c), span(z1), prng, chls[1]); - - // eval(p0, p1); - // timer.setTimePoint("send"); - // for (u64 i = 0; i < n; ++i) { - // block left = x.gf128Mul(c[i]); - // block right = z1[i] ^ z0[i]; - // if (left != right) { - // std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << left << std::endl; - // std::cout << "z0[i] " << z0[i] << " - z1 " << z1[i] << " = " << right << std::endl; - // throw RTE_LOC; - // } - // } - //} - //timer.setTimePoint("done"); - // std::cout << timer << std::endl; -#else - throw UnitTestSkipped("not defined." LOCATION); -#endif - } - -} \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 3864bd9b..f2f1bb50 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -14,8 +14,8 @@ #include "libOTe_Tests/ExConvCode_Tests.h" #include "libOTe_Tests/EACode_Tests.h" #include "libOTe/Tools/LDPC/Mtx.h" -#include "libOTe_Tests/Subfield_Test.h" #include "libOTe_Tests/Pprf_Tests.h" + using namespace osuCrypto; namespace tests_libOTe { @@ -92,7 +92,6 @@ namespace tests_libOTe tc.add("OtExt_SoftSpokenMalicious21_Split_Test ", OtExt_SoftSpokenMalicious21_Split_Test); tc.add("DotExt_SoftSpokenMaliciousLeaky_Test ", DotExt_SoftSpokenMaliciousLeaky_Test); - tc.add("Subfield_Silent_Vole_test ", Subfield_Silent_Vole_test); tc.add("Vole_Noisy_test ", Vole_Noisy_test); tc.add("Vole_Silent_QuasiCyclic_test ", Vole_Silent_QuasiCyclic_test); tc.add("Vole_Silent_paramSweep_test ", Vole_Silent_paramSweep_test); diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index 155b2f9b..6d38d06e 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -15,13 +15,11 @@ using namespace oc; #include -#include "libOTe/Tools/Subfield/Subfield.h" -#include "libOTe/Vole/Subfield/SilentVoleSender.h" -#include "libOTe/Vole/Subfield/SilentVoleReceiver.h" +#include "libOTe/Tools/CoeffCtx.h" using namespace tests_libOTe; -template +template void Vole_Noisy_test_impl(u64 n) { PRNG prng(CCBlock); @@ -31,12 +29,13 @@ void Vole_Noisy_test_impl(u64 n) std::vector a(n), b(n); prng.get(c.data(), c.size()); - NoisySubfieldVoleReceiver recv; - NoisySubfieldVoleSender send; + NoisyVoleReceiver recv; + NoisyVoleSender send; auto chls = cp::LocalAsyncSocket::makePair(); - BitVector recvChoice = Trait::binaryDecomposition(delta); + Ctx ctx; + BitVector recvChoice = ctx.binaryDecomposition(delta); std::vector otRecvMsg(recvChoice.size()); std::vector> otSendMsg(recvChoice.size()); prng.get>(otSendMsg); @@ -47,8 +46,8 @@ void Vole_Noisy_test_impl(u64 n) // // a = b + c * delta // - auto p0 = recv.receive(c, a, prng, otSendMsg, chls[0]); - auto p1 = send.send(delta, b, prng, otRecvMsg, chls[1]); + auto p0 = recv.receive(c, a, prng, otSendMsg, chls[0], ctx); + auto p1 = send.send(delta, b, prng, otRecvMsg, chls[1], ctx); eval(p0, p1); @@ -56,8 +55,8 @@ void Vole_Noisy_test_impl(u64 n) { F prod, sum; - Trait::mul(prod, delta, c[i]); - Trait::minus(sum, a[i], b[i]); + ctx.mul(prod, delta, c[i]); + ctx.minus(sum, a[i], b[i]); if (prod != sum) { From 158d2d14a410e9b1ba09d37ac5972f938e10011d Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sun, 21 Jan 2024 23:55:20 -0800 Subject: [PATCH 14/23] expander opt fix --- libOTe/Tools/ExConvCode/Expander.h | 10 +++++----- libOTe_Tests/EACode_Tests.cpp | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libOTe/Tools/ExConvCode/Expander.h b/libOTe/Tools/ExConvCode/Expander.h index c952b044..55558627 100644 --- a/libOTe/Tools/ExConvCode/Expander.h +++ b/libOTe/Tools/ExConvCode/Expander.h @@ -103,11 +103,11 @@ namespace osuCrypto auto main = mMessageSize / 8 * 8; u64 i = 0; - for (; i < main; i += 8, output += 8) + for (; i < main; i += 8, rOutput+= 8) { if constexpr (Add == false) { - ctx.zero(output, output + 8); + ctx.zero(rOutput, rOutput + 8); } for (auto j = 0ull; j < mExpanderWeight; ++j) @@ -135,14 +135,14 @@ namespace osuCrypto if constexpr (Add == false) { - ctx.zero(output, output + (mMessageSize - i)); + ctx.zero(rOutput, rOutput + (mMessageSize - i)); } - for (; i < mMessageSize; ++i, ++output) + for (; i < mMessageSize; ++i, ++rOutput) { for (auto j = 0ull; j < mExpanderWeight; ++j) { - ctx.plus(*output, *output, *(input + prng.get())); + ctx.plus(*rOutput, *rOutput, *(input + prng.get())); } } } diff --git a/libOTe_Tests/EACode_Tests.cpp b/libOTe_Tests/EACode_Tests.cpp index 872ba6f9..1737e735 100644 --- a/libOTe_Tests/EACode_Tests.cpp +++ b/libOTe_Tests/EACode_Tests.cpp @@ -48,7 +48,7 @@ namespace osuCrypto throw RTE_LOC; if(i+1 < a0.size()) - sum += c0[i + 1]; + sum ^= c0[i + 1]; } u64 i = 0; From 58be0942c3ab120c4c8cb3eb42b973ac96e50bb9 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sun, 21 Jan 2024 23:56:18 -0800 Subject: [PATCH 15/23] cryptoTools bump --- cryptoTools | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cryptoTools b/cryptoTools index 5f90354b..0fe05285 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 5f90354b499adddbcf6861a3b4463e0724e5f719 +Subproject commit 0fe05285f4f22d520a31ed226ea757d7c3dac49c From bc74c3589b87a4483ca22d29a7075150c0f1d462 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 11:49:11 -0800 Subject: [PATCH 16/23] gf2 --- CMakePresets.json | 2 +- frontend/benchmark.h | 13 +- libOTe/Tools/CoeffCtx.h | 175 ++++++++++++------ libOTe/Tools/EACode/EACode.h | 4 + libOTe/Tools/ExConvCode/ExConvCode.h | 12 +- .../Silent/SilentOtExtReceiver.cpp | 10 +- .../TwoChooseOne/Silent/SilentOtExtReceiver.h | 2 +- .../TwoChooseOne/Silent/SilentOtExtSender.cpp | 6 +- .../TwoChooseOne/Silent/SilentOtExtSender.h | 2 +- libOTe/Vole/Silent/SilentVoleReceiver.h | 4 +- libOTe/Vole/Silent/SilentVoleSender.h | 4 +- libOTe/Vole/SoftSpokenOT/SmallFieldVole.h | 4 +- libOTe_Tests/EACode_Tests.cpp | 11 +- libOTe_Tests/ExConvCode_Tests.cpp | 87 +++++---- libOTe_Tests/Pprf_Tests.cpp | 2 +- libOTe_Tests/Vole_Tests.cpp | 15 +- 16 files changed, 223 insertions(+), 130 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 69e16ea3..20225f0a 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -59,7 +59,7 @@ "VERBOSE_FETCH": true, "ENABLE_SSE": true, "ENABLE_AVX": true, - "ENABLE_ASAN": true, + "ENABLE_ASAN": false, "COPROTO_ENABLE_BOOST": true, "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}", "CMAKE_PREFIX_PATH": "${sourceDir}/../out/install/${presetName}" diff --git a/frontend/benchmark.h b/frontend/benchmark.h index bd50c7b2..b67327cc 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -99,7 +99,7 @@ namespace osuCrypto timer.setTimePoint("_____________________"); for (u64 i = 0; i < trials; ++i) { - code.dualEncode(x, y, {}); + code.dualEncode(x, y, {}); timer.setTimePoint("encode"); } @@ -131,6 +131,8 @@ namespace osuCrypto // size for the accumulator (# random transitions) u64 a = cmd.getOr("a", roundUpTo(log2ceil(n), 8)); + bool gf128 = cmd.isSet("gf128"); + // verbose flag. bool v = cmd.isSet("v"); bool sys = cmd.isSet("sys"); @@ -154,7 +156,10 @@ namespace osuCrypto timer.setTimePoint("_____________________"); for (u64 i = 0; i < trials; ++i) { - code.dualEncode(x.begin(), {}); + if(gf128) + code.dualEncode(x.begin(), {}); + else + code.dualEncode(x.begin(), {}); timer.setTimePoint("encode"); } @@ -385,8 +390,8 @@ namespace osuCrypto try { - SilentSubfieldVoleSender sender; - SilentSubfieldVoleReceiver recver; + SilentSubfieldVoleSender sender; + SilentSubfieldVoleReceiver recver; u64 trials = cmd.getOr("t", 10); diff --git a/libOTe/Tools/CoeffCtx.h b/libOTe/Tools/CoeffCtx.h index 5de952ef..a0fc481e 100644 --- a/libOTe/Tools/CoeffCtx.h +++ b/libOTe/Tools/CoeffCtx.h @@ -8,44 +8,59 @@ namespace osuCrypto { /* * Primitive CoeffCtx for integers-like types + * + * This class implements the required functions to perform a vole + * + * The core functions are plus, minus, mul. However, additional function + * and types are required. */ struct CoeffCtxInteger { template - OC_FORCEINLINE void plus(R&& ret, F1&& lhs, F2&& rhs) { + OC_FORCEINLINE void plus(R&& ret, F1&& lhs, F2&& rhs) { ret = lhs + rhs; } template - OC_FORCEINLINE void minus(R&& ret, F1&& lhs, F2&& rhs) { + OC_FORCEINLINE void minus(R&& ret, F1&& lhs, F2&& rhs) { ret = lhs - rhs; } template - OC_FORCEINLINE void mul(R&& ret, F1&& lhs, F2&& rhs) { + OC_FORCEINLINE void mul(R&& ret, F1&& lhs, F2&& rhs) { ret = lhs * rhs; } template - OC_FORCEINLINE bool eq(F&& lhs, F&& rhs) { + OC_FORCEINLINE bool eq(F&& lhs, F&& rhs) { return lhs == rhs; } - // is F a field? - template - OC_FORCEINLINE bool isField() { + // is G a field? + template + OC_FORCEINLINE bool isField() { return false; // default. } - - - - + // For the base field G is an extension fields, + // mulConst should multiply x by some constant in G to linearly + // mix the components. Most of the LPN codes this library + // uses are binary and so for extension field this would + // result in the componets not interactive. This can lead + // to a splitting attack. To fix this we multiply by some + // non-zero G element. + // + // If your type is a scaler, e.g. Fp or Z2k, just return x. + template + OC_FORCEINLINE void mulConst(F& ret, const F& x) + { + ret = x; + } // the bit size require to prepresent F // the protocol will perform binary decomposition // of F using this many bits template - u64 bitSize() + u64 bitSize() { return sizeof(F) * 8; } @@ -56,15 +71,15 @@ namespace osuCrypto { // x = sum_{i = 0,...,n} 2^i * binaryDecomposition(x)[i] // template - OC_FORCEINLINE BitVector binaryDecomposition(F& x) { - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + OC_FORCEINLINE BitVector binaryDecomposition(F& x) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); return { (u8*)&x, sizeof(F) * 8 }; } - // sample an F using the randomness b. + // derive an F using the randomness b. template - OC_FORCEINLINE void fromBlock(F& ret, const block& b) { - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + OC_FORCEINLINE void fromBlock(F& ret, const block& b) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); if constexpr (sizeof(F) <= sizeof(block)) { @@ -84,8 +99,8 @@ namespace osuCrypto { // return the F element with value 2^power template - OC_FORCEINLINE void powerOfTwo(F& ret, u64 power) { - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + OC_FORCEINLINE void powerOfTwo(F& ret, u64 power) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); memset(&ret, 0, sizeof(F)); *BitIterator((u8*)&ret, power) = 1; } @@ -103,21 +118,21 @@ namespace osuCrypto { // resize Vec template - void resize(VecF& f, u64 size) + void resize(VecF& f, u64 size) { f.resize(size); } // the size of F when serialized. template - u64 byteSize() + u64 byteSize() { return sizeof(F); } // copy a single F element. template - OC_FORCEINLINE void copy(F& dst, const F& src) + OC_FORCEINLINE void copy(F& dst, const F& src) { dst = src; } @@ -125,7 +140,7 @@ namespace osuCrypto { // copy [begin,...,end) into [dstBegin, ...) // the iterators will point to the same types, i.e. F template - OC_FORCEINLINE void copy( + OC_FORCEINLINE void copy( SrcIter begin, SrcIter end, DstIter dstBegin) @@ -139,10 +154,10 @@ namespace osuCrypto { } // deserialize [begin,...,end) into [dstBegin, ...) - // begin will be a byte pointer/iterator. + // begin will be a u8 pointer/iterator. // dstBegin will be an F pointer/iterator template - void deserialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) + void deserialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) { // as written this function is a bit more general than strictly neccessary // due to serialize(...) redirecting here. @@ -151,12 +166,6 @@ namespace osuCrypto { static_assert(std::is_trivially_copyable::value, "source serialization types must be trivially_copyable."); static_assert(std::is_trivially_copyable::value, "destination serialization types must be trivially_copyable."); -#if __cplusplus >= 202002L - //std::contiguous_iterator<> - // _assert contigous iter in cpp20 -#endif - - // how many source elem do we have? auto srcN = std::distance(begin, end); if (srcN) @@ -177,8 +186,8 @@ namespace osuCrypto { auto beginU8 = (u8*)&*begin; auto dstBeginU8 = (u8*)&*dstBegin; - auto dstBackPtr = dstBeginU8 + (n -sizeof(DstType)); - auto dstBackIter = dstBegin + (dstN -1); + auto dstBackPtr = dstBeginU8 + (n - sizeof(DstType)); + auto dstBackIter = dstBegin + (dstN - 1); // try to deref the back. might bounds check. // And check that the pointer math works @@ -208,15 +217,22 @@ namespace osuCrypto { // begin will be an F pointer/iterator // dstBegin will be a byte pointer/iterator. template - void serialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) + void serialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) { // for primitive types serialization and deserializaion // are the same, a memcpy. deserialize(begin, end, dstBegin); } + + // If unsure you can just return iter. iter will be + // a Vec::iterator or const of it. + // This function allows for some compiler optimziations/ + // The idea is to return a pointer with the __restrict + // attibute. If this does not make sense for your Vec::iterator, + // just return the iterator. template - F* __restrict restrictPtr(Iter iter) + F* __restrict restrictPtr(Iter iter) { return &*iter; } @@ -225,10 +241,10 @@ namespace osuCrypto { // fill the range [begin,..., end) with zeros. // begin will be an F pointer/iterator. template - void zero(Iter begin, Iter end) + void zero(Iter begin, Iter end) { using F = std::remove_reference_t; - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); if (begin != end) { @@ -241,10 +257,10 @@ namespace osuCrypto { // fill the range [begin,..., end) with ones. // begin will be an F pointer/iterator. template - void one(Iter begin, Iter end) + void one(Iter begin, Iter end) { using F = std::remove_reference_t; - static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); if (begin != end) @@ -260,10 +276,9 @@ namespace osuCrypto { } } - // convert F into a string template - std::string str(F&& f) + std::string str(F&& f) { std::stringstream ss; if constexpr (std::is_same_v, u8>) @@ -276,34 +291,71 @@ namespace osuCrypto { }; - // CoeffCtx for GF fields. - // ^ operator is used for addition. - struct CoeffCtxGF : CoeffCtxInteger - { + + // block does not use operator* + struct CoeffCtxGF2 : CoeffCtxInteger + { template - OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { ret = lhs ^ rhs; } template - OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { + OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { ret = lhs ^ rhs; } + template + OC_FORCEINLINE void mul(F& ret, const F& lhs, const F& rhs) { + ret = lhs & rhs; + } + + //template + //OC_FORCEINLINE void mul(F& ret, const F& lhs, const bool& rhs) { + // ret = rhs ? lhs : zeroElem(); + //} // is F a field? template - OC_FORCEINLINE bool isField() { + OC_FORCEINLINE bool isField() { return true; // default. } + //template + //static OC_FORCEINLINE constexpr F zeroElem() + //{ + // static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); + // F r; + // memset(&r, 0, sizeof(F)); + // return r; + //} }; + // block does not use operator* - struct CoeffCtxGFBlock : CoeffCtxGF + struct CoeffCtxGF128 : CoeffCtxGF2 { - OC_FORCEINLINE void mul(block& ret, const block& lhs, const block& rhs) { + + OC_FORCEINLINE void mul(block& ret, const block& lhs, const block& rhs) { ret = lhs.gf128Mul(rhs); } + + // ret = x * 4234123421 mod 2^127 - 135 + OC_FORCEINLINE void mulConst(block& ret, const block& x) + { + // multiplication y modulo mod + block y(0, 4234123421); + static const constexpr std::uint64_t mod = 0b10000111; + const __m128i modulus = _mm_loadl_epi64((const __m128i*) & (mod)); + + block xy1 = _mm_clmulepi64_si128(x, y, (int)0x00); + block xy2 = _mm_clmulepi64_si128(x, y, 0x01); + xy1 = xy1 ^ _mm_slli_si128(xy2, 8); + xy2 = _mm_srli_si128(xy2, 8); + + /* reduce w.r.t. high half of mul256_high */ + auto tmp = _mm_clmulepi64_si128(xy2, modulus, 0x00); + ret = _mm_xor_si128(xy1, tmp); + } }; @@ -312,35 +364,35 @@ namespace osuCrypto { { using F = std::array; - OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { ret[i] = lhs[i] + rhs[i]; } } - OC_FORCEINLINE void plus(G& ret, const G& lhs, const G& rhs) { + OC_FORCEINLINE void plus(G& ret, const G& lhs, const G& rhs) { ret = lhs + rhs; } - OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) + OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { ret[i] = lhs[i] - rhs[i]; } } - OC_FORCEINLINE void minus(G& ret, const G& lhs, const G& rhs) { + OC_FORCEINLINE void minus(G& ret, const G& lhs, const G& rhs) { ret = lhs - rhs; } - OC_FORCEINLINE void mul(F& ret, const F& lhs, const G& rhs) + OC_FORCEINLINE void mul(F& ret, const F& lhs, const G& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { ret[i] = lhs[i] * rhs; } } - OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) + OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) { for (u64 i = 0; i < lhs.size(); ++i) { if (lhs[i] != rhs[i]) @@ -349,13 +401,13 @@ namespace osuCrypto { return true; } - OC_FORCEINLINE bool eq(const G& lhs, const G& rhs) + OC_FORCEINLINE bool eq(const G& lhs, const G& rhs) { return lhs == rhs; } // convert F into a string - std::string str(const F& f) + std::string str(const F& f) { auto delim = "{ "; std::stringstream ss; @@ -375,7 +427,7 @@ namespace osuCrypto { } // convert G into a string - std::string str(const G& g) + std::string str(const G& g) { std::stringstream ss; if constexpr (std::is_same_v, u8>) @@ -385,6 +437,7 @@ namespace osuCrypto { return ss.str(); } + }; template @@ -392,8 +445,8 @@ namespace osuCrypto { }; // GF128 vole - template<> struct DefaultCoeffCtx : CoeffCtxGFBlock {}; + template<> struct DefaultCoeffCtx : CoeffCtxGF128 {}; // OT - template<> struct DefaultCoeffCtx : CoeffCtxGFBlock {}; + template<> struct DefaultCoeffCtx : CoeffCtxGF2 {}; } diff --git a/libOTe/Tools/EACode/EACode.h b/libOTe/Tools/EACode/EACode.h index 77022721..e99ddf67 100644 --- a/libOTe/Tools/EACode/EACode.h +++ b/libOTe/Tools/EACode/EACode.h @@ -120,6 +120,7 @@ namespace osuCrypto for (u64 i = 0; i < main; ++i) { ctx.plus(xx[i + 1], xx[i + 1], xx[i]); + ctx.mulConst(xx[i + 1], xx[i + 1]); } } @@ -139,6 +140,9 @@ namespace osuCrypto { ctx.plus(xx1[i + 1], xx1[i + 1], xx1[i]); ctx.plus(xx2[i + 1], xx2[i + 1], xx2[i]); + ctx.mulConst(xx1[i + 1], xx1[i + 1]); + ctx.mulConst(xx2[i + 1], xx2[i + 1]); + } } diff --git a/libOTe/Tools/ExConvCode/ExConvCode.h b/libOTe/Tools/ExConvCode/ExConvCode.h index 277702e3..401e1420 100644 --- a/libOTe/Tools/ExConvCode/ExConvCode.h +++ b/libOTe/Tools/ExConvCode/ExConvCode.h @@ -250,12 +250,14 @@ namespace osuCrypto // Compute e[0,...,k-1] = G * e. template void ExConvCode::dualEncode( - Iter&& e, + Iter&& e_, CoeffCtx ctx) { static_assert(is_iterator::value, "must pass in an iterator to the data"); - (void)*(e + mCodeSize - 1); + (void)*(e_ + mCodeSize - 1); + + auto e = ctx.restrictPtr(e_); if (mSystematic) { @@ -275,7 +277,9 @@ namespace osuCrypto CoeffCtx::template Vec w; ctx.resize(w, mMessageSize); - mExpander.expand(e, w.begin(), ctx); + auto wIter = ctx.restrictPtr(w.begin()); + + mExpander.expand(e, wIter, ctx); setTimePoint("ExConv.encode.expand"); ctx.copy(w.begin(), w.end(), e); @@ -407,6 +411,7 @@ namespace osuCrypto if (!rangeCheck || xj < end) { ctx.plus(*xj, *xj, *xi); + ctx.mulConst(*xj, *xj); ++xj; } @@ -470,6 +475,7 @@ namespace osuCrypto if (!rangeCheck || xj < end) { ctx.plus(*xj, *xj, *xi); + ctx.mulConst(*xj, *xj); ++xj; } diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index d9d4fba6..c473459f 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -476,7 +476,7 @@ namespace osuCrypto b = AlignedUnVector(1), //deltaShare = block{}, i = u64{}, - sender = NoisyVoleSender{}, + sender = NoisyVoleSender{}, theirHash = std::array{}, myHash = std::array{}, ro = RandomOracle(32) @@ -658,7 +658,7 @@ namespace osuCrypto case osuCrypto::MultType::ExAcc40: { AlignedUnVector A2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mA.subspan(0, mEAEncoder.mCodeSize), A2, {}); + mEAEncoder.dualEncode(mA.subspan(0, mEAEncoder.mCodeSize), A2, {}); std::swap(mA, A2); break; } @@ -666,7 +666,7 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mA.begin(), {}); + mExConvEncoder.dualEncode(mA.begin(), {}); break; default: throw RTE_LOC; @@ -712,7 +712,7 @@ namespace osuCrypto { AlignedUnVector A2(mEAEncoder.mMessageSize); AlignedUnVector C2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode2( + mEAEncoder.dualEncode2( mA.subspan(0, mEAEncoder.mCodeSize), A2, mC.subspan(0, mEAEncoder.mCodeSize), C2, {}); @@ -726,7 +726,7 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode2( + mExConvEncoder.dualEncode2( mA.begin(), mC.begin(), {}); diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h index 4a5e2683..0ac227e5 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h @@ -91,7 +91,7 @@ namespace osuCrypto block mMalCheckX = ZeroBlock; // The ggm tree thats used to generate the sparse vectors. - RegularPprfReceiver mGen; + RegularPprfReceiver mGen; // The type of compress we will use to generate the // dense vectors from the sparse vectors. diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp index 148bc1c0..73a927a5 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp @@ -428,7 +428,7 @@ namespace osuCrypto c = AlignedUnVector(1), //deltaShare = block{}, i = u64{}, - recver = NoisyVoleReceiver{}, + recver = NoisyVoleReceiver{}, myHash = std::array{}, ro = RandomOracle(32) ); @@ -485,7 +485,7 @@ namespace osuCrypto if (mTimer) mEAEncoder.setTimer(getTimer()); AlignedUnVector B2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mB.subspan(0, mEAEncoder.mCodeSize), B2, {}); + mEAEncoder.dualEncode(mB.subspan(0, mEAEncoder.mCodeSize), B2, {}); std::swap(mB, B2); break; } @@ -493,7 +493,7 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mB.begin(), {}); + mExConvEncoder.dualEncode(mB.begin(), {}); break; default: throw RTE_LOC; diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h index 1ab10edb..88ee73e5 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h @@ -110,7 +110,7 @@ namespace osuCrypto #endif // The ggm tree thats used to generate the sparse vectors. - RegularPprfSender mGen; + RegularPprfSender mGen; // The type of compress we will use to generate the // dense vectors from the sparse vectors. diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index 5f8ace0a..74980f78 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -41,7 +41,7 @@ namespace osuCrypto static constexpr bool MaliciousSupported = std::is_same_v&& - std::is_same_v; + std::is_same_v; enum class State { @@ -543,7 +543,7 @@ namespace osuCrypto if constexpr ( std::is_same_v && std::is_same_v && - std::is_same_v) + std::is_same_v) { QuasiCyclicCode encoder; encoder.init2(mRequestSize, mNoiseVecSize); diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index aa091c60..15a0a4a9 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -40,7 +40,7 @@ namespace osuCrypto static constexpr bool MaliciousSupported = std::is_same_v&& - std::is_same_v; + std::is_same_v; enum class State { @@ -398,7 +398,7 @@ namespace osuCrypto if constexpr ( std::is_same_v && std::is_same_v && - std::is_same_v) + std::is_same_v) { QuasiCyclicCode encoder; encoder.init2(mRequestSize, mNoiseVecSize); diff --git a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h index 68120264..acb12481 100644 --- a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h +++ b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h @@ -123,7 +123,7 @@ namespace osuCrypto { public: - using PprfSender = RegularPprfSender; + using PprfSender = RegularPprfSender; private: SmallFieldVoleSender(const SmallFieldVoleSender& b) @@ -209,7 +209,7 @@ namespace osuCrypto class SmallFieldVoleReceiver : public SmallFieldVoleBase { public: - using PprfReceiver = RegularPprfReceiver; + using PprfReceiver = RegularPprfReceiver; private: SmallFieldVoleReceiver(const SmallFieldVoleReceiver& b) diff --git a/libOTe_Tests/EACode_Tests.cpp b/libOTe_Tests/EACode_Tests.cpp index 1737e735..b1eaa5ce 100644 --- a/libOTe_Tests/EACode_Tests.cpp +++ b/libOTe_Tests/EACode_Tests.cpp @@ -39,16 +39,19 @@ namespace osuCrypto prng.get(c0.data(), c0.size()); auto a0 = c0; - code.accumulate(a0, {}); - + code.accumulate(a0, {}); + CoeffCtxGF128 ctx; block sum = c0[0]; for (u64 i = 0; i < a0.size(); ++i) { if (a0[i] != sum) throw RTE_LOC; - if(i+1 < a0.size()) + if (i + 1 < a0.size()) + { sum ^= c0[i + 1]; + ctx.mulConst(sum, sum); + } } u64 i = 0; @@ -75,7 +78,7 @@ namespace osuCrypto } } - code.dualEncode(c0, m1, {}); + code.dualEncode(c0, m1, {}); if (m0 != m1) throw RTE_LOC; diff --git a/libOTe_Tests/ExConvCode_Tests.cpp b/libOTe_Tests/ExConvCode_Tests.cpp index 0a35e1d7..5a2ed345 100644 --- a/libOTe_Tests/ExConvCode_Tests.cpp +++ b/libOTe_Tests/ExConvCode_Tests.cpp @@ -65,32 +65,39 @@ namespace osuCrypto for (i64 i = 0; i < x1.size() - aw - 1; ++i) { prng.get(rand.data(), rand.size()); - code.accOne(x1.begin() + i, x1.end(), rand.data(), ctx, std::integral_constant{}); + code.accOne(x1.data() + i, x1.data()+n, rand.data(), ctx, std::integral_constant{}); if (aw == 16) - code.accOne(x2.begin() + i, x2.end(), rand.data(), ctx, std::integral_constant{}); + code.accOne(x2.data() + i, x2.data()+n, rand.data(), ctx, std::integral_constant{}); ctx.plus(x3[i + 1], x3[i + 1], x3[i]); + //std::cout << "x" << i + 1 << " " << x3[i + 1] << " -> "; + ctx.mulConst(x3[i + 1], x3[i + 1]); + //std::cout << x3[i + 1] << std::endl;; + assert(aw <= 64); + u64 bits = 0; + memcpy(&bits, rand.data(), std::min(rand.size(), 8)); for (u64 j = 0; j < aw && (i + j + 2) < x3.size(); ++j) { - if (*BitIterator(rand.data(), j)) + if (bits & 1) { ctx.plus(x3[i + j + 2], x3[i + j + 2], x3[i]); } + bits >>= 1; } for (u64 j = i; j < x1.size() && j < i + aw + 2; ++j) { if (aw == 16 && x1[j] != x2[j]) { - std::cout << j << " " << (x1[j]) << " " << (x2[j]) << std::endl; + std::cout << j << " " << ctx.str(x1[j]) << " " << ctx.str(x2[j]) << std::endl; throw RTE_LOC; } if (x1[j] != x3[j]) { - std::cout << j << " " << (x1[j]) << " " << (x3[j]) << std::endl; + std::cout << j << " " << ctx.str(x1[j]) << " " << ctx.str(x3[j]) << std::endl; throw RTE_LOC; } } @@ -100,17 +107,17 @@ namespace osuCrypto x4 = x1; //std::cout << std::endl; - code.accumulateFixed(x1.begin() + accOffset, ctx); + code.accumulateFixed(x1.data() + accOffset, ctx); if (aw == 16) { - code.accumulateFixed(x2.begin() + accOffset, ctx); + code.accumulateFixed(x2.data() + accOffset, ctx); if (x1 != x2) { for (u64 i = 0; i < x1.size(); ++i) { - std::cout << i << " " << (x1[i]) << " " << (x2[i]) << std::endl; + std::cout << i << " " << ctx.str(x1[i]) << " " << ctx.str(x2[i]) << std::endl; } throw RTE_LOC; } @@ -121,8 +128,8 @@ namespace osuCrypto u8* mtxCoeffIter = (u8*)coeffGen.mBuffer.data(); auto mtxCoeffEnd = mtxCoeffIter + coeffGen.mBuffer.size() * sizeof(block) - divCeil(aw, 8); - auto xi = x3.begin() + accOffset; - auto end = x3.end(); + auto xi = x3.data() + accOffset; + auto end = x3.data() + n; while (xi < end) { if (mtxCoeffIter > mtxCoeffEnd) @@ -137,14 +144,19 @@ namespace osuCrypto if (xj != end) { ctx.plus(*xj, *xj, *xi); + ctx.mulConst(*xj, *xj); ++xj; } + //assert((mtxCoeffEnd - mtxCoeffIter) * 8 >= aw); + u64 bits = 0; + memcpy(&bits, mtxCoeffIter, divCeil(aw,8)); for (u64 j = 0; j < aw && xj != end; ++j, ++xj) { - if (*BitIterator(mtxCoeffIter, j)) + if (bits &1) { ctx.plus(*xj, *xj, *xi); } + bits >>= 1; } ++mtxCoeffIter; @@ -156,7 +168,7 @@ namespace osuCrypto { for (u64 i = 0; i < x1.size(); ++i) { - std::cout << i << " " << (x1[i]) << " " << (x3[i]) << std::endl; + std::cout << i << " " << ctx.str(x1[i]) << " " << ctx.str(x3[i]) << std::endl; } throw RTE_LOC; } @@ -167,9 +179,9 @@ namespace osuCrypto if (sys) { - std::copy(x1.begin(), x1.begin() + k, y1.begin()); + std::copy(x1.data(), x1.data() + k, y1.data()); y2 = y1; - code.mExpander.expand(x1.cbegin() + accOffset, y1.begin()); + code.mExpander.expand(x1.data() + accOffset, y1.data()); //using P = std::pair::const_iterator, typename std::vector::iterator>; //auto p = P{ x1.cbegin() + accOffset, y1.begin() }; //code.mExpander.expandMany( @@ -178,7 +190,7 @@ namespace osuCrypto } else { - code.mExpander.expand(x1.cbegin() + accOffset, y1.begin()); + code.mExpander.expand(x1.data() + accOffset, y1.data()); } u64 i = 0; @@ -214,26 +226,35 @@ namespace osuCrypto throw RTE_LOC; } - + //block mult2(block x, int imm8) + //{ + // assert(imm8 < 2); + // if (imm8) + // { + // // mult x[1] * 2 + + // } + // else + // { + // // x[0] * 2 + // __m128i carry = _mm_slli_si128(x, 8); + // carry = _mm_srli_epi64(carry, 63); + // x = _mm_slli_epi64(x, 1); + // return _mm_or_si128(x, carry); + + // //return _mm_slli_si128(x, 8); + // } + // //TEMP[i] : = (TEMP1[0] and TEMP2[i]) + // // FOR j : = 1 to i + // // TEMP[i] : = TEMP[i] XOR(TEMP1[j] AND TEMP2[i - j]) + // // ENDFOR + // //dst[i] : = TEMP[i] + //} void ExConvCode_encode_basic_test(const oc::CLP& cmd) { - - //std::vector i0, o0; - //std::vector i1, o1; - //std::vector i2, o2; - - //ExpanderCode ex; - //ex.expandMany( - // std::tuple{ - // std::pair{i0.begin(), o0.begin()}, - // std::pair{i1.begin(), o1.begin()}, - // std::pair{i2.begin(), o2.begin()} - // }, {}); - - - auto K = cmd.getManyOr("k", { 16ul, 64, 4353 }); + auto K = cmd.getManyOr("k", { 32ul, 333 }); auto R = cmd.getManyOr("R", { 2.0, 3.0 }); auto Bw = cmd.getManyOr("bw", { 7, 21 }); auto Aw = cmd.getManyOr("aw", { 16, 24, 29 }); @@ -244,8 +265,8 @@ namespace osuCrypto auto n = k * r; exConvTest(k, n, bw, aw, sys); exConvTest(k, n, bw, aw, sys); - exConvTest(k, n, bw, aw, sys); - exConvTest, CoeffCtxArray>(k, n, bw, aw, sys); + exConvTest(k, n, bw, aw, sys); + exConvTest, CoeffCtxArray>(k, n, bw, aw, sys); } } diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp index 2e95de2e..7c028c25 100644 --- a/libOTe_Tests/Pprf_Tests.cpp +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -151,7 +151,7 @@ void Tools_Pprf_expandOne_test(const oc::CLP& cmd) { Tools_Pprf_expandOne_test_impl(domain, program); - Tools_Pprf_expandOne_test_impl(domain, program); + Tools_Pprf_expandOne_test_impl(domain, program); Tools_Pprf_expandOne_test_impl, u32, CoeffCtxArray>(domain, program); } diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index 6d38d06e..8eff6728 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -71,7 +71,7 @@ void Vole_Noisy_test(const oc::CLP& cmd) { Vole_Noisy_test_impl(n); Vole_Noisy_test_impl(n); - Vole_Noisy_test_impl(n); + Vole_Noisy_test_impl(n); Vole_Noisy_test_impl, u32, CoeffCtxArray>(n); } } @@ -184,7 +184,8 @@ void Vole_Silent_paramSweep_test(const oc::CLP& cmd) for (u64 n : {128, 45364}) { Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); - Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + //Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); Vole_Silent_test_impl, u32, CoeffCtxArray>(n, DefaultMultType, debug, false, false); } } @@ -194,7 +195,7 @@ void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) #if defined(ENABLE_SILENTOT) && defined(ENABLE_BITPOLYMUL) auto debug = cmd.isSet("debug"); for (u64 n : {128, 333}) - Vole_Silent_test_impl(n, MultType::QuasiCyclic, debug, false, false); + Vole_Silent_test_impl(n, MultType::QuasiCyclic, debug, false, false); #else throw UnitTestSkipped("ENABLE_BITPOLYMUL not defined." LOCATION); #endif @@ -206,7 +207,7 @@ void Vole_Silent_baseOT_test(const oc::CLP& cmd) auto debug = cmd.isSet("debug"); u64 n = 128; Vole_Silent_test_impl(n, DefaultMultType, debug, true, false); - Vole_Silent_test_impl(n, DefaultMultType, debug, true, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, true, false); Vole_Silent_test_impl, u32, CoeffCtxArray>(n, DefaultMultType, debug, true, false); } @@ -217,7 +218,7 @@ void Vole_Silent_mal_test(const oc::CLP& cmd) auto debug = cmd.isSet("debug"); for (u64 n : {45364}) { - Vole_Silent_test_impl(n, DefaultMultType, debug, false, true); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, true); } } @@ -282,8 +283,8 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) cp::BufferingSocket chls[2]; - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentSubfieldVoleReceiver recv; + SilentSubfieldVoleSender send; send.mMalType = SilentSecType::SemiHonest; recv.mMalType = SilentSecType::SemiHonest; From eeb3a240e439088f7aef2cb293b63031ef66f5ea Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 12:17:49 -0800 Subject: [PATCH 17/23] cleanup and gf2 vole --- frontend/benchmark.h | 17 +++--- libOTe/Tools/CoeffCtx.h | 38 +++++++----- libOTe/Vole/Silent/SilentVoleReceiver.h | 77 +++++++++++++++---------- libOTe/Vole/Silent/SilentVoleSender.h | 43 ++++++++------ libOTe_Tests/Vole_Tests.cpp | 10 ++-- 5 files changed, 109 insertions(+), 76 deletions(-) diff --git a/frontend/benchmark.h b/frontend/benchmark.h index b67327cc..63984a5e 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -390,8 +390,8 @@ namespace osuCrypto try { - SilentSubfieldVoleSender sender; - SilentSubfieldVoleReceiver recver; + SilentVoleSender sender; + SilentVoleReceiver recver; u64 trials = cmd.getOr("t", 10); @@ -412,11 +412,14 @@ namespace osuCrypto baseSend[i] = prng.get(); baseRecv[i] = baseSend[i][baseChoice[i]]; } - - sender.mOtExtRecver.setBaseOts(baseSend); - recver.mOtExtRecver.setBaseOts(baseSend); - sender.mOtExtSender.setBaseOts(baseRecv, baseChoice); - recver.mOtExtSender.setBaseOts(baseRecv, baseChoice); + sender.mOtExtRecver.emplace(); + sender.mOtExtSender.emplace(); + recver.mOtExtRecver.emplace(); + recver.mOtExtSender.emplace(); + sender.mOtExtRecver->setBaseOts(baseSend); + recver.mOtExtRecver->setBaseOts(baseSend); + sender.mOtExtSender->setBaseOts(baseRecv, baseChoice); + recver.mOtExtSender->setBaseOts(baseRecv, baseChoice); PRNG prng0(ZeroBlock), prng1(ZeroBlock); block delta = prng0.get(); diff --git a/libOTe/Tools/CoeffCtx.h b/libOTe/Tools/CoeffCtx.h index a0fc481e..9b5423d5 100644 --- a/libOTe/Tools/CoeffCtx.h +++ b/libOTe/Tools/CoeffCtx.h @@ -73,7 +73,7 @@ namespace osuCrypto { template OC_FORCEINLINE BitVector binaryDecomposition(F& x) { static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); - return { (u8*)&x, sizeof(F) * 8 }; + return { (u8*)&x, bitSize() }; } // derive an F using the randomness b. @@ -309,10 +309,22 @@ namespace osuCrypto { ret = lhs & rhs; } - //template - //OC_FORCEINLINE void mul(F& ret, const F& lhs, const bool& rhs) { - // ret = rhs ? lhs : zeroElem(); - //} + template + OC_FORCEINLINE void mul(F& ret, const F& lhs, const bool& rhs) { + ret = rhs ? lhs : zeroElem(); + } + + // the bit size require to prepresent F + // the protocol will perform binary decomposition + // of F using this many bits + template + u64 bitSize() + { + if (std::is_same::value) + return 1; + else + return sizeof(F) * 8; + } // is F a field? template @@ -320,14 +332,14 @@ namespace osuCrypto { return true; // default. } - //template - //static OC_FORCEINLINE constexpr F zeroElem() - //{ - // static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); - // F r; - // memset(&r, 0, sizeof(F)); - // return r; - //} + template + static OC_FORCEINLINE constexpr F zeroElem() + { + static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); + F r; + memset(&r, 0, sizeof(F)); + return r; + } }; diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index 74980f78..2f7e8226 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -34,14 +34,14 @@ namespace osuCrypto typename G = F, typename Ctx = DefaultCoeffCtx > - class SilentSubfieldVoleReceiver : public TimerAdapter + class SilentVoleReceiver : public TimerAdapter { public: static constexpr u64 mScaler = 2; static constexpr bool MaliciousSupported = - std::is_same_v&& - std::is_same_v; + std::is_same_v && std::is_same_v; + enum class State { @@ -112,8 +112,8 @@ namespace osuCrypto #ifdef ENABLE_SOFTSPOKEN_OT - SoftSpokenMalOtSender mOtExtSender; - SoftSpokenMalOtReceiver mOtExtRecver; + macoro::optional mOtExtSender; + macoro::optional mOtExtRecver; #endif // // sets the Iknp base OTs that are then used to extend @@ -192,24 +192,29 @@ namespace osuCrypto if (mBaseType == SilentBaseType::BaseExtend) { #ifdef ENABLE_SOFTSPOKEN_OT + if (!mOtExtRecver) + mOtExtRecver.emplace(); + + if (!mOtExtSender) + mOtExtSender.emplace(); - if (mOtExtSender.hasBaseOts() == false) + if (mOtExtSender->hasBaseOts() == false) { - msg.resize(msg.size() + mOtExtSender.baseOtCount()); - bb.resize(mOtExtSender.baseOtCount()); + msg.resize(msg.size() + mOtExtSender->baseOtCount()); + bb.resize(mOtExtSender->baseOtCount()); bb.randomize(prng); choice.append(bb); - MC_AWAIT(mOtExtRecver.receive(choice, msg, prng, chl)); + MC_AWAIT(mOtExtRecver->receive(choice, msg, prng, chl)); - mOtExtSender.setBaseOts( + mOtExtSender->setBaseOts( span(msg).subspan( - msg.size() - mOtExtSender.baseOtCount(), - mOtExtSender.baseOtCount()), + msg.size() - mOtExtSender->baseOtCount(), + mOtExtSender->baseOtCount()), bb); - msg.resize(msg.size() - mOtExtSender.baseOtCount()); - MC_AWAIT(nv.receive(noiseVals, baseAs, prng, mOtExtSender, chl, mCtx)); + msg.resize(msg.size() - mOtExtSender->baseOtCount()); + MC_AWAIT(nv.receive(noiseVals, baseAs, prng, *mOtExtSender, chl, mCtx)); } else { @@ -219,8 +224,8 @@ namespace osuCrypto MC_AWAIT( macoro::when_all_ready( - nv.receive(noiseVals, baseAs, prng2, mOtExtSender, chl2, mCtx), - mOtExtRecver.receive(choice, msg, prng, chl) + nv.receive(noiseVals, baseAs, prng2, *mOtExtSender, chl2, mCtx), + mOtExtRecver->receive(choice, msg, prng, chl) )); } #else @@ -323,26 +328,34 @@ namespace osuCrypto // and perform a noicy vole to get a = b + mD * c - VecG zero, one; - mCtx.resize(zero, 1); - mCtx.resize(one, 1); - mCtx.zero(zero.begin(), zero.end()); - mCtx.one(one.begin(), one.end()); mCtx.resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); - for (size_t i = 0; i < mNumPartitions; i++) + + if (mCtx.bitSize() == 1) { - mCtx.fromBlock(mBaseC[i], prng.get()); - - // must not be zero. - while(mCtx.eq(zero[0], mBaseC[i])) + mCtx.one(mBaseC.begin(), mBaseC.begin() + mNumPartitions); + } + else + { + VecG zero, one; + mCtx.resize(zero, 1); + mCtx.resize(one, 1); + mCtx.zero(zero.begin(), zero.end()); + mCtx.one(one.begin(), one.end()); + for (size_t i = 0; i < mNumPartitions; i++) + { mCtx.fromBlock(mBaseC[i], prng.get()); - // if we are not a field, then the noise should be odd. - if (mCtx.isField() == false) - { - u8 odd = mCtx.binaryDecomposition(mBaseC[i])[0]; - if (odd) - mCtx.plus(mBaseC[i], mBaseC[i], one[0]); + // must not be zero. + while(mCtx.eq(zero[0], mBaseC[i])) + mCtx.fromBlock(mBaseC[i], prng.get()); + + // if we are not a field, then the noise should be odd. + if (mCtx.isField() == false) + { + u8 odd = mCtx.binaryDecomposition(mBaseC[i])[0]; + if (odd) + mCtx.plus(mBaseC[i], mBaseC[i], one[0]); + } } } diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index 15a0a4a9..bc79c379 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -33,14 +33,13 @@ namespace osuCrypto typename G = F, typename Ctx = DefaultCoeffCtx > - class SilentSubfieldVoleSender : public TimerAdapter + class SilentVoleSender : public TimerAdapter { public: static constexpr u64 mScaler = 2; static constexpr bool MaliciousSupported = - std::is_same_v&& - std::is_same_v; + std::is_same_v && std::is_same_v; enum class State { @@ -103,11 +102,14 @@ namespace osuCrypto block mDeltaShare; #ifdef ENABLE_SOFTSPOKEN_OT - SoftSpokenMalOtSender mOtExtSender; - SoftSpokenMalOtReceiver mOtExtRecver; + macoro::optional mOtExtSender; + macoro::optional mOtExtRecver; #endif - + bool hasSilentBaseOts()const + { + return mGen.hasBaseOts(); + } u64 baseVoleCount() const @@ -147,18 +149,23 @@ namespace osuCrypto { #ifdef ENABLE_SOFTSPOKEN_OT - if (mOtExtRecver.hasBaseOts() == false) + if (!mOtExtSender) + mOtExtSender = SoftSpokenMalOtSender{}; + if (!mOtExtRecver) + mOtExtRecver = SoftSpokenMalOtReceiver{}; + + if (mOtExtRecver->hasBaseOts() == false) { - msg.resize(msg.size() + mOtExtRecver.baseOtCount()); - MC_AWAIT(mOtExtSender.send(msg, prng, chl)); + msg.resize(msg.size() + mOtExtRecver->baseOtCount()); + MC_AWAIT(mOtExtSender->send(msg, prng, chl)); - mOtExtRecver.setBaseOts( + mOtExtRecver->setBaseOts( span>(msg).subspan( - msg.size() - mOtExtRecver.baseOtCount(), - mOtExtRecver.baseOtCount())); - msg.resize(msg.size() - mOtExtRecver.baseOtCount()); + msg.size() - mOtExtRecver->baseOtCount(), + mOtExtRecver->baseOtCount())); + msg.resize(msg.size() - mOtExtRecver->baseOtCount()); - MC_AWAIT(nv.send(delta, b, prng, mOtExtRecver, chl, mCtx)); + MC_AWAIT(nv.send(delta, b, prng, *mOtExtRecver, chl, mCtx)); } else { @@ -167,19 +174,17 @@ namespace osuCrypto MC_AWAIT( macoro::when_all_ready( - nv.send(delta, b, prng2, mOtExtRecver, chl2, mCtx), - mOtExtSender.send(msg, prng, chl))); + nv.send(delta, b, prng2, *mOtExtRecver, chl2, mCtx), + mOtExtSender->send(msg, prng, chl))); } #else - + throw RTE_LOC; #endif } else { chl2 = chl.fork(); prng2.SetSeed(prng.get()); - //MC_AWAIT(baseOt.send(msg, prng, chl)); - //MC_AWAIT(nv.send(delta, b, prng2, baseOt, chl2)); MC_AWAIT( macoro::when_all_ready( nv.send(delta, b, prng2, baseOt, chl2, mCtx), diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index 8eff6728..1016d54c 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -139,8 +139,8 @@ void Vole_Silent_test_impl(u64 n, MultType type, bool debug, bool doFakeBase, bo auto chls = cp::LocalAsyncSocket::makePair(); - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentVoleReceiver recv; + SilentVoleSender send; recv.mMultType = type; send.mMultType = type; recv.mDebug = debug; @@ -185,7 +185,7 @@ void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); - //Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); Vole_Silent_test_impl, u32, CoeffCtxArray>(n, DefaultMultType, debug, false, false); } } @@ -283,8 +283,8 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) cp::BufferingSocket chls[2]; - SilentSubfieldVoleReceiver recv; - SilentSubfieldVoleSender send; + SilentVoleReceiver recv; + SilentVoleSender send; send.mMalType = SilentSecType::SemiHonest; recv.mMalType = SilentSecType::SemiHonest; From 1fd141a3d64c992a9dcf9af4a05ccc0b7ba93a58 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 13:20:26 -0800 Subject: [PATCH 18/23] linux compile and fixes --- libOTe/Tools/CoeffCtx.h | 18 ++++-- libOTe/Tools/ExConvCode/ExConvCode.h | 80 +++++++++++++++---------- libOTe/Tools/ExConvCode/Expander.h | 4 +- libOTe/Tools/Pprf/PprfUtil.h | 2 +- libOTe/Tools/Pprf/RegularPprf.h | 22 +++---- libOTe/Vole/Noisy/NoisyVoleReceiver.h | 6 +- libOTe/Vole/Noisy/NoisyVoleSender.h | 8 +-- libOTe/Vole/Silent/SilentVoleReceiver.h | 14 ++--- libOTe/Vole/Silent/SilentVoleSender.h | 2 +- libOTe_Tests/EACode_Tests.cpp | 1 - libOTe_Tests/ExConvCode_Tests.cpp | 8 +-- libOTe_Tests/Pprf_Tests.cpp | 11 ++-- libOTe_Tests/Vole_Tests.cpp | 6 +- 13 files changed, 105 insertions(+), 77 deletions(-) diff --git a/libOTe/Tools/CoeffCtx.h b/libOTe/Tools/CoeffCtx.h index 9b5423d5..7b8312ac 100644 --- a/libOTe/Tools/CoeffCtx.h +++ b/libOTe/Tools/CoeffCtx.h @@ -3,6 +3,7 @@ #include "cryptoTools/Common/BitIterator.h" #include "cryptoTools/Common/BitVector.h" #include +#include namespace osuCrypto { @@ -453,12 +454,21 @@ namespace osuCrypto { }; template - struct DefaultCoeffCtx : CoeffCtxInteger { + struct DefaultCoeffCtx_t { + using type = CoeffCtxInteger; }; // GF128 vole - template<> struct DefaultCoeffCtx : CoeffCtxGF128 {}; + template<> + struct DefaultCoeffCtx_t { + using type = CoeffCtxGF128; + }; + + // OT, gf2 + template<> struct DefaultCoeffCtx_t { + using type = CoeffCtxGF2; + }; - // OT - template<> struct DefaultCoeffCtx : CoeffCtxGF2 {}; + template + using DefaultCoeffCtx = typename DefaultCoeffCtx_t::type; } diff --git a/libOTe/Tools/ExConvCode/ExConvCode.h b/libOTe/Tools/ExConvCode/ExConvCode.h index 401e1420..0a574fcc 100644 --- a/libOTe/Tools/ExConvCode/ExConvCode.h +++ b/libOTe/Tools/ExConvCode/ExConvCode.h @@ -173,8 +173,7 @@ namespace osuCrypto Iter&& xi, Iter&& end, u8* matrixCoeff, - CoeffCtx& ctx, - std::integral_constant); + CoeffCtx& ctx); // accumulating row i. generic version template< @@ -183,12 +182,11 @@ namespace osuCrypto bool rangeCheck, typename Iter > - OC_FORCEINLINE void accOne( + OC_FORCEINLINE void accOneGen( Iter&& xi, Iter&& end, u8* matrixCoeff, - CoeffCtx& ctx, - std::integral_constant); + CoeffCtx& ctx); // accumulate x onto itself. @@ -257,7 +255,7 @@ namespace osuCrypto (void)*(e_ + mCodeSize - 1); - auto e = ctx.restrictPtr(e_); + auto e = ctx.template restrictPtr(e_); if (mSystematic) { @@ -275,9 +273,9 @@ namespace osuCrypto accumulate(e, ctx); setTimePoint("ExConv.encode.accumulate"); - CoeffCtx::template Vec w; + typename CoeffCtx::template Vec w; ctx.resize(w, mMessageSize); - auto wIter = ctx.restrictPtr(w.begin()); + auto wIter = ctx.template restrictPtr(w.begin()); mExpander.expand(e, wIter, ctx); setTimePoint("ExConv.encode.expand"); @@ -358,14 +356,22 @@ namespace osuCrypto assert((((b >> 7) & 1) ? *xi : ZeroBlock) == tt[7]); // xj += bj * xi - if (rangeCheck && xj + 0 == end) return; ctx.plus(*(xj + 0), *(xj + 0), tt[0]); - if (rangeCheck && xj + 1 == end) return; ctx.plus(*(xj + 1), *(xj + 1), tt[1]); - if (rangeCheck && xj + 2 == end) return; ctx.plus(*(xj + 2), *(xj + 2), tt[2]); - if (rangeCheck && xj + 3 == end) return; ctx.plus(*(xj + 3), *(xj + 3), tt[3]); - if (rangeCheck && xj + 4 == end) return; ctx.plus(*(xj + 4), *(xj + 4), tt[4]); - if (rangeCheck && xj + 5 == end) return; ctx.plus(*(xj + 5), *(xj + 5), tt[5]); - if (rangeCheck && xj + 6 == end) return; ctx.plus(*(xj + 6), *(xj + 6), tt[6]); - if (rangeCheck && xj + 7 == end) return; ctx.plus(*(xj + 7), *(xj + 7), tt[7]); + if (rangeCheck && xj + 0 == end) return; + ctx.plus(*(xj + 0), *(xj + 0), tt[0]); + if (rangeCheck && xj + 1 == end) return; + ctx.plus(*(xj + 1), *(xj + 1), tt[1]); + if (rangeCheck && xj + 2 == end) return; + ctx.plus(*(xj + 2), *(xj + 2), tt[2]); + if (rangeCheck && xj + 3 == end) return; + ctx.plus(*(xj + 3), *(xj + 3), tt[3]); + if (rangeCheck && xj + 4 == end) return; + ctx.plus(*(xj + 4), *(xj + 4), tt[4]); + if (rangeCheck && xj + 5 == end) return; + ctx.plus(*(xj + 5), *(xj + 5), tt[5]); + if (rangeCheck && xj + 6 == end) return; + ctx.plus(*(xj + 6), *(xj + 6), tt[6]); + if (rangeCheck && xj + 7 == end) return; + ctx.plus(*(xj + 7), *(xj + 7), tt[7]); } else #endif @@ -379,14 +385,22 @@ namespace osuCrypto auto b6 = b & 64; auto b7 = b & 128; - if (rangeCheck && xj + 0 == end) return; if (b0) ctx.plus(*(xj + 0), *(xj + 0), *xi); - if (rangeCheck && xj + 1 == end) return; if (b1) ctx.plus(*(xj + 1), *(xj + 1), *xi); - if (rangeCheck && xj + 2 == end) return; if (b2) ctx.plus(*(xj + 2), *(xj + 2), *xi); - if (rangeCheck && xj + 3 == end) return; if (b3) ctx.plus(*(xj + 3), *(xj + 3), *xi); - if (rangeCheck && xj + 4 == end) return; if (b4) ctx.plus(*(xj + 4), *(xj + 4), *xi); - if (rangeCheck && xj + 5 == end) return; if (b5) ctx.plus(*(xj + 5), *(xj + 5), *xi); - if (rangeCheck && xj + 6 == end) return; if (b6) ctx.plus(*(xj + 6), *(xj + 6), *xi); - if (rangeCheck && xj + 7 == end) return; if (b7) ctx.plus(*(xj + 7), *(xj + 7), *xi); + if (rangeCheck && xj + 0 == end) return; + if (b0) ctx.plus(*(xj + 0), *(xj + 0), *xi); + if (rangeCheck && xj + 1 == end) return; + if (b1) ctx.plus(*(xj + 1), *(xj + 1), *xi); + if (rangeCheck && xj + 2 == end) return; + if (b2) ctx.plus(*(xj + 2), *(xj + 2), *xi); + if (rangeCheck && xj + 3 == end) return; + if (b3) ctx.plus(*(xj + 3), *(xj + 3), *xi); + if (rangeCheck && xj + 4 == end) return; + if (b4) ctx.plus(*(xj + 4), *(xj + 4), *xi); + if (rangeCheck && xj + 5 == end) return; + if (b5) ctx.plus(*(xj + 5), *(xj + 5), *xi); + if (rangeCheck && xj + 6 == end) return; + if (b6) ctx.plus(*(xj + 6), *(xj + 6), *xi); + if (rangeCheck && xj + 7 == end) return; + if (b7) ctx.plus(*(xj + 7), *(xj + 7), *xi); } } @@ -398,12 +412,11 @@ namespace osuCrypto bool rangeCheck, typename Iter > - OC_FORCEINLINE void ExConvCode::accOne( + OC_FORCEINLINE void ExConvCode::accOneGen( Iter&& xi, Iter&& end, u8* matrixCoeff, - CoeffCtx& ctx, - std::integral_constant _) + CoeffCtx& ctx) { // xj += xi @@ -464,8 +477,7 @@ namespace osuCrypto Iter&& xi, Iter&& end, u8* matrixCoeff, - CoeffCtx& ctx, - std::integral_constant) + CoeffCtx& ctx) { static_assert(AccumulatorSize, "should have called the other overload"); static_assert(AccumulatorSize % 8 == 0, "must be a multiple of 8"); @@ -530,7 +542,10 @@ namespace osuCrypto } // add xi to the next positions - accOne(xi, end, mtxCoeffIter++, ctx, std::integral_constant{}); + if constexpr(AccumulatorSize == 0) + accOneGen(xi, end, mtxCoeffIter++, ctx); + else + accOne(xi, end, mtxCoeffIter++, ctx); ++xi; } @@ -544,7 +559,10 @@ namespace osuCrypto } // add xi to the next positions - accOne(xi, end, mtxCoeffIter++, ctx, std::integral_constant{}); + if constexpr (AccumulatorSize == 0) + accOneGen(xi, end, mtxCoeffIter++, ctx); + else + accOne(xi, end, mtxCoeffIter++, ctx); ++xi; } } diff --git a/libOTe/Tools/ExConvCode/Expander.h b/libOTe/Tools/ExConvCode/Expander.h index 55558627..1bd9803e 100644 --- a/libOTe/Tools/ExConvCode/Expander.h +++ b/libOTe/Tools/ExConvCode/Expander.h @@ -97,8 +97,8 @@ namespace osuCrypto detail::ExpanderModd prng(mSeed, mCodeSize); - auto rInput = ctx.restrictPtr(input); - auto rOutput = ctx.restrictPtr(output); + auto rInput = ctx.template restrictPtr(input); + auto rOutput = ctx.template restrictPtr(output); auto main = mMessageSize / 8 * 8; u64 i = 0; diff --git a/libOTe/Tools/Pprf/PprfUtil.h b/libOTe/Tools/Pprf/PprfUtil.h index 03b70f21..4d294a15 100644 --- a/libOTe/Tools/Pprf/PprfUtil.h +++ b/libOTe/Tools/Pprf/PprfUtil.h @@ -134,7 +134,7 @@ namespace osuCrypto CoeffCtx& ctx) { - u64 elementSize = ctx.byteSize(); + u64 elementSize = ctx.template byteSize(); // num of bytes they will take up. u64 numBytes = diff --git a/libOTe/Tools/Pprf/RegularPprf.h b/libOTe/Tools/Pprf/RegularPprf.h index fc37f009..ebede7f9 100644 --- a/libOTe/Tools/Pprf/RegularPprf.h +++ b/libOTe/Tools/Pprf/RegularPprf.h @@ -310,7 +310,7 @@ namespace osuCrypto std::array child; // clear the sums - std::array, 2> leafSums; + std::array leafSums; ctx.resize(leafSums[0], 8); ctx.resize(leafSums[1], 8); ctx.zero(leafSums[0].begin(), leafSums[0].end()); @@ -368,7 +368,7 @@ namespace osuCrypto // active child should be the correct value XOR the delta. // This will be done by sending the sums and the sums plus // delta and ensure that they can only decrypt the correct ones. - CoeffCtx::template Vec leafOts; + VecF leafOts; ctx.resize(leafOts, 2); PRNG otMasker; @@ -397,7 +397,7 @@ namespace osuCrypto } // copy m0 into the output buffer. - span buff = leafMsgs.subspan(0, 2 * ctx.byteSize()); + span buff = leafMsgs.subspan(0, 2 * ctx.template byteSize()); leafMsgs = leafMsgs.subspan(buff.size()); ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); @@ -411,7 +411,7 @@ namespace osuCrypto } else { - CoeffCtx::template Vec leafOts; + VecF leafOts; ctx.resize(leafOts, 1); PRNG otMasker; @@ -421,7 +421,7 @@ namespace osuCrypto { // copy the sum k into the output buffer. ctx.copy(leafOts[0], leafSums[k][j]); - span buff = leafMsgs.subspan(0, ctx.byteSize()); + span buff = leafMsgs.subspan(0, ctx.template byteSize()); leafMsgs = leafMsgs.subspan(buff.size()); ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); @@ -716,7 +716,7 @@ namespace osuCrypto // We change the hash function for the leaf so lets update // inactiveChildValues to use the new hash and subtract // these from the leafSums - std::array, 2> leafSums; + std::array leafSums; if (mDepth > 1) { auto theirSumsIter = theirSums.begin(); @@ -885,7 +885,7 @@ namespace osuCrypto // overwrite whatever the value was. This is an optimization. auto width = divCeil(mDomain, 1ull << (mDepth - d)); - CoeffCtx::template Vec temp; + VecF temp; ctx.resize(temp, 2); for (u64 k = 0; k < 2; ++k) { @@ -969,11 +969,11 @@ namespace osuCrypto auto notAi = inactiveChildIdx & 1; // offset to the first or second ot message, based on the one we want - auto offset = ctx.byteSize() * 2 * notAi; + auto offset = ctx.template byteSize() * 2 * notAi; // decrypt the ot string - span buff = leafMsg.subspan(offset, ctx.byteSize() * 2); + span buff = leafMsg.subspan(offset, ctx.template byteSize() * 2); leafMsg = leafMsg.subspan(buff.size() * 2); otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) @@ -1006,10 +1006,10 @@ namespace osuCrypto auto notAi = inactiveChildIdx & 1; // offset to the first or second ot message, based on the one we want - auto offset = ctx.byteSize() * notAi; + auto offset = ctx.template byteSize() * notAi; // decrypt the ot string - span buff = leafMsg.subspan(offset, ctx.byteSize()); + span buff = leafMsg.subspan(offset, ctx.template byteSize()); leafMsg = leafMsg.subspan(buff.size() * 2); otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); for (u64 i = 0; i < buff.size(); ++i) diff --git a/libOTe/Vole/Noisy/NoisyVoleReceiver.h b/libOTe/Vole/Noisy/NoisyVoleReceiver.h index 49d45007..3c83b2ff 100644 --- a/libOTe/Vole/Noisy/NoisyVoleReceiver.h +++ b/libOTe/Vole/Noisy/NoisyVoleReceiver.h @@ -53,7 +53,7 @@ namespace osuCrypto { otMsg = AlignedUnVector>{}); setTimePoint("NoisyVoleReceiver.ot.begin"); - otMsg.resize(ctx.bitSize()); + otMsg.resize(ctx.template bitSize()); MC_AWAIT(ot.send(otMsg, prng, chl)); setTimePoint("NoisyVoleReceiver.ot.end"); @@ -101,7 +101,7 @@ namespace osuCrypto { for (size_t j = 0; j < c.size(); ++j, ++k) { // msg[i,j] = otMsg[i,j,0] - ctx.fromBlock(msg[k], prng.get()); + ctx.fromBlock(msg[k], prng.get()); //ctx.zero(msg.begin() + k, msg.begin() + k + 1); //std::cout << "m" << i << ",0 = " << ctx.str(msg[k]) << std::endl; @@ -135,7 +135,7 @@ namespace osuCrypto { } } - buff.resize(msg.size() * ctx.byteSize()); + buff.resize(msg.size() * ctx.template byteSize()); ctx.serialize(msg.begin(), msg.end(), buff.begin()); MC_AWAIT(chl.send(std::move(buff))); diff --git a/libOTe/Vole/Noisy/NoisyVoleSender.h b/libOTe/Vole/Noisy/NoisyVoleSender.h index 6868b453..d42e9fcb 100644 --- a/libOTe/Vole/Noisy/NoisyVoleSender.h +++ b/libOTe/Vole/Noisy/NoisyVoleSender.h @@ -38,8 +38,8 @@ namespace osuCrypto { template < typename F, - typename G = F, - typename CoeffCtx = DefaultCoeffCtx + typename G, + typename CoeffCtx > class NoisyVoleSender : public TimerAdapter { @@ -83,7 +83,7 @@ namespace osuCrypto { temp = VecF{}, xb = BitVector{}); - xb = ctx.binaryDecomposition(delta); + xb = ctx.binaryDecomposition(delta); if (otMsg.size() != xb.size()) throw RTE_LOC; @@ -93,7 +93,7 @@ namespace osuCrypto { ctx.zero(b.begin(), b.end()); // receive the the excrypted one shares. - buffer.resize(xb.size() * b.size() * ctx.byteSize()); + buffer.resize(xb.size() * b.size() * ctx.template byteSize()); MC_AWAIT(chl.recv(buffer)); ctx.resize(msg, xb.size() * b.size()); ctx.deserialize(buffer.begin(), buffer.end(), msg.begin()); diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index 2f7e8226..a7fcd311 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -330,7 +330,7 @@ namespace osuCrypto mCtx.resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); - if (mCtx.bitSize() == 1) + if (mCtx.template bitSize() == 1) { mCtx.one(mBaseC.begin(), mBaseC.begin() + mNumPartitions); } @@ -343,14 +343,14 @@ namespace osuCrypto mCtx.one(one.begin(), one.end()); for (size_t i = 0; i < mNumPartitions; i++) { - mCtx.fromBlock(mBaseC[i], prng.get()); + mCtx.fromBlock(mBaseC[i], prng.get()); // must not be zero. while(mCtx.eq(zero[0], mBaseC[i])) - mCtx.fromBlock(mBaseC[i], prng.get()); + mCtx.fromBlock(mBaseC[i], prng.get()); // if we are not a field, then the noise should be odd. - if (mCtx.isField() == false) + if (mCtx.template isField() == false) { u8 odd = mCtx.binaryDecomposition(mBaseC[i])[0]; if (odd) @@ -605,19 +605,19 @@ namespace osuCrypto ); // recv delta - buffer.resize(mCtx.byteSize()); + buffer.resize(mCtx.template byteSize()); mCtx.resize(delta, 1); MC_AWAIT(chl.recv(buffer)); mCtx.deserialize(buffer.begin(), buffer.end(), delta.begin()); // recv B - buffer.resize(mCtx.byteSize() * mA.size()); + buffer.resize(mCtx.template byteSize() * mA.size()); mCtx.resize(B, mA.size()); MC_AWAIT(chl.recv(buffer)); mCtx.deserialize(buffer.begin(), buffer.end(), B.begin()); // recv the noisy values. - buffer.resize(mCtx.byteSize() * mBaseA.size()); + buffer.resize(mCtx.template byteSize() * mBaseA.size()); mCtx.resize(baseB, mBaseA.size()); MC_AWAIT(chl.recvResize(buffer)); mCtx.deserialize(buffer.begin(), buffer.end(), baseB.begin()); diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index bc79c379..682e2b5b 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -139,7 +139,7 @@ namespace osuCrypto if (isConfigured() == false) throw std::runtime_error("configure must be called first"); - xx = mCtx.binaryDecomposition(delta); + xx = mCtx.template binaryDecomposition(delta); // compute the correlation for the noisy coordinates. b.resize(baseVoleCount()); diff --git a/libOTe_Tests/EACode_Tests.cpp b/libOTe_Tests/EACode_Tests.cpp index b1eaa5ce..6347d4d3 100644 --- a/libOTe_Tests/EACode_Tests.cpp +++ b/libOTe_Tests/EACode_Tests.cpp @@ -13,7 +13,6 @@ namespace osuCrypto auto n = cmd.getOr("n", k * R); auto bw = cmd.getOr("bw", 7); - bool v = cmd.isSet("v"); EACode code; diff --git a/libOTe_Tests/ExConvCode_Tests.cpp b/libOTe_Tests/ExConvCode_Tests.cpp index 5a2ed345..49be53ef 100644 --- a/libOTe_Tests/ExConvCode_Tests.cpp +++ b/libOTe_Tests/ExConvCode_Tests.cpp @@ -62,13 +62,13 @@ namespace osuCrypto } CoeffCtx ctx; std::vector rand(divCeil(aw, 8)); - for (i64 i = 0; i < x1.size() - aw - 1; ++i) + for (i64 i = 0; i < i64(x1.size() - aw - 1); ++i) { prng.get(rand.data(), rand.size()); - code.accOne(x1.data() + i, x1.data()+n, rand.data(), ctx, std::integral_constant{}); + code.accOneGen(x1.data() + i, x1.data()+n, rand.data(), ctx); if (aw == 16) - code.accOne(x2.data() + i, x2.data()+n, rand.data(), ctx, std::integral_constant{}); + code.accOne(x2.data() + i, x2.data()+n, rand.data(), ctx); ctx.plus(x3[i + 1], x3[i + 1], x3[i]); @@ -259,7 +259,7 @@ namespace osuCrypto auto Bw = cmd.getManyOr("bw", { 7, 21 }); auto Aw = cmd.getManyOr("aw", { 16, 24, 29 }); - bool v = cmd.isSet("v"); + //bool v = cmd.isSet("v"); for (auto k : K) for (auto r : R) for (auto bw : Bw) for (auto aw : Aw) for (auto sys : { false, true }) { auto n = k * r; diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp index 7c028c25..fd3622c9 100644 --- a/libOTe_Tests/Pprf_Tests.cpp +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -16,7 +16,6 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) auto pntCount = 8ull; PRNG prng(CCBlock); - auto format = PprfOutputFormat::Interleaved; RegularPprfSender sender; RegularPprfReceiver recver; @@ -59,9 +58,9 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) pprf::allocateExpandTree(mTreeAlloc, sLevels); pprf::allocateExpandTree(mTreeAlloc, rLevels); - - Ctx::Vec sLeafLevel(8ull << depth); - Ctx::Vec rLeafLevel(8ull << depth); + using VecF = typename Ctx::template Vec; + VecF sLeafLevel(8ull * domain); + VecF rLeafLevel(8ull * domain); u64 leafOffset = 0; Ctx ctx; @@ -133,7 +132,10 @@ void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) else { if (sLeaves(j, i) != rLeaves(j, i)) + { + std::cout << "j " << j << " i " << i << " sender " << ctx.str(sLeaves(j, i)) << " recver " << ctx.str(rLeaves(j, i)) << std::endl; throw RTE_LOC; + } } } } @@ -171,7 +173,6 @@ void Tools_Pprf_test_impl( bool verbose) { - u64 depth = log2ceil(domain); auto threads = 1; PRNG prng(CCBlock); using Vec = typename Ctx::Vec; diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index 1016d54c..26dafae6 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -112,7 +112,7 @@ namespace // the sender gets b, d // the recver gets a, c auto c = recver.sampleBaseVoleVals(prng); - Ctx::template Vec a(c.size()), b(c.size()); + typename Ctx::template Vec a(c.size()), b(c.size()); prng.get(b.data(), b.size()); for (auto i : rng(c.size())) @@ -130,8 +130,8 @@ namespace template void Vole_Silent_test_impl(u64 n, MultType type, bool debug, bool doFakeBase, bool mal) { - using VecF = Ctx::Vec; - using VecG = Ctx::Vec; + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; Ctx ctx; block seed = CCBlock; From f9ba771a49d5a334b01bc8d56e368b4a23e97264 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 15:10:22 -0800 Subject: [PATCH 19/23] reworked examples and ENABLED fixes --- CMakePresets.json | 15 +- cmake/buildOptions.cmake | 19 +- cmake/buildOptions.cmake.in | 1 + frontend/ExampleBase.cpp | 154 ++++++++++ frontend/ExampleBase.h | 59 +--- frontend/ExampleNChooseOne.cpp | 245 ++++++++++++++++ frontend/ExampleNChooseOne.h | 226 +-------------- frontend/ExampleSilent.cpp | 189 +++++++++++++ frontend/ExampleSilent.h | 183 +----------- frontend/ExampleTwoChooseOne.cpp | 356 ++++++++++++++++++++++++ frontend/ExampleTwoChooseOne.h | 313 +-------------------- frontend/ExampleVole.cpp | 111 ++++++++ frontend/ExampleVole.h | 136 +-------- frontend/benchmark.h | 77 +---- frontend/main.cpp | 138 +-------- frontend/util.cpp | 3 + frontend/util.h | 27 +- libOTe/Tools/Pprf/RegularPprf.h | 7 +- libOTe/Vole/Silent/SilentVoleReceiver.h | 10 +- libOTe/Vole/Silent/SilentVoleSender.h | 4 + libOTe/config.h.in | 2 + libOTe_Tests/Pprf_Tests.cpp | 22 ++ libOTe_Tests/Vole_Tests.cpp | 31 ++- 23 files changed, 1188 insertions(+), 1140 deletions(-) create mode 100644 frontend/ExampleBase.cpp create mode 100644 frontend/ExampleNChooseOne.cpp create mode 100644 frontend/ExampleSilent.cpp create mode 100644 frontend/ExampleTwoChooseOne.cpp create mode 100644 frontend/ExampleVole.cpp diff --git a/CMakePresets.json b/CMakePresets.json index 20225f0a..5b7eb95d 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -13,6 +13,7 @@ "ENABLE_ALL_OT": true, "ENABLE_SSE": true, "ENABLE_AVX": true, + "ENABLE_BOOST": true, "ENABLE_BITPOLYMUL": false, "ENABLE_CIRCUITS": true, "LIBOTE_STD_VER": "17", @@ -43,17 +44,13 @@ "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", "ENABLE_INSECURE_SILVER": false, - "ENABLE_LDPC": false, + "ENABLE_PPRF": true, + "ENABLE_SILENT_VOLE": true, "LIBOTE_STD_VER": "17", - "ENABLE_ALL_OT": true, - "ENABLE_KKRT": "ON", - "ENABLE_IKNP": "ON", - "ENABLE_MR": "ON", - "ENABLE_SIMPLESTOT": "ON", "ENABLE_RELIC": false, "ENABLE_SODIUM": true, - "ENABLE_BOOST": false, - "ENABLE_BITPOLYMUL": true, + "ENABLE_BOOST": true, + "ENABLE_BITPOLYMUL": false, "FETCH_AUTO": "ON", "ENABLE_CIRCUITS": true, "VERBOSE_FETCH": true, @@ -88,7 +85,7 @@ "ENABLE_ALL_OT": true, "ENABLE_RELIC": true, "ENABLE_SODIUM": false, - "ENABLE_BOOST": false, + "ENABLE_BOOST": true, "ENABLE_OPENSSL": false, "FETCH_AUTO": true, "ENABLE_CIRCUITS": true, diff --git a/cmake/buildOptions.cmake b/cmake/buildOptions.cmake index 4d05ed90..cefdd3c4 100644 --- a/cmake/buildOptions.cmake +++ b/cmake/buildOptions.cmake @@ -94,9 +94,8 @@ option(ENABLE_DELTA_IKNP "Build the IKNP Delta-OT-Ext protocol." OFF) option(ENABLE_OOS "Build the OOS 1-oo-N OT-Ext protocol." OFF) option(ENABLE_KKRT "Build the KKRT 1-oo-N OT-Ext protocol." OFF) +option(ENABLE_PPRF "Build the PPRF protocol." OFF) option(ENABLE_SILENT_VOLE "Build the Silent Vole protocol." OFF) -#option(COPROTO_ENABLE_BOOST "Build with coproto boost support." OFF) -#option(COPROTO_ENABLE_OPENSSL "Build with coproto boost open ssl support." OFF) option(ENABLE_INSECURE_SILVER "Build with silver codes." OFF) option(ENABLE_LDPC "Build with ldpc functions." OFF) @@ -105,21 +104,14 @@ if(ENABLE_INSECURE_SILVER) endif() option(NO_KOS_WARNING "Build with no kos security warning." OFF) -#option(FETCH_BITPOLYMUL "download and build bitpolymul" OFF)) EVAL(FETCH_BITPOLYMUL_IMPL (DEFINED FETCH_BITPOLYMUL AND FETCH_BITPOLYMUL) OR ((NOT DEFINED FETCH_BITPOLYMUL) AND (FETCH_AUTO AND ENABLE_BITPOLYMUL))) - -#option(FETCH_BITPOLYMUL "download and build bitpolymul" OFF)) -EVAL(FETCH_BITPOLYMUL_IMPL - (DEFINED FETCH_BITPOLYMUL AND FETCH_BITPOLYMUL) OR - ((NOT DEFINED FETCH_BITPOLYMUL) AND (FETCH_AUTO AND ENABLE_BITPOLYMUL))) - - - - +if(ENABLE_SILENT_VOLE OR ENABLE_SILENTOT OR ENABLE_SOFTSPOKEN_OT) + set(ENABLE_PPRF true) +endif() option(VERBOSE_FETCH "Print build info for fetched libraries" ON) @@ -159,7 +151,8 @@ message(STATUS "Option: ENABLE_KKRT = ${ENABLE_KKRT}\n\n") message(STATUS "other \n=======================================================") -message(STATUS "Option: NO_KOS_WARNING = ${NO_KOS_WARNING}\n\n") +message(STATUS "Option: NO_KOS_WARNING = ${NO_KOS_WARNING}") +message(STATUS "Option: ENABLE_PPRF = ${ENABLE_PPRF}\n\n") ############################################# # Config Checks # diff --git a/cmake/buildOptions.cmake.in b/cmake/buildOptions.cmake.in index 1c8cf600..f3329773 100644 --- a/cmake/buildOptions.cmake.in +++ b/cmake/buildOptions.cmake.in @@ -70,6 +70,7 @@ set(ENABLE_KKRT @ENABLE_KKRT@) set(ENABLE_SILENT_VOLE @ENABLE_SILENT_VOLE@) set(NO_SILVER_WARNING @NO_SILVER_WARNING@) +set(ENABLE_PPRF @ENABLE_PPRF@) set(libOTe_boost_FOUND ${ENABLE_BOOST}) set(libOTe_relic_FOUND ${ENABLE_RELIC}) diff --git a/frontend/ExampleBase.cpp b/frontend/ExampleBase.cpp new file mode 100644 index 00000000..13b2726d --- /dev/null +++ b/frontend/ExampleBase.cpp @@ -0,0 +1,154 @@ + +#include "libOTe/Base/SimplestOT.h" +#include "libOTe/Base/McRosRoyTwist.h" +#include "libOTe/Base/McRosRoy.h" +#include "libOTe/Tools/Popf/EKEPopf.h" +#include "libOTe/Tools/Popf/MRPopf.h" +#include "libOTe/Tools/Popf/FeistelPopf.h" +#include "libOTe/Tools/Popf/FeistelMulPopf.h" +#include "libOTe/Tools/Popf/FeistelRistPopf.h" +#include "libOTe/Tools/Popf/FeistelMulRistPopf.h" +#include "libOTe/Base/MasnyRindal.h" +#include "libOTe/Base/MasnyRindalKyber.h" + +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/CLP.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" +#include "cryptoTools/Common/Timer.h" + +namespace osuCrypto +{ + + template + void baseOT_example_from_ot(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP&, BaseOT ot) + { +#ifdef COPROTO_ENABLE_BOOST + PRNG prng(sysRandomSeed()); + + if (totalOTs == 0) + totalOTs = 128; + + if (numThreads > 1) + std::cout << "multi threading for the base OT example is not implemented.\n" << std::flush; + + Timer t; + Timer::timeUnit s; + if (role == Role::Receiver) + { + auto sock = cp::asioConnect(ip, false); + BaseOT recv = ot; + + AlignedVector msg(totalOTs); + BitVector choice(totalOTs); + choice.randomize(prng); + + + s = t.setTimePoint("base OT start"); + + coproto::sync_wait(recv.receive(choice, msg, prng, sock)); + + // make sure all messages are sent. + cp::sync_wait(sock.flush()); + } + else + { + auto sock = cp::asioConnect(ip, true); + + BaseOT send = ot; + + AlignedVector> msg(totalOTs); + + s = t.setTimePoint("base OT start"); + + coproto::sync_wait(send.send(msg, prng, sock)); + + + // make sure all messages are sent. + cp::sync_wait(sock.flush()); + } + + + + auto e = t.setTimePoint("base OT end"); + auto milli = std::chrono::duration_cast(e - s).count(); + + std::cout << tag << (role == Role::Receiver ? " (receiver)" : " (sender)") + << " n=" << totalOTs << " " << milli << " ms" << std::endl; +#endif + } + + template + void baseOT_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) + { + return baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, BaseOT()); + } + + bool baseOT_examples(const CLP& cmd) + { + bool flagSet = false; + +#ifdef ENABLE_SIMPLESTOT + flagSet |= runIf(baseOT_example, cmd, simple); +#endif + +#ifdef ENABLE_SIMPLESTOT_ASM + flagSet |= runIf(baseOT_example, cmd, simpleasm); +#endif + +#ifdef ENABLE_MRR_TWIST +#ifdef ENABLE_SSE + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepEKEPopf factory; + const char* domain = "EKE POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwist(factory)); + }, cmd, moellerpopf, { "eke" }); +#endif + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepMRPopf factory; + const char* domain = "MR POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMR(factory)); + }, cmd, moellerpopf, { "mrPopf" }); + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelPopf factory; + const char* domain = "Feistel POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistFeistel(factory)); + }, cmd, moellerpopf, { "feistel" }); + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelMulPopf factory; + const char* domain = "Feistel With Multiplication POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMul(factory)); + }, cmd, moellerpopf, { "feistelMul" }); +#endif + +#ifdef ENABLE_MRR + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelRistPopf factory; + const char* domain = "Feistel POPF OT example (Risretto)"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoy(factory)); + }, cmd, ristrettopopf, { "feistel" }); + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelMulRistPopf factory; + const char* domain = "Feistel With Multiplication POPF OT example (Risretto)"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyMul(factory)); + }, cmd, ristrettopopf, { "feistelMul" }); +#endif + +#ifdef ENABLE_MR + flagSet |= runIf(baseOT_example, cmd, mr); +#endif + + return flagSet; + } + +} diff --git a/frontend/ExampleBase.h b/frontend/ExampleBase.h index a4684405..762d3b78 100644 --- a/frontend/ExampleBase.h +++ b/frontend/ExampleBase.h @@ -14,65 +14,8 @@ #include "cryptoTools/Common/BitVector.h" #include "cryptoTools/Common/CLP.h" -#include "util.h" -#include "coproto/Socket/AsioSocket.h" namespace osuCrypto { - - template - void baseOT_example_from_ot(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP&, BaseOT ot) - { -#ifdef COPROTO_ENABLE_BOOST - PRNG prng(sysRandomSeed()); - - if (totalOTs == 0) - totalOTs = 128; - - if (numThreads > 1) - std::cout << "multi threading for the base OT example is not implemented.\n" << std::flush; - - Timer t; - Timer::timeUnit s; - if (role == Role::Receiver) - { - auto sock = cp::asioConnect(ip, false); - BaseOT recv = ot; - - AlignedVector msg(totalOTs); - BitVector choice(totalOTs); - choice.randomize(prng); - - - s = t.setTimePoint("base OT start"); - - coproto::sync_wait(recv.receive(choice, msg, prng, sock)); - - } - else - { - auto sock = cp::asioConnect(ip, true); - - BaseOT send = ot; - - AlignedVector> msg(totalOTs); - - s = t.setTimePoint("base OT start"); - - coproto::sync_wait(send.send(msg, prng, sock)); - } - - auto e = t.setTimePoint("base OT end"); - auto milli = std::chrono::duration_cast(e - s).count(); - - std::cout << tag << (role == Role::Receiver ? " (receiver)" : " (sender)") - << " n=" << totalOTs << " " << milli << " ms" << std::endl; -#endif - } - - template - void baseOT_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) - { - return baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, BaseOT()); - } + bool baseOT_examples(const CLP& clp); } diff --git a/frontend/ExampleNChooseOne.cpp b/frontend/ExampleNChooseOne.cpp new file mode 100644 index 00000000..86bd1d5f --- /dev/null +++ b/frontend/ExampleNChooseOne.cpp @@ -0,0 +1,245 @@ + +#include "cryptoTools/Common/Matrix.h" +#include "libOTe/NChooseOne/Oos/OosNcoOtReceiver.h" +#include "libOTe/NChooseOne/Oos/OosNcoOtSender.h" +#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.h" +#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.h" +#include "cryptoTools/Common/Matrix.h" +#include "libOTe/Tools/Coproto.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ + + + + auto chls = cp::LocalAsyncSocket::makePair(); + + template + void NChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP&) + { +#ifdef COPROTO_ENABLE_BOOST + const u64 step = 1024; + + if (totalOTs == 0) + totalOTs = 1 << 20; + + bool randomOT = true; + u64 numOTs = (u64)totalOTs; + auto numChosenMsgs = 256; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + //auto chl = role == Role::Sender ? chls[0] : chls[1]; + + PRNG prng(ZeroBlock);// sysRandomSeed()); + + NcoOtSender sender; + NcoOtReceiver recver; + + // all Nco Ot extenders must have configure called first. This determines + // a variety of parameters such as how many base OTs are required. + bool maliciousSecure = false; + u64 statSecParam = 40; + u64 inputBitCount = 76; // the kkrt protocol default to 128 but oos can only do 76. + + // create a lambda function that performs the computation of a single receiver thread. + auto recvRoutine = [&]() -> task<> + { + MC_BEGIN(task<>, &, + i = u64{}, min = u64{}, + recvMsgs = std::vector{}, + choices = std::vector{} + ); + + recver.configure(maliciousSecure, statSecParam, inputBitCount); + //MC_AWAIT(sync(chl, Role::Receiver)); + + if (randomOT) + { + // once configure(...) and setBaseOts(...) are called, + // we can compute many batches of OTs. First we need to tell + // the instance how many OTs we want in this batch. This is done here. + MC_AWAIT(recver.init(numOTs, prng, chl)); + + // now we can iterate over the OTs and actually retrieve the desired + // messages. However, for efficiency we will do this in steps where + // we do some computation followed by sending off data. This is more + // efficient since data will be sent in the background :). + for (i = 0; i < numOTs; ) + { + // figure out how many OTs we want to do in this step. + min = std::min(numOTs - i, step); + + //// iterate over this step. + for (u64 j = 0; j < min; ++j, ++i) + { + // For the OT index by i, we need to pick which + // one of the N OT messages that we want. For this + // example we simply pick a random one. Note only the + // first log2(N) bits of choice is considered. + block choice = prng.get(); + + // this will hold the (random) OT message of our choice + block otMessage; + + // retrieve the desired message. + recver.encode(i, &choice, &otMessage); + + // do something cool with otMessage + //otMessage; + } + + // Note that all OTs in this region must be encode. If there are some + // that you don't actually care about, then you can skip them by calling + // + // recver.zeroEncode(i); + // + + // Now that we have gotten out the OT mMessages for this step, + // we are ready to send over network some information that + // allows the sender to also compute the OT mMessages. Since we just + // encoded "min" OT mMessages, we will tell the class to send the + // next min "correction" values. + MC_AWAIT(recver.sendCorrection(chl, min)); + } + + // once all numOTs have been encoded and had their correction values sent + // we must call check. This allows to sender to make sure we did not cheat. + // For semi-honest protocols, this can and will be skipped. + MC_AWAIT(recver.check(chl, prng.get())); + } + else + { + recvMsgs.resize(numOTs); + choices.resize(numOTs); + + // define which messages the receiver should learn. + for (i = 0; i < numOTs; ++i) + choices[i] = prng.get(); + + // the messages that were learned are written to recvMsgs. + MC_AWAIT(recver.receiveChosen(numChosenMsgs, recvMsgs, choices, prng, chl)); + } + + MC_AWAIT(chl.flush()); + MC_END(); + }; + + // create a lambda function that performs the computation of a single sender thread. + auto sendRoutine = [&]() + { + MC_BEGIN(task<>, &, + sendMessages = Matrix{}, + i = u64{}, min = u64{} + ); + + sender.configure(maliciousSecure, statSecParam, inputBitCount); + //MC_AWAIT(sync(chl, Role::Sender)); + + if (randomOT) + { + // Same explanation as above. + MC_AWAIT(sender.init(numOTs, prng, chl)); + + // Same explanation as above. + for (i = 0; i < numOTs; ) + { + // Same explanation as above. + min = std::min(numOTs - i, step); + + // unlike for the receiver, before we call encode to get + // some desired OT message, we must call recvCorrection(...). + // This receivers some information that the receiver had sent + // and allows the sender to compute any OT message that they desired. + // Note that the step size must match what the receiver used. + // If this is unknown you can use recvCorrection(chl) -> u64 + // which will tell you how many were sent. + MC_AWAIT(sender.recvCorrection(chl, min)); + + // we now encode any OT message with index less that i + min. + for (u64 j = 0; j < min; ++j, ++i) + { + // in particular, the sender can retrieve many OT messages + // at a single index, in this case we chose to retrieve 3 + // but that is arbitrary. + auto choice0 = prng.get(); + auto choice1 = prng.get(); + auto choice2 = prng.get(); + + // these we hold the actual OT messages. + block + otMessage0, + otMessage1, + otMessage2; + + // now retrieve the messages + sender.encode(i, &choice0, &otMessage0); + sender.encode(i, &choice1, &otMessage1); + sender.encode(i, &choice2, &otMessage2); + } + } + + // This call is required to make sure the receiver did not cheat. + // All corrections must be received before this is called. + MC_AWAIT(sender.check(chl, ZeroBlock)); + } + else + { + // populate this with the messages that you want to send. + sendMessages.resize(numOTs, numChosenMsgs); + prng.get(sendMessages.data(), sendMessages.size()); + + // perform the OTs with the given messages. + MC_AWAIT(sender.sendChosen(sendMessages, prng, chl)); + + } + + MC_AWAIT(chl.flush()); + MC_END(); + }; + + + Timer time; + auto s = time.setTimePoint("start"); + + + task<> proto; + if (role == Role::Sender) + proto = sendRoutine(); + else + proto = recvRoutine(); + try + { + cp::sync_wait(proto); + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } + + auto e = time.setTimePoint("finish"); + auto milli = std::chrono::duration_cast(e - s).count(); + + if (role == Role::Sender) + std::cout << tag << " n=" << totalOTs << " " << milli << " ms " << std::endl; +#endif + } + + + bool NChooseOne_Examples(const CLP& cmd) + { + bool flagSet = false; +#ifdef ENABLE_KKRT + flagSet |= runIf(NChooseOne_example, cmd, kkrt); +#endif + +#ifdef ENABLE_OOS + flagSet |= runIf(NChooseOne_example, cmd, oos); +#endif + + return flagSet; + } + +} diff --git a/frontend/ExampleNChooseOne.h b/frontend/ExampleNChooseOne.h index 7d886cc4..571657ed 100644 --- a/frontend/ExampleNChooseOne.h +++ b/frontend/ExampleNChooseOne.h @@ -1,230 +1,8 @@ #pragma once - - -#include "cryptoTools/Common/Matrix.h" -#include "libOTe/NChooseOne/Oos/OosNcoOtReceiver.h" -#include "libOTe/NChooseOne/Oos/OosNcoOtSender.h" -#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.h" -#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.h" -#include "cryptoTools/Common/Matrix.h" -#include "libOTe/Tools/Coproto.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { - - - auto chls = cp::LocalAsyncSocket::makePair(); - - template - void NChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP&) - { -#ifdef COPROTO_ENABLE_BOOST - const u64 step = 1024; - - if (totalOTs == 0) - totalOTs = 1 << 20; - - bool randomOT = true; - u64 numOTs = (u64)totalOTs; - auto numChosenMsgs = 256; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - //auto chl = role == Role::Sender ? chls[0] : chls[1]; - - PRNG prng(ZeroBlock);// sysRandomSeed()); - - NcoOtSender sender; - NcoOtReceiver recver; - - // all Nco Ot extenders must have configure called first. This determines - // a variety of parameters such as how many base OTs are required. - bool maliciousSecure = false; - u64 statSecParam = 40; - u64 inputBitCount = 76; // the kkrt protocol default to 128 but oos can only do 76. - - // create a lambda function that performs the computation of a single receiver thread. - auto recvRoutine = [&]() -> task<> - { - MC_BEGIN(task<>,&, - i = u64{}, min = u64{}, - recvMsgs = std::vector{}, - choices = std::vector{} - ); - - recver.configure(maliciousSecure, statSecParam, inputBitCount); - //MC_AWAIT(sync(chl, Role::Receiver)); - - if (randomOT) - { - // once configure(...) and setBaseOts(...) are called, - // we can compute many batches of OTs. First we need to tell - // the instance how many OTs we want in this batch. This is done here. - MC_AWAIT(recver.init(numOTs, prng, chl)); - - // now we can iterate over the OTs and actually retrieve the desired - // messages. However, for efficiency we will do this in steps where - // we do some computation followed by sending off data. This is more - // efficient since data will be sent in the background :). - for (i = 0; i < numOTs; ) - { - // figure out how many OTs we want to do in this step. - min = std::min(numOTs - i, step); - - //// iterate over this step. - for (u64 j = 0; j < min; ++j, ++i) - { - // For the OT index by i, we need to pick which - // one of the N OT messages that we want. For this - // example we simply pick a random one. Note only the - // first log2(N) bits of choice is considered. - block choice = prng.get(); - - // this will hold the (random) OT message of our choice - block otMessage; - - // retrieve the desired message. - recver.encode(i, &choice, &otMessage); - - // do something cool with otMessage - //otMessage; - } - - // Note that all OTs in this region must be encode. If there are some - // that you don't actually care about, then you can skip them by calling - // - // recver.zeroEncode(i); - // - - // Now that we have gotten out the OT mMessages for this step, - // we are ready to send over network some information that - // allows the sender to also compute the OT mMessages. Since we just - // encoded "min" OT mMessages, we will tell the class to send the - // next min "correction" values. - MC_AWAIT(recver.sendCorrection(chl, min)); - } - - // once all numOTs have been encoded and had their correction values sent - // we must call check. This allows to sender to make sure we did not cheat. - // For semi-honest protocols, this can and will be skipped. - MC_AWAIT(recver.check(chl, prng.get())); - } - else - { - recvMsgs.resize(numOTs); - choices.resize(numOTs); - - // define which messages the receiver should learn. - for (i = 0; i < numOTs; ++i) - choices[i] = prng.get(); - - // the messages that were learned are written to recvMsgs. - MC_AWAIT(recver.receiveChosen(numChosenMsgs, recvMsgs, choices, prng, chl)); - } - - MC_AWAIT(chl.flush()); - MC_END(); - }; - - // create a lambda function that performs the computation of a single sender thread. - auto sendRoutine = [&]() - { - MC_BEGIN(task<>,&, - sendMessages = Matrix{}, - i = u64{}, min = u64{} - ); - - sender.configure(maliciousSecure, statSecParam, inputBitCount); - //MC_AWAIT(sync(chl, Role::Sender)); - - if (randomOT) - { - // Same explanation as above. - MC_AWAIT(sender.init(numOTs, prng, chl)); - - // Same explanation as above. - for (i = 0; i < numOTs; ) - { - // Same explanation as above. - min = std::min(numOTs - i, step); - - // unlike for the receiver, before we call encode to get - // some desired OT message, we must call recvCorrection(...). - // This receivers some information that the receiver had sent - // and allows the sender to compute any OT message that they desired. - // Note that the step size must match what the receiver used. - // If this is unknown you can use recvCorrection(chl) -> u64 - // which will tell you how many were sent. - MC_AWAIT(sender.recvCorrection(chl, min)); - - // we now encode any OT message with index less that i + min. - for (u64 j = 0; j < min; ++j, ++i) - { - // in particular, the sender can retrieve many OT messages - // at a single index, in this case we chose to retrieve 3 - // but that is arbitrary. - auto choice0 = prng.get(); - auto choice1 = prng.get(); - auto choice2 = prng.get(); - - // these we hold the actual OT messages. - block - otMessage0, - otMessage1, - otMessage2; - - // now retrieve the messages - sender.encode(i, &choice0, &otMessage0); - sender.encode(i, &choice1, &otMessage1); - sender.encode(i, &choice2, &otMessage2); - } - } - - // This call is required to make sure the receiver did not cheat. - // All corrections must be received before this is called. - MC_AWAIT(sender.check(chl, ZeroBlock)); - } - else - { - // populate this with the messages that you want to send. - sendMessages.resize(numOTs, numChosenMsgs); - prng.get(sendMessages.data(), sendMessages.size()); - - // perform the OTs with the given messages. - MC_AWAIT(sender.sendChosen(sendMessages, prng, chl)); - - } - - MC_AWAIT(chl.flush()); - MC_END(); - }; - - - Timer time; - auto s = time.setTimePoint("start"); - - - task<> proto; - if (role == Role::Sender) - proto = sendRoutine(); - else - proto = recvRoutine(); - try - { - cp::sync_wait(proto); - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - } - - auto e = time.setTimePoint("finish"); - auto milli = std::chrono::duration_cast(e - s).count(); - - if (role == Role::Sender) - std::cout << tag << " n=" << totalOTs << " " << milli << " ms " << std::endl; -#endif - } - + bool NChooseOne_Examples(const CLP& cmd); } diff --git a/frontend/ExampleSilent.cpp b/frontend/ExampleSilent.cpp new file mode 100644 index 00000000..c2b17d02 --- /dev/null +++ b/frontend/ExampleSilent.cpp @@ -0,0 +1,189 @@ + +#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" +#include "util.h" +#include + +#include "cryptoTools/Network/IOService.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ + + void Silent_example(Role role, u64 numOTs, u64 numThreads, std::string ip, std::string tag, const CLP& cmd) + { +#if defined(ENABLE_SILENTOT) && defined(COPROTO_ENABLE_BOOST) + + if (numOTs == 0) + numOTs = 1 << 20; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + + + PRNG prng(sysRandomSeed()); + + bool fakeBase = cmd.isSet("fakeBase"); + u64 trials = cmd.getOr("trials", 1); + auto malicious = cmd.isSet("mal") ? SilentSecType::Malicious : SilentSecType::SemiHonest; + + auto multType = (MultType)cmd.getOr("multType", (int)DefaultMultType); + + std::vector types; + if (cmd.isSet("base")) + types.push_back(SilentBaseType::Base); + else + types.push_back(SilentBaseType::BaseExtend); + + macoro::thread_pool threadPool; + auto work = threadPool.make_work(); + if (numThreads > 1) + threadPool.create_threads(numThreads); + + for (auto type : types) + { + for (u64 tt = 0; tt < trials; ++tt) + { + Timer timer; + auto start = timer.setTimePoint("start"); + if (role == Role::Sender) + { + SilentOtExtSender sender; + + // optionally request the LPN encoding matrix. + sender.mMultType = multType; + + // optionally configure the sender. default is semi honest security. + sender.configure(numOTs, 2, numThreads, malicious); + + if (fakeBase) + { + auto nn = sender.baseOtCount(); + BitVector bits(nn); + bits.randomize(prng); + std::vector> baseSendMsgs(bits.size()); + std::vector baseRecvMsgs(bits.size()); + + auto commonPrng = PRNG(ZeroBlock); + commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); + for (u64 i = 0; i < bits.size(); ++i) + baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; + + sender.setBaseOts(baseRecvMsgs, bits); + } + else + { + // optional. You can request that the base ot are generated either + // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. + // The default is the latter, base + extension. + cp::sync_wait(sender.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); + } + + std::vector> messages(numOTs); + + // create the protocol object. + auto protocol = sender.silentSend(messages, prng, chl); + + // run the protocol + if (numThreads <= 1) + cp::sync_wait(protocol); + else + // launch the protocol on the thread pool. + cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); + + // messages has been populated with random OT messages. + // See the header for other options. + } + else + { + + SilentOtExtReceiver recver; + + // optionally request the LPN encoding matrix. + recver.mMultType = multType; + + // configure the sender. optional for semi honest security... + recver.configure(numOTs, 2, numThreads, malicious); + + if (fakeBase) + { + auto nn = recver.baseOtCount(); + BitVector bits(nn); + bits.randomize(prng); + std::vector> baseSendMsgs(bits.size()); + std::vector baseRecvMsgs(bits.size()); + + auto commonPrng = PRNG(ZeroBlock); + commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); + for (u64 i = 0; i < bits.size(); ++i) + baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; + + recver.setBaseOts(baseSendMsgs); + } + else + { + // optional. You can request that the base ot are generated either + // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. + // The default is the latter, base + extension. + cp::sync_wait(recver.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); + } + + std::vector messages(numOTs); + BitVector choices(numOTs); + + // create the protocol object. + auto protocol = recver.silentReceive(choices, messages, prng, chl); + + // run the protocol + if (numThreads <= 1) + cp::sync_wait(protocol); + else + // launch the protocol on the thread pool. + cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); + + // choices, messages has been populated with random OT messages. + // messages[i] = sender.message[i][choices[i]] + // See the header for other options. + } + auto end = timer.setTimePoint("end"); + auto milli = std::chrono::duration_cast(end - start).count(); + + u64 com = chl.bytesReceived() + chl.bytesSent(); + + if (role == Role::Sender) + { + std::string typeStr = type == SilentBaseType::Base ? "b " : "be "; + lout << tag << + " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << + " type: " << Color::Green << typeStr << Color::Default << + " || " << Color::Green << + std::setw(6) << std::setfill(' ') << milli << " ms " << + std::setw(6) << std::setfill(' ') << com << " bytes" << std::endl << Color::Default; + + if (cmd.getOr("v", 0) > 1) + lout << gTimer << std::endl; + } + + if (cmd.isSet("v")) + { + if (role == Role::Sender) + lout << " **** sender ****\n" << timer << std::endl; + + if (role == Role::Receiver) + lout << " **** receiver ****\n" << timer << std::endl; + } + } + + } + + cp::sync_wait(chl.flush()); + +#endif + } + bool Silent_Examples(const CLP& cmd) + { + return runIf(Silent_example, cmd, Silent); + } + + +} diff --git a/frontend/ExampleSilent.h b/frontend/ExampleSilent.h index e792ba73..abc4122f 100644 --- a/frontend/ExampleSilent.h +++ b/frontend/ExampleSilent.h @@ -1,186 +1,7 @@ #pragma once - -#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" -#include "util.h" -#include - -#include "cryptoTools/Network/IOService.h" -#include "coproto/Socket/AsioSocket.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { - - void Silent_example(Role role, u64 numOTs, u64 numThreads, std::string ip, std::string tag, CLP& cmd) - { -#if defined(ENABLE_SILENTOT) && defined(COPROTO_ENABLE_BOOST) - - if (numOTs == 0) - numOTs = 1 << 20; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - - - PRNG prng(sysRandomSeed()); - - bool fakeBase = cmd.isSet("fakeBase"); - u64 trials = cmd.getOr("trials", 1); - auto malicious = cmd.isSet("mal") ? SilentSecType::Malicious : SilentSecType::SemiHonest; - - auto multType = (MultType)cmd.getOr("multType", (int)DefaultMultType); - - std::vector types; - if (cmd.isSet("base")) - types.push_back(SilentBaseType::Base); - else - types.push_back(SilentBaseType::BaseExtend); - - macoro::thread_pool threadPool; - auto work = threadPool.make_work(); - if (numThreads > 1) - threadPool.create_threads(numThreads); - - for (auto type : types) - { - for (u64 tt = 0; tt < trials; ++tt) - { - Timer timer; - auto start = timer.setTimePoint("start"); - if (role == Role::Sender) - { - SilentOtExtSender sender; - - // optionally request the LPN encoding matrix. - sender.mMultType = multType; - - // optionally configure the sender. default is semi honest security. - sender.configure(numOTs, 2, numThreads, malicious); - - if (fakeBase) - { - auto nn = sender.baseOtCount(); - BitVector bits(nn); - bits.randomize(prng); - std::vector> baseSendMsgs(bits.size()); - std::vector baseRecvMsgs(bits.size()); - - auto commonPrng = PRNG(ZeroBlock); - commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < bits.size(); ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - - sender.setBaseOts(baseRecvMsgs, bits); - } - else - { - // optional. You can request that the base ot are generated either - // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. - // The default is the latter, base + extension. - cp::sync_wait(sender.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); - } - - std::vector> messages(numOTs); - - // create the protocol object. - auto protocol = sender.silentSend(messages, prng, chl); - - // run the protocol - if (numThreads <= 1) - cp::sync_wait(protocol); - else - // launch the protocol on the thread pool. - cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); - - // messages has been populated with random OT messages. - // See the header for other options. - } - else - { - - SilentOtExtReceiver recver; - - // optionally request the LPN encoding matrix. - recver.mMultType = multType; - - // configure the sender. optional for semi honest security... - recver.configure(numOTs, 2, numThreads, malicious); - - if (fakeBase) - { - auto nn = recver.baseOtCount(); - BitVector bits(nn); - bits.randomize(prng); - std::vector> baseSendMsgs(bits.size()); - std::vector baseRecvMsgs(bits.size()); - - auto commonPrng = PRNG(ZeroBlock); - commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < bits.size(); ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - - recver.setBaseOts(baseSendMsgs); - } - else - { - // optional. You can request that the base ot are generated either - // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. - // The default is the latter, base + extension. - cp::sync_wait(recver.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); - } - - std::vector messages(numOTs); - BitVector choices(numOTs); - - // create the protocol object. - auto protocol = recver.silentReceive(choices, messages, prng, chl); - - // run the protocol - if (numThreads <= 1) - cp::sync_wait(protocol); - else - // launch the protocol on the thread pool. - cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); - - // choices, messages has been populated with random OT messages. - // messages[i] = sender.message[i][choices[i]] - // See the header for other options. - } - auto end = timer.setTimePoint("end"); - auto milli = std::chrono::duration_cast(end - start).count(); - - u64 com = chl.bytesReceived() + chl.bytesSent(); - - if (role == Role::Sender) - { - std::string typeStr = type == SilentBaseType::Base ? "b " : "be "; - lout << tag << - " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << - " type: " << Color::Green << typeStr << Color::Default << - " || " << Color::Green << - std::setw(6) << std::setfill(' ') << milli << " ms " << - std::setw(6) << std::setfill(' ') << com << " bytes" << std::endl << Color::Default; - - if (cmd.getOr("v", 0) > 1) - lout << gTimer << std::endl; - } - - if (cmd.isSet("v")) - { - if (role == Role::Sender) - lout << " **** sender ****\n" << timer << std::endl; - - if (role == Role::Receiver) - lout << " **** receiver ****\n" << timer << std::endl; - } - } - - } - - cp::sync_wait(chl.flush()); - -#endif - } - - + bool Silent_Examples(const CLP& cmd); } diff --git a/frontend/ExampleTwoChooseOne.cpp b/frontend/ExampleTwoChooseOne.cpp new file mode 100644 index 00000000..4dc5527b --- /dev/null +++ b/frontend/ExampleTwoChooseOne.cpp @@ -0,0 +1,356 @@ + +#include "libOTe/Base/BaseOT.h" +#include "libOTe/TwoChooseOne/Kos/KosOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Kos/KosOtExtSender.h" +#include "libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.h" +#include "libOTe/TwoChooseOne/KosDot/KosDotExtSender.h" +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" + + +#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ +#ifdef ENABLE_IKNP + void noHash(IknpOtExtSender& s, IknpOtExtReceiver& r) + { + s.mHash = false; + r.mHash = false; + } +#endif + + template + void noHash(Sender&, Receiver&) + { + throw std::runtime_error("This protocol does not support noHash"); + } + +#ifdef ENABLE_SOFTSPOKEN_OT + // soft spoken takes an extra parameter as input what determines + // the computation/communication trade-off. + template + using is_SoftSpoken = typename std::conditional< + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value + false + , + std::true_type, std::false_type>::type; +#else + template + using is_SoftSpoken = std::false_type; +#endif + + template + typename std::enable_if::value, T>::type + construct(CLP& cmd) + { + return T(cmd.getOr("f", 2)); + } + + template + typename std::enable_if::value, T>::type + construct(CLP& cmd) + { + return T{}; + } + + template + void TwoChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& cmd) + { + +#ifdef COPROTO_ENABLE_BOOST + if (totalOTs == 0) + totalOTs = 1 << 20; + + bool randomOT = true; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + + PRNG prng(sysRandomSeed()); + + + OtExtSender sender; + OtExtRecver receiver; + + +#ifdef LIBOTE_HAS_BASE_OT + // Now compute the base OTs, we need to set them on the first pair of extenders. + // In real code you would only have a sender or reciever, not both. But we do + // here just showing the example. + if (role == Role::Receiver) + { + DefaultBaseOT base; + std::array, 128> baseMsg; + + // perform the base To, call sync_wait to block until they have completed. + cp::sync_wait(base.send(baseMsg, prng, chl)); + receiver.setBaseOts(baseMsg); + } + else + { + + DefaultBaseOT base; + BitVector bv(128); + std::array baseMsg; + bv.randomize(prng); + + // perform the base To, call sync_wait to block until they have completed. + cp::sync_wait(base.receive(bv, baseMsg, prng, chl)); + sender.setBaseOts(baseMsg, bv); + } + +#else + if (!cmd.isSet("fakeBase")) + std::cout << "warning, base ots are not enabled. Fake base OTs will be used. " << std::endl; + PRNG commonPRNG(oc::ZeroBlock); + std::array, 128> sendMsgs; + commonPRNG.get(sendMsgs.data(), sendMsgs.size()); + if (role == Role::Receiver) + { + receiver.setBaseOts(sendMsgs); + } + else + { + BitVector bv(128); + bv.randomize(commonPRNG); + std::array recvMsgs; + for (u64 i = 0; i < 128; ++i) + recvMsgs[i] = sendMsgs[i][bv[i]]; + sender.setBaseOts(recvMsgs, bv); + } +#endif + + if (cmd.isSet("noHash")) + noHash(sender, receiver); + + Timer timer, sendTimer, recvTimer; + sendTimer.setTimePoint("start"); + recvTimer.setTimePoint("start"); + auto s = timer.setTimePoint("start"); + + if (numThreads == 1) + { + if (role == Role::Receiver) + { + // construct the choices that we want. + BitVector choice(totalOTs); + // in this case pick random messages. + choice.randomize(prng); + + // construct a vector to stored the received messages. + AlignedUnVector rMsgs(totalOTs); + + try { + + if (randomOT) + { + // perform totalOTs random OTs, the results will be written to msgs. + cp::sync_wait(receiver.receive(choice, rMsgs, prng, chl)); + } + else + { + // perform totalOTs chosen message OTs, the results will be written to msgs. + cp::sync_wait(receiver.receiveChosen(choice, rMsgs, prng, chl)); + } + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + chl.close(); + } + } + else + { + // construct a vector to stored the random send messages. + AlignedUnVector> sMsgs(totalOTs); + + + // if delta OT is used, then the user can call the following + // to set the desired XOR difference between the zero messages + // and the one messages. + // + // senders[i].setDelta(some 128 bit delta); + // + try + { + if (randomOT) + { + // perform the OTs and write the random OTs to msgs. + cp::sync_wait(sender.send(sMsgs, prng, chl)); + } + else + { + // Populate msgs with something useful... + prng.get(sMsgs.data(), sMsgs.size()); + + // perform the OTs. The receiver will learn one + // of the messages stored in msgs. + cp::sync_wait(sender.sendChosen(sMsgs, prng, chl)); + } + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + chl.close(); + } + } + + // make sure all messages have been sent. + cp::sync_wait(chl.flush()); + } + else + { + + // for multi threading, we only show example for random OTs. + // We first need to construct the inputs + // that each thread will use. Note that the actual protocol + // is not thread safe so everything needs to be independent. + std::vector> tasks(numThreads); + std::vector threadPrngs(numThreads); + std::vector threadChls(numThreads); + + macoro::thread_pool::work work; + macoro::thread_pool threadPool(numThreads, work); + + if (role == Role::Receiver) + { + std::vector receivers(numThreads); + std::vector threadChoices(numThreads); + std::vector> threadMsgs(numThreads); + + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + { + u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); + u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); + + threadChoices[threadIndex].resize(endIndex - beginIndex); + threadChoices[threadIndex].randomize(prng); + + threadMsgs[threadIndex].resize(endIndex - beginIndex); + + // create a copy of the receiver so that each can run + // independently. A single receiver is not thread safe. + receivers[threadIndex] = receiver.splitBase(); + + // create a PRNG for this thread. + threadPrngs[threadIndex].SetSeed(prng.get()); + + // create a socket for this thread. This is done by calling fork(). + threadChls[threadIndex] = chl.fork(); + + // start the receive protocol on the thread pool + tasks[threadIndex] = + receivers[threadIndex].receive( + threadChoices[threadIndex], + threadMsgs[threadIndex], + threadPrngs[threadIndex], + threadChls[threadIndex]) + | macoro::start_on(threadPool); + } + + // block this thread until the receive operations + // have completed. + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + cp::sync_wait(tasks[threadIndex]); + } + else + { + std::vector senders(numThreads); + std::vector>> threadMsgs(numThreads); + + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + { + u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); + u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); + + threadMsgs[threadIndex].resize(endIndex - beginIndex); + + // create a copy of the receiver so that each can run + // independently. A single receiver is not thread safe. + senders[threadIndex] = sender.splitBase(); + + // create a PRNG for this thread. + threadPrngs[threadIndex].SetSeed(prng.get()); + + // create a socket for this thread. This is done by calling fork(). + threadChls[threadIndex] = chl.fork(); + + // start the send protocol on the thread pool + tasks[threadIndex] = + senders[threadIndex].send( + threadMsgs[threadIndex], + threadPrngs[threadIndex], + threadChls[threadIndex]) + | macoro::start_on(threadPool); + } + + // block this thread until the receive operations + // have completed. + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + cp::sync_wait(tasks[threadIndex]); + } + + work.reset(); + } + + + + auto e = timer.setTimePoint("finish"); + auto milli = std::chrono::duration_cast(e - s).count(); + + auto com = 0;// (chls[0].getTotalDataRecv() + chls[0].getTotalDataSent())* numThreads; + + if (role == Role::Sender) + lout << tag << " n=" << Color::Green << totalOTs << " " << milli << " ms " << com << " bytes" << std::endl << Color::Default; + + + if (cmd.isSet("v") && role == Role::Sender) + { + if (role == Role::Sender) + lout << " **** sender ****\n" << sendTimer << std::endl; + + if (role == Role::Receiver) + lout << " **** receiver ****\n" << recvTimer << std::endl; + } + +#else + throw std::runtime_error("This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. " LOCATION); +#endif + } + + + + + bool TwoChooseOne_Examples(const CLP& cmd) + { + bool flagSet = false; +#ifdef ENABLE_IKNP + flagSet |= runIf(TwoChooseOne_example, cmd, iknp); +#endif + +#ifdef ENABLE_KOS + flagSet |= runIf(TwoChooseOne_example, cmd, kos); +#endif + +#ifdef ENABLE_DELTA_KOS + flagSet |= runIf(TwoChooseOne_example, cmd, dkos); +#endif + +#ifdef ENABLE_SOFTSPOKEN_OT + flagSet |= runIf(TwoChooseOne_example, SoftSpokenShOtReceiver<>>, cmd, sshonest); + flagSet |= runIf(TwoChooseOne_example, cmd, smalicious); +#endif + return flagSet; + } + } diff --git a/frontend/ExampleTwoChooseOne.h b/frontend/ExampleTwoChooseOne.h index d532c8f3..c90ad94d 100644 --- a/frontend/ExampleTwoChooseOne.h +++ b/frontend/ExampleTwoChooseOne.h @@ -1,317 +1,8 @@ #pragma once - -#include "libOTe/Base/BaseOT.h" -#include "libOTe/TwoChooseOne/Kos/KosOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Kos/KosOtExtSender.h" -#include "libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.h" -#include "libOTe/TwoChooseOne/KosDot/KosDotExtSender.h" -#include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" - - -#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" - -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalLeakyDotExt.h" -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.h" -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.h" -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShDotExt.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { -#ifdef ENABLE_IKNP - void noHash(IknpOtExtSender& s, IknpOtExtReceiver& r) - { - s.mHash = false; - r.mHash = false; - } -#endif - - template - void noHash(Sender&, Receiver&) - { - throw std::runtime_error("This protocol does not support noHash"); - } - -#ifdef ENABLE_SOFTSPOKEN_OT - // soft spoken takes an extra parameter as input what determines - // the computation/communication trade-off. - template - using is_SoftSpoken = typename std::conditional< - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value - false - , - std::true_type, std::false_type>::type; -#else - template - using is_SoftSpoken = std::false_type; -#endif - - template - typename std::enable_if::value,T>::type - construct(CLP& cmd) - { - return T( cmd.getOr("f", 2) ); - } - - template - typename std::enable_if::value, T>::type - construct(CLP& cmd) - { - return T{}; - } - - template - void TwoChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& cmd) - { - -#ifdef COPROTO_ENABLE_BOOST - if (totalOTs == 0) - totalOTs = 1 << 20; - - bool randomOT = true; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - - PRNG prng(sysRandomSeed()); - - - OtExtSender sender; - OtExtRecver receiver; - - -#ifdef LIBOTE_HAS_BASE_OT - // Now compute the base OTs, we need to set them on the first pair of extenders. - // In real code you would only have a sender or reciever, not both. But we do - // here just showing the example. - if (role == Role::Receiver) - { - DefaultBaseOT base; - std::array, 128> baseMsg; - - // perform the base To, call sync_wait to block until they have completed. - cp::sync_wait(base.send(baseMsg, prng, chl)); - receiver.setBaseOts(baseMsg); - } - else - { - - DefaultBaseOT base; - BitVector bv(128); - std::array baseMsg; - bv.randomize(prng); - - // perform the base To, call sync_wait to block until they have completed. - cp::sync_wait(base.receive(bv, baseMsg, prng, chl)); - sender.setBaseOts(baseMsg, bv); - } - -#else - if (!cmd.isSet("fakeBase")) - std::cout << "warning, base ots are not enabled. Fake base OTs will be used. " << std::endl; - PRNG commonPRNG(oc::ZeroBlock); - std::array, 128> sendMsgs; - commonPRNG.get(sendMsgs.data(), sendMsgs.size()); - if (role == Role::Receiver) - { - receiver.setBaseOts(sendMsgs); - } - else - { - BitVector bv(128); - bv.randomize(commonPRNG); - std::array recvMsgs; - for (u64 i = 0; i < 128; ++i) - recvMsgs[i] = sendMsgs[i][bv[i]]; - sender.setBaseOts(recvMsgs, bv); - } -#endif - - if (cmd.isSet("noHash")) - noHash(sender, receiver); - - Timer timer, sendTimer, recvTimer; - sendTimer.setTimePoint("start"); - recvTimer.setTimePoint("start"); - auto s = timer.setTimePoint("start"); - - if (numThreads == 1) - { - if (role == Role::Receiver) - { - // construct the choices that we want. - BitVector choice(totalOTs); - // in this case pick random messages. - choice.randomize(prng); - - // construct a vector to stored the received messages. - std::vector rMsgs(totalOTs); - - if (randomOT) - { - // perform totalOTs random OTs, the results will be written to msgs. - cp::sync_wait(receiver.receive(choice, rMsgs, prng, chl)); - } - else - { - // perform totalOTs chosen message OTs, the results will be written to msgs. - cp::sync_wait(receiver.receiveChosen(choice, rMsgs, prng, chl)); - } - } - else - { - // construct a vector to stored the random send messages. - std::vector> sMsgs(totalOTs); - - - // if delta OT is used, then the user can call the following - // to set the desired XOR difference between the zero messages - // and the one messages. - // - // senders[i].setDelta(some 128 bit delta); - // - - if (randomOT) - { - // perform the OTs and write the random OTs to msgs. - cp::sync_wait(sender.send(sMsgs, prng, chl)); - } - else - { - // Populate msgs with something useful... - prng.get(sMsgs.data(), sMsgs.size()); - - // perform the OTs. The receiver will learn one - // of the messages stored in msgs. - cp::sync_wait(sender.sendChosen(sMsgs, prng, chl)); - } - } - - } - else - { - - // for multi threading, we only show example for random OTs. - // We first need to construct the inputs - // that each thread will use. Note that the actual protocol - // is not thread safe so everything needs to be independent. - std::vector> tasks(numThreads); - std::vector threadPrngs(numThreads); - std::vector threadChls(numThreads); - - macoro::thread_pool::work work; - macoro::thread_pool threadPool(numThreads, work); - - if (role == Role::Receiver) - { - std::vector receivers(numThreads); - std::vector threadChoices(numThreads); - std::vector> threadMsgs(numThreads); - - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - { - u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); - u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); - - threadChoices[threadIndex].resize(endIndex - beginIndex); - threadChoices[threadIndex].randomize(prng); - - threadMsgs[threadIndex].resize(endIndex - beginIndex); - - // create a copy of the receiver so that each can run - // independently. A single receiver is not thread safe. - receivers[threadIndex] = receiver.splitBase(); - - // create a PRNG for this thread. - threadPrngs[threadIndex].SetSeed(prng.get()); - - // create a socket for this thread. This is done by calling fork(). - threadChls[threadIndex] = chl.fork(); - - // start the receive protocol on the thread pool - tasks[threadIndex] = - receivers[threadIndex].receive( - threadChoices[threadIndex], - threadMsgs[threadIndex], - threadPrngs[threadIndex], - threadChls[threadIndex]) - | macoro::start_on(threadPool); - } - - // block this thread until the receive operations - // have completed. - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - cp::sync_wait(tasks[threadIndex]); - } - else - { - std::vector senders(numThreads); - std::vector>> threadMsgs(numThreads); - - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - { - u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); - u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); - - threadMsgs[threadIndex].resize(endIndex - beginIndex); - - // create a copy of the receiver so that each can run - // independently. A single receiver is not thread safe. - senders[threadIndex] = sender.splitBase(); - - // create a PRNG for this thread. - threadPrngs[threadIndex].SetSeed(prng.get()); - - // create a socket for this thread. This is done by calling fork(). - threadChls[threadIndex] = chl.fork(); - - // start the send protocol on the thread pool - tasks[threadIndex] = - senders[threadIndex].send( - threadMsgs[threadIndex], - threadPrngs[threadIndex], - threadChls[threadIndex]) - | macoro::start_on(threadPool); - } - - // block this thread until the receive operations - // have completed. - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - cp::sync_wait(tasks[threadIndex]); - } - - work.reset(); - } - - auto e = timer.setTimePoint("finish"); - auto milli = std::chrono::duration_cast(e - s).count(); - - auto com = 0;// (chls[0].getTotalDataRecv() + chls[0].getTotalDataSent())* numThreads; - - if (role == Role::Sender) - lout << tag << " n=" << Color::Green << totalOTs << " " << milli << " ms " << com << " bytes" << std::endl << Color::Default; - - - if (cmd.isSet("v") && role == Role::Sender) - { - if (role == Role::Sender) - lout << " **** sender ****\n" << sendTimer << std::endl; - - if (role == Role::Receiver) - lout << " **** receiver ****\n" << recvTimer << std::endl; - } - -#else - throw std::runtime_error("This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. " LOCATION); -#endif - } + bool TwoChooseOne_Examples(const CLP& cmd); } diff --git a/frontend/ExampleVole.cpp b/frontend/ExampleVole.cpp new file mode 100644 index 00000000..2ef71894 --- /dev/null +++ b/frontend/ExampleVole.cpp @@ -0,0 +1,111 @@ + +#include "libOTe/Vole/Silent/SilentVoleReceiver.h" +#include "libOTe/Vole/Silent/SilentVoleSender.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ + + + template + void Vole_example(Role role, int numVole, int numThreads, std::string ip, std::string tag, const CLP& cmd) + { +#if defined(ENABLE_SILENT_VOLE) && defined(COPROTO_ENABLE_BOOST) + + if (numVole == 0) + numVole = 1 << 20; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + + // get a random number generator seeded from the system + PRNG prng(sysRandomSeed()); + + auto mulType = (MultType)cmd.getOr("multType", (int)DefaultMultType); + + u64 milli; + Timer timer; + + gTimer.setTimePoint("begin"); + if (role == Role::Receiver) + { + // construct a vector to stored the received messages. + // A = B + C * delta + AlignedUnVector C(numVole); + AlignedUnVector A(numVole); + gTimer.setTimePoint("recver.msg.alloc"); + + SilentVoleReceiver receiver; + receiver.mMultType = mulType; + receiver.configure(numVole); + gTimer.setTimePoint("recver.config"); + + // block until both parties are ready (optional). + cp::sync_wait(sync(chl, role)); + auto b = timer.setTimePoint("start"); + receiver.setTimePoint("start"); + gTimer.setTimePoint("recver.genBase"); + + // perform numVole random OTs, the results will be written to msgs. + cp::sync_wait(receiver.silentReceive(C, A, prng, chl)); + + // record the time. + receiver.setTimePoint("finish"); + auto e = timer.setTimePoint("finish"); + milli = std::chrono::duration_cast(e - b).count(); + } + else + { + gTimer.setTimePoint("sender.thrd.begin"); + + // A = B + C * delta + AlignedUnVector B(numVole); + block delta = prng.get(); + + gTimer.setTimePoint("sender.msg.alloc"); + + SilentVoleSender sender; + sender.mMultType = mulType; + sender.configure(numVole); + gTimer.setTimePoint("sender.config"); + timer.setTimePoint("start"); + + // block until both parties are ready (optional). + cp::sync_wait(sync(chl, role)); + auto b = sender.setTimePoint("start"); + gTimer.setTimePoint("sender.genBase"); + + // perform the OTs and write the random OTs to msgs. + cp::sync_wait(sender.silentSend(delta, B, prng, chl)); + + sender.setTimePoint("finish"); + auto e = timer.setTimePoint("finish"); + milli = std::chrono::duration_cast(e - b).count(); + } + if (role == Role::Sender) + { + + lout << tag << + " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numVole << Color::Default << + " || " << Color::Green << + std::setw(6) << std::setfill(' ') << milli << " ms " << + //std::setw(6) << std::setfill(' ') << com << " bytes" << + std::endl << Color::Default; + + if (cmd.getOr("v", 0) > 1) + lout << gTimer << std::endl; + + } + + // make sure all messages are sent. + cp::sync_wait(chl.flush()); +#endif + } + bool Vole_Examples(const CLP& cmd) + { + return + runIf(Vole_example, cmd, vole); + } + +} diff --git a/frontend/ExampleVole.h b/frontend/ExampleVole.h index c8891879..da6e58a5 100644 --- a/frontend/ExampleVole.h +++ b/frontend/ExampleVole.h @@ -1,142 +1,10 @@ #pragma once - -#include "libOTe/Vole/Silent/SilentVoleReceiver.h" -#include "libOTe/Vole/Silent/SilentVoleSender.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { - - //template - void Vole_example(Role role, int numOTs, int numThreads, std::string ip, std::string tag, CLP& cmd) - { -#if defined(ENABLE_SILENT_VOLE) && defined(COPROTO_ENABLE_BOOST) - - if (numOTs == 0) - numOTs = 1 << 20; - using OtExtSender = SilentVoleSender; - using OtExtRecver = SilentVoleReceiver; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - - // get a random number generator seeded from the system - PRNG prng(sysRandomSeed()); - - auto mulType = (MultType)cmd.getOr("multType", (int)DefaultMultType); - bool fakeBase = cmd.isSet("fakeBase"); - - u64 milli; - Timer timer; - - gTimer.setTimePoint("begin"); - if (role == Role::Receiver) - { - // construct a vector to stored the received messages. - std::unique_ptr backing0(new block[numOTs]); - std::unique_ptr backing1(new block[numOTs]); - span choice(backing0.get(), numOTs); - span msgs(backing1.get(), numOTs); - gTimer.setTimePoint("recver.msg.alloc"); - - OtExtRecver receiver; - receiver.mMultType = mulType; - receiver.configure(numOTs); - gTimer.setTimePoint("recver.config"); - - // generate base OTs - if (fakeBase) - { - auto nn = receiver.baseOtCount(); - std::vector> baseSendMsgs(nn); - PRNG pp(oc::ZeroBlock); - pp.get(baseSendMsgs.data(), baseSendMsgs.size()); - receiver.setBaseOts(baseSendMsgs); - } - else - { - cp::sync_wait(receiver.genSilentBaseOts(prng, chl)); - } - - // block until both parties are ready (optional). - cp::sync_wait(sync(chl, role)); - auto b = timer.setTimePoint("start"); - receiver.setTimePoint("start"); - gTimer.setTimePoint("recver.genBase"); - - // perform numOTs random OTs, the results will be written to msgs. - cp::sync_wait(receiver.silentReceive(choice, msgs, prng, chl)); - - // record the time. - receiver.setTimePoint("finish"); - auto e = timer.setTimePoint("finish"); - milli = std::chrono::duration_cast(e - b).count(); - } - else - { - gTimer.setTimePoint("sender.thrd.begin"); - - - std::unique_ptr backing(new block[numOTs]); - span msgs(backing.get(), numOTs); - - gTimer.setTimePoint("sender.msg.alloc"); - - OtExtSender sender; - sender.mMultType = mulType; - sender.configure(numOTs); - gTimer.setTimePoint("sender.config"); - timer.setTimePoint("start"); - - // generate base OTs - if (fakeBase) - { - auto nn = sender.baseOtCount(); - BitVector bits(nn); bits.randomize(prng); - std::vector> baseSendMsgs(nn); - std::vector baseRecvMsgs(nn); - PRNG pp(oc::ZeroBlock); - pp.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < nn; ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - sender.setBaseOts(baseRecvMsgs, bits); - } - else - { - cp::sync_wait(sender.genSilentBaseOts(prng, chl)); - } - - // block until both parties are ready (optional). - cp::sync_wait(sync(chl, role)); - auto b = sender.setTimePoint("start"); - gTimer.setTimePoint("sender.genBase"); - - // perform the OTs and write the random OTs to msgs. - block delta = prng.get(); - cp::sync_wait(sender.silentSend(delta, msgs, prng, chl)); - - sender.setTimePoint("finish"); - auto e = timer.setTimePoint("finish"); - milli = std::chrono::duration_cast(e - b).count(); - } - if (role == Role::Sender) - { - - lout << tag << - " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << - " || " << Color::Green << - std::setw(6) << std::setfill(' ') << milli << " ms " << - //std::setw(6) << std::setfill(' ') << com << " bytes" << - std::endl << Color::Default; - - if (cmd.getOr("v", 0) > 1) - lout << gTimer << std::endl; - - } - - cp::sync_wait(chl.flush()); -#endif - } + bool Vole_Examples(const CLP& cmd); } diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 63984a5e..25043f5e 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -12,6 +12,7 @@ #include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" #include "libOTe/Vole/Silent/SilentVoleSender.h" #include "libOTe/Vole/Silent/SilentVoleReceiver.h" +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { @@ -171,74 +172,12 @@ namespace osuCrypto std::cout << verbose << std::endl; } - inline void encodeBench(CLP& cmd) - { -#ifdef ENABLE_INSECURE_SILVER - u64 trials = cmd.getOr("t", 10); - - // the message length of the code. - // The noise vector will have size n=2*m. - // the user can use - // -m X - // to state that exactly X rows should be used or - // -mm X - // to state that 2^X rows should be used. - u64 m = cmd.getOr("m", 1ull << cmd.getOr("mm", 10)); - - // the weight of the code, must be 5 or 11. - u64 w = cmd.getOr("w", 5); - - // verbose flag. - bool v = cmd.isSet("v"); - - - SilverCode code; - if (w == 11) - code = SilverCode::Weight11; - else if (w == 5) - code = SilverCode::Weight5; - else - { - std::cout << "invalid weight" << std::endl; - throw RTE_LOC; - } - - PRNG prng(ZeroBlock); - SilverEncoder encoder; - encoder.init(m, code); - - - std::vector x(encoder.cols()); - Timer timer, verbose; - - if (v) - encoder.setTimer(verbose); - - timer.setTimePoint("_____________________"); - for (u64 i = 0; i < trials; ++i) - { - encoder.dualEncode(x); - timer.setTimePoint("encode"); - } - - std::cout << timer << std::endl; - - if (v) - std::cout << verbose << std::endl; -#else - std::cout << "disabled, ENABLE_INSECURE_SILVER not defined " << std::endl; -#endif - } - - - inline void transpose(const CLP& cmd) { #ifdef ENABLE_AVX u64 trials = cmd.getOr("trials", 1ull << 18); { - AlignedArray data; Timer timer; @@ -376,17 +315,14 @@ namespace osuCrypto { std::cout << e.what() << std::endl; } +#else + std::cout << "ENABLE_SILENTOT = false" << std::endl; #endif } - - - - inline void VoleBench2(const CLP& cmd) { -#ifdef ENABLE_SILENTOT - +#ifdef ENABLE_SILENT_VOLE try { @@ -412,6 +348,8 @@ namespace osuCrypto baseSend[i] = prng.get(); baseRecv[i] = baseSend[i][baseChoice[i]]; } + +#ifdef ENABLE_SOFTSPOKEN_OT sender.mOtExtRecver.emplace(); sender.mOtExtSender.emplace(); recver.mOtExtRecver.emplace(); @@ -420,6 +358,7 @@ namespace osuCrypto recver.mOtExtRecver->setBaseOts(baseSend); sender.mOtExtSender->setBaseOts(baseRecv, baseChoice); recver.mOtExtSender->setBaseOts(baseRecv, baseChoice); +#endif // ENABLE_SOFTSPOKEN_OT PRNG prng0(ZeroBlock), prng1(ZeroBlock); block delta = prng0.get(); @@ -478,6 +417,8 @@ namespace osuCrypto { std::cout << e.what() << std::endl; } +#else + std::cout << "ENABLE_Silent_VOLE = false" << std::endl; #endif } } \ No newline at end of file diff --git a/frontend/main.cpp b/frontend/main.cpp index b3c64c07..f98584f1 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -11,9 +11,6 @@ using namespace osuCrypto; #include #include -#include -#include -#include #include #include @@ -32,30 +29,12 @@ using namespace osuCrypto; #include "cryptoTools/Crypto/RandomOracle.h" #include "libOTe/Tools/EACode/EAChecker.h" -static const std::vector -unitTestTag{ "u", "unitTest" }, -kos{ "k", "kos" }, -dkos{ "d", "dkos" }, -ssdelta{ "ssd", "ssdelta" }, -sshonest{ "ss", "sshonest" }, -smleakydelta{ "smld", "smleakydelta" }, -smalicious{ "sm", "smalicious" }, -kkrt{ "kk", "kkrt" }, -iknp{ "i", "iknp" }, -diknp{ "diknp" }, -oos{ "o", "oos" }, -moellerpopf{ "p", "moellerpopf" }, -ristrettopopf{ "r", "ristrettopopf" }, -mr{ "mr" }, -mrb{ "mrb" }, -Silent{ "s", "Silent" }, -vole{ "vole" }, -akn{ "a", "akn" }, -np{ "np" }, -simple{ "simplest" }, -simpleasm{ "simplest-asm" }; +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" #ifdef ENABLE_IKNP +using namespace oc; + void minimal() { // Setup networking. See cryptoTools\frontend_cryptoTools\Tutorials\Network.cpp @@ -113,9 +92,7 @@ int main(int argc, char** argv) // various benchmarks if (cmd.isSet("bench")) { - if (cmd.isSet("silver")) - encodeBench(cmd); - else if (cmd.isSet("QC")) + if (cmd.isSet("QC")) QCCodeBench(cmd); else if (cmd.isSet("silent")) SilentOtBench(cmd); @@ -128,14 +105,6 @@ int main(int argc, char** argv) return 0; } - - // minimum distance checker for EA codes. - if (cmd.isSet("ea")) - { - EAChecker(cmd); - return 0; - } - // unit tests. if (cmd.isSet(unitTestTag)) { @@ -158,96 +127,11 @@ int main(int argc, char** argv) // run various examples. - - -#ifdef ENABLE_SIMPLESTOT - flagSet |= runIf(baseOT_example, cmd, simple); -#endif - -#ifdef ENABLE_SIMPLESTOT_ASM - flagSet |= runIf(baseOT_example, cmd, simpleasm); -#endif - -#ifdef ENABLE_MRR_TWIST -#ifdef ENABLE_SSE - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepEKEPopf factory; - const char* domain = "EKE POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwist(factory)); - }, cmd, moellerpopf, { "eke" }); -#endif - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepMRPopf factory; - const char* domain = "MR POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMR(factory)); - }, cmd, moellerpopf, { "mrPopf" }); - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelPopf factory; - const char* domain = "Feistel POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistFeistel(factory)); - }, cmd, moellerpopf, { "feistel" }); - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelMulPopf factory; - const char* domain = "Feistel With Multiplication POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMul(factory)); -}, cmd, moellerpopf, { "feistelMul" }); -#endif - -#ifdef ENABLE_MRR - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelRistPopf factory; - const char* domain = "Feistel POPF OT example (Risretto)"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoy(factory)); - }, cmd, ristrettopopf, { "feistel" }); - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelMulRistPopf factory; - const char* domain = "Feistel With Multiplication POPF OT example (Risretto)"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyMul(factory)); - }, cmd, ristrettopopf, { "feistelMul" }); -#endif - -#ifdef ENABLE_MR - flagSet |= runIf(baseOT_example, cmd, mr); -#endif - -#ifdef ENABLE_IKNP - flagSet |= runIf(TwoChooseOne_example, cmd, iknp); -#endif - -#ifdef ENABLE_KOS - flagSet |= runIf(TwoChooseOne_example, cmd, kos); -#endif - -#ifdef ENABLE_DELTA_KOS - flagSet |= runIf(TwoChooseOne_example, cmd, dkos); -#endif - -#ifdef ENABLE_SOFTSPOKEN_OT - flagSet |= runIf(TwoChooseOne_example, SoftSpokenShOtReceiver<>>, cmd, sshonest); - flagSet |= runIf(TwoChooseOne_example, cmd, smalicious); -#endif - -#ifdef ENABLE_KKRT - flagSet |= runIf(NChooseOne_example, cmd, kkrt); -#endif - -#ifdef ENABLE_OOS - flagSet |= runIf(NChooseOne_example, cmd, oos); -#endif - - flagSet |= runIf(Silent_example, cmd, Silent); - flagSet |= runIf(Vole_example, cmd, vole); - + flagSet |= baseOT_examples(cmd); + flagSet |= TwoChooseOne_Examples(cmd); + flagSet |= NChooseOne_Examples(cmd); + flagSet |= Silent_Examples(cmd); + flagSet |= Vole_Examples(cmd); if (cmd.isSet("messagePassing")) { @@ -256,8 +140,6 @@ int main(int argc, char** argv) } - - if (flagSet == false) { diff --git a/frontend/util.cpp b/frontend/util.cpp index ac7a95ce..fa6bc91b 100644 --- a/frontend/util.cpp +++ b/frontend/util.cpp @@ -7,6 +7,7 @@ #include #include +#include "util.h" namespace osuCrypto { @@ -67,6 +68,7 @@ namespace osuCrypto else { u8 dummy[1]; + dummy[0] = 0; chl.asyncSend(dummy, 1); chl.recv(dummy, 1); chl.asyncSend(dummy, 1); @@ -79,6 +81,7 @@ namespace osuCrypto { u8 dummy[1]; + dummy[0] = 0; chl.asyncSend(dummy, 1); diff --git a/frontend/util.h b/frontend/util.h index 599a922c..27fb08fe 100644 --- a/frontend/util.h +++ b/frontend/util.h @@ -10,11 +10,34 @@ #include #include #include "libOTe/Tools/Coproto.h" - +#include namespace osuCrypto { + + static const std::vector + unitTestTag{ "u", "unitTest" }, + kos{ "k", "kos" }, + dkos{ "d", "dkos" }, + ssdelta{ "ssd", "ssdelta" }, + sshonest{ "ss", "sshonest" }, + smleakydelta{ "smld", "smleakydelta" }, + smalicious{ "sm", "smalicious" }, + kkrt{ "kk", "kkrt" }, + iknp{ "i", "iknp" }, + diknp{ "diknp" }, + oos{ "o", "oos" }, + moellerpopf{ "p", "moellerpopf" }, + ristrettopopf{ "r", "ristrettopopf" }, + mr{ "mr" }, + mrb{ "mrb" }, + Silent{ "s", "Silent" }, + vole{ "vole" }, + akn{ "a", "akn" }, + np{ "np" }, + simple{ "simplest" }, + simpleasm{ "simplest-asm" }; enum class Role { Sender, @@ -36,7 +59,7 @@ namespace osuCrypto //using ProtocolFunc = std::function; template - inline bool runIf(ProtocolFunc protocol, CLP & cmd, std::vector tag, + inline bool runIf(ProtocolFunc protocol, const CLP & cmd, std::vector tag, std::vector tag2 = std::vector()) { auto n = cmd.isSet("nn") diff --git a/libOTe/Tools/Pprf/RegularPprf.h b/libOTe/Tools/Pprf/RegularPprf.h index ebede7f9..439150af 100644 --- a/libOTe/Tools/Pprf/RegularPprf.h +++ b/libOTe/Tools/Pprf/RegularPprf.h @@ -1,4 +1,7 @@ #pragma once +#include "libOTe/config.h" + +#ifdef ENABLE_PPRF #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/BitVector.h" #include "cryptoTools/Common/Matrix.h" @@ -1031,4 +1034,6 @@ namespace osuCrypto } } }; -} \ No newline at end of file +} + +#endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index a7fcd311..1818f1b2 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -25,6 +25,8 @@ #include #include #include "libOTe/Tools/QuasiCyclicCode.h" +#include "libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h" + namespace osuCrypto { @@ -145,6 +147,8 @@ namespace osuCrypto // generate the needed OTs. task<> genSilentBaseOts(PRNG& prng, Socket& chl) { +#ifdef LIBOTE_HAS_BASE_OT + using BaseOT = DefaultBaseOT; @@ -246,7 +250,11 @@ namespace osuCrypto setSilentBaseOts(msg, baseAs); setTimePoint("SilentVoleReceiver.genSilent.done"); - MC_END(); + MC_END(); +#else + throw std::runtime_error("LIBOTE_HAS_BASE_OT = false, must enable relic, sodium or simplest ot asm." LOCATION); +#endif + }; // configure the silent OT extension. This sets diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index 682e2b5b..a60625a7 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -123,6 +123,7 @@ namespace osuCrypto // generate the needed OTs. task<> genSilentBaseOts(PRNG& prng, Socket& chl, F delta) { +#ifdef LIBOTE_HAS_BASE_OT using BaseOT = DefaultBaseOT; MC_BEGIN(task<>, this, delta, &prng, &chl, @@ -195,6 +196,9 @@ namespace osuCrypto setSilentBaseOts(msg, b); setTimePoint("SilentVoleSender.genSilent.done"); MC_END(); +#else + throw std::runtime_error("LIBOTE_HAS_BASE_OT = false, must enable relic, sodium or simplest ot asm." LOCATION); +#endif } // configure the silent OT extension. This sets diff --git a/libOTe/config.h.in b/libOTe/config.h.in index f9c5bb36..40b1bec8 100644 --- a/libOTe/config.h.in +++ b/libOTe/config.h.in @@ -52,6 +52,8 @@ // build the library with silent vole enabled #cmakedefine ENABLE_SILENT_VOLE @ENABLE_SILENT_VOLE@ +#cmakedefine ENABLE_PPRF @ENABLE_PPRF@ + // build the library with silver codes. #cmakedefine ENABLE_INSECURE_SILVER @ENABLE_INSECURE_SILVER@ diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp index fd3622c9..00c49971 100644 --- a/libOTe_Tests/Pprf_Tests.cpp +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -1,6 +1,9 @@ #include "Pprf_Tests.h" #include "libOTe/Tools/Pprf/RegularPprf.h" +#include "cryptoTools/Common/TestCollection.h" + +#ifdef ENABLE_PPRF #include "cryptoTools/Common/Log.h" #include "Common.h" #include @@ -447,3 +450,22 @@ void Tools_Pprf_callback_test(const oc::CLP& cmd) throw UnitTestSkipped("ENABLE_SILENTOT not defined."); #endif } +#else + + +namespace { + void throwDisabled() + { + throw oc::UnitTestSkipped( + "ENABLE_PPRF not defined. " + ); + } +} + +void Tools_Pprf_expandOne_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_inter_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_ByLeafIndex_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_callback_test(const oc::CLP& cmd) { throwDisabled(); } + +#endif diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index 26dafae6..52ced20e 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -18,7 +18,7 @@ using namespace oc; #include "libOTe/Tools/CoeffCtx.h" using namespace tests_libOTe; - +#ifdef ENABLE_SILENT_VOLE template void Vole_Noisy_test_impl(u64 n) { @@ -356,15 +356,26 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) timer.setTimePoint("done"); } } -// -//namespace { -// void throwDisabled() -// { -// throw UnitTestSkipped( -// "ENABLE_SILENT_VOLE not defined. " -// ); -// } -//} +#else + + +namespace { + void throwDisabled() + { + throw UnitTestSkipped( + "ENABLE_SILENT_VOLE not defined. " + ); + } +} +void Vole_Noisy_test(const oc::CLP& cmd) { throwDisabled(); } +void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { throwDisabled(); } +void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { throwDisabled(); } +void Vole_Silent_baseOT_test(const oc::CLP& cmd) { throwDisabled(); } +void Vole_Silent_mal_test(const oc::CLP& cmd) { throwDisabled(); } +void Vole_Silent_Rounds_test(const oc::CLP& cmd) { throwDisabled(); } + + +#endif // // //void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { throwDisabled(); } From 36cd083faf3cf084dde93300499cfa64916005ea Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 15:14:21 -0800 Subject: [PATCH 20/23] ENABLE_SSE Ctx --- libOTe/Tools/CoeffCtx.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libOTe/Tools/CoeffCtx.h b/libOTe/Tools/CoeffCtx.h index 7b8312ac..ddf64eed 100644 --- a/libOTe/Tools/CoeffCtx.h +++ b/libOTe/Tools/CoeffCtx.h @@ -357,6 +357,8 @@ namespace osuCrypto { { // multiplication y modulo mod block y(0, 4234123421); + +#ifdef ENABLE_SSE static const constexpr std::uint64_t mod = 0b10000111; const __m128i modulus = _mm_loadl_epi64((const __m128i*) & (mod)); @@ -368,6 +370,9 @@ namespace osuCrypto { /* reduce w.r.t. high half of mul256_high */ auto tmp = _mm_clmulepi64_si128(xy2, modulus, 0x00); ret = _mm_xor_si128(xy1, tmp); +#else + ret = x.gf128Mul(y); +#endif } }; From 65bfb08efe5299df3bc7a169838adc5a6baf0ec0 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 15:30:04 -0800 Subject: [PATCH 21/23] missing include --- libOTe/Tools/QuasiCyclicCode.h | 2 +- libOTe/Vole/Silent/SilentVoleReceiver.h | 1 + libOTe/Vole/Silent/SilentVoleSender.h | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/libOTe/Tools/QuasiCyclicCode.h b/libOTe/Tools/QuasiCyclicCode.h index d9de8859..bcf95183 100644 --- a/libOTe/Tools/QuasiCyclicCode.h +++ b/libOTe/Tools/QuasiCyclicCode.h @@ -1,5 +1,5 @@ #pragma once -// © 2022 Visaß. +// © 2022 Visa. // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index 1818f1b2..3f0310c8 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -26,6 +26,7 @@ #include #include "libOTe/Tools/QuasiCyclicCode.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h" +#include namespace osuCrypto { diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index a60625a7..5fe9aaed 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -24,7 +24,7 @@ #include #include #include - +#include namespace osuCrypto { From 9390125807d605d3d1205b0422cd33f8232beaf0 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 15:33:59 -0800 Subject: [PATCH 22/23] mac compile --- libOTe_Tests/Pprf_Tests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp index 00c49971..68eaf4e7 100644 --- a/libOTe_Tests/Pprf_Tests.cpp +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -178,7 +178,7 @@ void Tools_Pprf_test_impl( auto threads = 1; PRNG prng(CCBlock); - using Vec = typename Ctx::Vec; + using Vec = typename Ctx::template Vec; auto sockets = cp::LocalAsyncSocket::makePair(); From 17e21f2a90afc4e3926e6c71404c291f104be877 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 22 Jan 2024 16:14:32 -0800 Subject: [PATCH 23/23] changed base ot to use fewer rounds --- CMakePresets.json | 5 +++++ libOTe/Vole/Silent/SilentVoleReceiver.h | 11 ++++++++++- libOTe/Vole/Silent/SilentVoleSender.h | 11 +++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/CMakePresets.json b/CMakePresets.json index 5b7eb95d..6fe20218 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -16,6 +16,11 @@ "ENABLE_BOOST": true, "ENABLE_BITPOLYMUL": false, "ENABLE_CIRCUITS": true, + "ENABLE_SIMPLESTOT": true, + "ENABLE_MRR": true, + "ENABLE_MR": true, + "ENABLE_SIMPLESTOT": true, + "ENABLE_RELIC": true, "LIBOTE_STD_VER": "17", "CMAKE_PREFIX_PATH": "${sourceDir}/../out/install", "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}" diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index 3f0310c8..13024928 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -150,8 +150,17 @@ namespace osuCrypto { #ifdef LIBOTE_HAS_BASE_OT +#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE + using BaseOT = McRosRoyTwist; +#elif defined ENABLE_MR + using BaseOT = MasnyRindal; +#elif defined ENABLE_MRR + using BaseOT = McRosRoy; +#elif defined ENABLE_NP_KYBER + using BaseOT = MasnyRindalKyber; +#else using BaseOT = DefaultBaseOT; - +#endif MC_BEGIN(task<>, this, &prng, &chl, choice = BitVector{}, diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index 5fe9aaed..81fcce8f 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -124,7 +124,18 @@ namespace osuCrypto task<> genSilentBaseOts(PRNG& prng, Socket& chl, F delta) { #ifdef LIBOTE_HAS_BASE_OT + +#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE + using BaseOT = McRosRoyTwist; +#elif defined ENABLE_MR + using BaseOT = MasnyRindal; +#elif defined ENABLE_MRR + using BaseOT = McRosRoy; +#elif defined ENABLE_NP_KYBER + using BaseOT = MasnyRindalKyber; +#else using BaseOT = DefaultBaseOT; +#endif MC_BEGIN(task<>, this, delta, &prng, &chl, msg = AlignedUnVector>(silentBaseOtCount()),