From cf537295c47a3924c13030a9b796cee9d6ebeace Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Tue, 2 Jan 2024 16:58:58 -0800 Subject: [PATCH] refactor Pprf --- CMakePresets.json | 29 +- cryptoTools | 2 +- libOTe/Tools/SilentPprf.cpp | 1263 +++++++---------- libOTe/Tools/SilentPprf.h | 200 +-- .../Silent/SilentOtExtReceiver.cpp | 2 +- .../TwoChooseOne/Silent/SilentOtExtSender.cpp | 2 +- libOTe/Vole/Silent/SilentVoleReceiver.cpp | 2 +- libOTe/Vole/Silent/SilentVoleSender.cpp | 2 +- libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp | 6 +- libOTe_Tests/SilentOT_Tests.cpp | 288 ++-- libOTe_Tests/SilentOT_Tests.h | 1 + libOTe_Tests/UnitTests.cpp | 4 +- libOTe_Tests/Vole_Tests.cpp | 16 +- 13 files changed, 764 insertions(+), 1053 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 91d4e171..3f921bdd 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -11,7 +11,8 @@ "CMAKE_BUILD_TYPE": "Debug", "FETCH_AUTO": true, "ENABLE_ALL_OT": true, - "ENABLE_SSE": false, + "ENABLE_SSE": true, + "ENABLE_AVX": true, "ENABLE_BITPOLYMUL": false, "ENABLE_CIRCUITS": true, "LIBOTE_STD_VER": "17", @@ -19,8 +20,14 @@ "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}" }, "vendor": { - "microsoft.com/VisualStudioSettings/CMake/1.0": { "hostOS": [ "Linux" ] }, - "microsoft.com/VisualStudioRemoteSettings/CMake/1.0": { "sourceDir": "$env{HOME}/.vs/$ms{projectDirName}" } + "microsoft.com/VisualStudioSettings/CMake/1.0": { + "hostOS": [ + "Linux" + ] + }, + "microsoft.com/VisualStudioRemoteSettings/CMake/1.0": { + "sourceDir": "$env{HOME}/.vs/$ms{projectDirName}" + } } }, { @@ -57,7 +64,13 @@ "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}", "CMAKE_PREFIX_PATH": "${sourceDir}/../out/install/${presetName}" }, - "vendor": { "microsoft.com/VisualStudioSettings/CMake/1.0": { "hostOS": [ "Windows" ] } } + "vendor": { + "microsoft.com/VisualStudioSettings/CMake/1.0": { + "hostOS": [ + "Windows" + ] + } + } }, { "name": "x64-Release", @@ -85,7 +98,13 @@ "ENABLE_ASAN": false, "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}" }, - "vendor": { "microsoft.com/VisualStudioSettings/CMake/1.0": { "hostOS": [ "Windows" ] } } + "vendor": { + "microsoft.com/VisualStudioSettings/CMake/1.0": { + "hostOS": [ + "Windows" + ] + } + } } ] } \ No newline at end of file diff --git a/cryptoTools b/cryptoTools index ddf93782..5f90354b 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit ddf937825eca17700abcb83474f40114cbe9fa3e +Subproject commit 5f90354b499adddbcf6861a3b4463e0724e5f719 diff --git a/libOTe/Tools/SilentPprf.cpp b/libOTe/Tools/SilentPprf.cpp index b63dcee6..28bd2850 100644 --- a/libOTe/Tools/SilentPprf.cpp +++ b/libOTe/Tools/SilentPprf.cpp @@ -50,54 +50,7 @@ namespace osuCrypto PprfOutputFormat oFormat, std::function> lvl)>& callback) { - - if (oFormat == PprfOutputFormat::InterleavedTransposed) - { - // not having an even (8) number of trees is not supported. - if (totalTrees % 8) - throw RTE_LOC; - if (lvl.size() % 16) - throw RTE_LOC; - - // - //auto rowsPer = 16; - //auto step = lvl.size() - - //auto sectionSize = - - if (lvl.size() < 16) - throw RTE_LOC; - - - auto setIdx = tIdx / 8; - auto blocksPerSet = lvl.size() * 8 / 128; - - - - auto numSets = totalTrees / 8; - auto begin = setIdx; - auto step = numSets; - - if (oFormat == PprfOutputFormat::InterleavedTransposed) - { - auto end = std::min(begin + step * blocksPerSet, output.cols()); - - for (u64 i = begin, k = 0; i < end; i += step, ++k) - { - auto& io = *(std::array*)(&lvl[k * 16]); - transpose128(io.data()); - for (u64 j = 0; j < 128; ++j) - output(j, i) = io[j]; - } - } - else - { - // no op - } - - - } - else if (oFormat == PprfOutputFormat::Plain) + if (oFormat == PprfOutputFormat::ByLeafIndex) { auto curSize = std::min(totalTrees - tIdx, 8); @@ -130,7 +83,7 @@ namespace osuCrypto } } - else if (oFormat == PprfOutputFormat::BlockTransposed) + else if (oFormat == PprfOutputFormat::ByTreeIndex) { auto curSize = std::min(totalTrees - tIdx, 8); @@ -160,10 +113,6 @@ namespace osuCrypto } } - else if (oFormat == PprfOutputFormat::Interleaved) - { - // no op - } else if (oFormat == PprfOutputFormat::Callback) callback(tIdx, lvl); else @@ -172,8 +121,6 @@ namespace osuCrypto u64 interleavedPoint(u64 point, u64 treeIdx, u64 totalTrees, u64 domain, PprfOutputFormat format) { - - switch (format) { case osuCrypto::PprfOutputFormat::Interleaved: @@ -189,23 +136,6 @@ namespace osuCrypto return (forest * domain + point) * 8 + subTree; } break; - case osuCrypto::PprfOutputFormat::InterleavedTransposed: - { - auto numSets = totalTrees / 8; - - auto setIdx = treeIdx / 8; - auto subIdx = treeIdx % 8; - - auto sectionIdx = point / 16; - auto posIdx = point % 16; - - - auto setOffset = setIdx * 128; - auto subOffset = subIdx + 8 * posIdx; - auto secOffset = sectionIdx * numSets * 128; - - return setOffset + subOffset + secOffset; - } default: throw RTE_LOC; break; @@ -240,8 +170,8 @@ namespace osuCrypto switch (format) { - case PprfOutputFormat::Plain: - case PprfOutputFormat::BlockTransposed: + case PprfOutputFormat::ByLeafIndex: + case PprfOutputFormat::ByTreeIndex: memset(points.data(), 0, points.size() * sizeof(u64)); for (u64 j = 0; j < mPntCount; ++j) @@ -250,7 +180,6 @@ namespace osuCrypto } break; - case PprfOutputFormat::InterleavedTransposed: case PprfOutputFormat::Interleaved: case PprfOutputFormat::Callback: @@ -259,7 +188,7 @@ namespace osuCrypto if (points.size() % 8) throw RTE_LOC; - getPoints(points, PprfOutputFormat::Plain); + getPoints(points, PprfOutputFormat::ByLeafIndex); interleavedPoints(points, mDomain, format); break; @@ -282,8 +211,8 @@ namespace osuCrypto u64 idx; switch (format) { - case osuCrypto::PprfOutputFormat::Plain: - case osuCrypto::PprfOutputFormat::BlockTransposed: + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: do { for (u64 j = 0; j < mDepth; ++j) mBaseChoices(i, j) = prng.getBit(); @@ -292,9 +221,13 @@ namespace osuCrypto break; case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::InterleavedTransposed: case osuCrypto::PprfOutputFormat::Callback: + if (modulus > mPntCount * mDomain) + throw std::runtime_error("modulus too big. " LOCATION); + if (modulus < mPntCount * mDomain / 2) + throw std::runtime_error("modulus too small. " LOCATION); + // make sure that at least the first element of this tree // is within the modulus. idx = interleavedPoint(0, i, mPntCount, mDomain, format); @@ -336,138 +269,143 @@ namespace osuCrypto mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); for (u64 i = 0; i < mPntCount; ++i) { + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = choices[mDepth * i + j]; + switch (format) { - case osuCrypto::PprfOutputFormat::Plain: - case osuCrypto::PprfOutputFormat::BlockTransposed: - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = choices[mDepth * i + j]; - break; - - // Not sure what ordering would be good for Interleaved or - // InterleavedTransposed. + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + if (getActivePath(mBaseChoices[i]) >= mDomain) + throw RTE_LOC; + break; + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::Callback: + { + auto idx = getActivePath(mBaseChoices[i]); + auto idx2 = interleavedPoint(idx, i, mPntCount, mDomain, format); + if(idx2 > mPntCount * mDomain) + throw std::runtime_error("the base ot choice bits index outside of the domain. see sampleChoiceBits(...). " LOCATION); + break; + } default: throw RTE_LOC; break; } - - if (getActivePath(mBaseChoices[i]) >= mDomain) - throw RTE_LOC; } } - //task<> SilentMultiPprfSender::expand( - // Socket& chls, - // block value, - // PRNG& prng, - // MatrixView output, - // PprfOutputFormat oFormat, - // u64 numThreads) - //{ - // return expand(chls, { &value, 1 }, prng, output, oFormat, numThreads); - //} - - - // Returns the i'th level of the current 8 trees. The - // children of node j on level i are located at 2*j and - // 2*j+1 on level i+1. - span> SilentMultiPprfSender::Expander::getLevel(u64 i, u64 g) + namespace { - - if (oFormat == PprfOutputFormat::Interleaved && i == pprf.mDepth) - { - auto b = (AlignedArray*)output.data(); - auto forest = g / 8; - assert(g % 8 == 0); - b += forest * pprf.mDomain; - return span>(b, pprf.mDomain); - } - - return mLevels[i]; - }; - - SilentMultiPprfSender::Expander::Expander(SilentMultiPprfSender& p, block seed, u64 treeIdx_, - PprfOutputFormat of, MatrixViewo, bool activeChildXorDelta, Socket&& s) - : pprf(p) - , chl(std::move(s)) - , mActiveChildXorDelta(activeChildXorDelta) - { - treeIdx = treeIdx_; - assert((treeIdx & 7) == 0); - output = o; - oFormat = of; // A public PRF/PRG that we will use for deriving the GGM tree. - aes[0].setKey(toBlock(3242342)); - aes[1].setKey(toBlock(8993849)); - prng.SetSeed(seed); - dd = pprf.mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - - + const std::array gAes = []() { + std::array aes; + aes[0].setKey(toBlock(3242342)); + aes[1].setKey(toBlock(8993849)); + return aes; + }(); } - task<> SilentMultiPprfSender::Expander::run() + void SilentMultiPprfSender::expandOne( + block aesSeed, + u64 treeIdx, + bool programActivePath, + span>> levels, + span, 2>> encSums, + span> lastOts) { - MC_BEGIN(task<>, this); + // The number of real trees for this iteration. + auto min = std::min(8, mPntCount - treeIdx); + // the first level should be size 1, the root of the tree. + // we will populate it with random seeds using aesSeed in counter mode + // based on the tree index. + assert(levels[0].size() == 1); + mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); -#ifdef DEBUG_PRINT_PPRF - chl.asyncSendCopy(mValue); -#endif - // pprf.setTimePoint("SilentMultiPprfSender.begin " + std::to_string(treeIdx)); - { - tree = pprf.mTreeAlloc.get(); - assert(tree.size() >= 1ull << (dd)); - assert((u64)tree.data() % 32 == 0); - mLevels.resize(dd); - mLevels[0] = tree.subspan(0, 1); - auto rem = tree.subspan(mLevels[0].size()); - for (auto i : rng(1ull, dd)) - { - while ((u64)rem.data() % 32) - rem = rem.subspan(1); + assert(encSums.size() == mDepth - programActivePath); + assert(encSums.size() < 24); - mLevels[i] = rem.subspan(0, mLevels[i - 1].size() * 2); - rem = rem.subspan(mLevels[i].size()); - } - } - // pprf.setTimePoint("SilentMultiPprfSender.alloc " + std::to_string(treeIdx)); + // space for our sums of each level. Should always be less then + // 24 levels... If not increase the limit or make it a vector. + std::array, 2>, 24> sums; + memset(&sums, 0, sizeof(sums)); - // This thread will process 8 trees at a time. It will interlace - // the sets of trees are processed with the other threads. + // For each level perform the following. + for (u64 d = 0; d < mDepth; ++d) { - // The number of real trees for this iteration. - min = std::min(8, pprf.mPntCount - treeIdx); - //gTimer.setTimePoint("send.start" + std::to_string(treeIdx)); + // The previous level of the GGM tree. + auto level0 = levels[d]; - // Populate the zeroth level of the GGM tree with random seeds. - prng.get(getLevel(0, treeIdx)); + // The next level of theGGM tree that we are populating. + auto level1 = levels[d + 1]; - // Allocate space for our sums of each level. - sums[0].resize(pprf.mDepth); - sums[1].resize(pprf.mDepth); + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); - // For each level perform the following. - for (u64 d = 0; d < pprf.mDepth; ++d) + // use the optimized approach for intern nodes of the tree + if (d + 1 < mDepth && 0) { - // The previous level of the GGM tree. - auto level0 = getLevel(d, treeIdx); - - // The next level of theGGM tree that we are populating. - auto level1 = getLevel(d + 1, treeIdx); - - // The total number of children in this level. - auto width = static_cast(level1.size()); - // For each child, populate the child by expanding the parent. - for (u64 childIdx = 0; childIdx < width; ) + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) { - // Index of the parent in the previous level. - auto parentIdx = childIdx >> 1; + // The value of the parent. + auto& parent = level0.data()[parentIdx]; + + auto& child0 = level1.data()[childIdx]; + auto& child1 = level1.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + child0[0] = child1[0] ^ parent[0]; + child0[1] = child1[1] ^ parent[1]; + child0[2] = child1[2] ^ parent[2]; + child0[3] = child1[3] ^ parent[3]; + child0[4] = child1[4] ^ parent[4]; + child0[5] = child1[5] ^ parent[5]; + child0[6] = child1[6] ^ parent[6]; + child0[7] = child1[7] ^ parent[7]; + + // Update the running sums for this level. We keep + // a left and right totals for each level. + auto& sum = sums[d]; + sum[0][0] = sum[0][0] ^ child0[0]; + sum[0][1] = sum[0][1] ^ child0[1]; + sum[0][2] = sum[0][2] ^ child0[2]; + sum[0][3] = sum[0][3] ^ child0[3]; + sum[0][4] = sum[0][4] ^ child0[4]; + sum[0][5] = sum[0][5] ^ child0[5]; + sum[0][6] = sum[0][6] ^ child0[6]; + sum[0][7] = sum[0][7] ^ child0[7]; + + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + sum[1][0] = sum[1][0] ^ child1[0]; + sum[1][1] = sum[1][1] ^ child1[1]; + sum[1][2] = sum[1][2] ^ child1[2]; + sum[1][3] = sum[1][3] ^ child1[3]; + sum[1][4] = sum[1][4] ^ child1[4]; + sum[1][5] = sum[1][5] ^ child1[5]; + sum[1][6] = sum[1][6] ^ child1[6]; + sum[1][7] = sum[1][7] ^ child1[7]; + } + } + else + { + // for the leaf nodes we need to hash both children. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) + { // The value of the parent. - auto& parent = level0[parentIdx]; + auto& parent = level0.data()[parentIdx]; // The bit that indicates if we are on the left child (0) // or on the right child (1). @@ -477,7 +415,7 @@ namespace osuCrypto auto& child = level1[childIdx]; // The sum that this child node belongs to. - auto& sum = sums[keep][d]; + auto& sum = sums[d][keep]; // Each parent is expanded into the left and right children // using a different AES fixed-key. Therefore our OWF is: @@ -485,7 +423,7 @@ namespace osuCrypto // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); // // where each half defines one of the children. - aes[keep].hashBlocks<8>(parent.data(), child.data()); + gAes[keep].hashBlocks<8>(parent.data(), child.data()); // Update the running sums for this level. We keep // a left and right totals for each level. @@ -500,242 +438,226 @@ namespace osuCrypto } } } + } - -#ifdef DEBUG_PRINT_PPRF - // If we are debugging, then send over the full tree - // to make sure its correct on the other side. - chl.asyncSendCopy(tree); -#endif - - // For all but the last level, mask the sums with the - // OT strings and send them over. - for (u64 d = 0; d < pprf.mDepth - mActiveChildXorDelta; ++d) + // For all but the last level, mask the sums with the + // OT strings and send them over. + for (u64 d = 0; d < mDepth - programActivePath; ++d) + { + for (u64 j = 0; j < min; ++j) { - for (u64 j = 0; j < min; ++j) - { -#ifdef DEBUG_PRINT_PPRF - if (mPrint) - { - std::cout << "c[" << treeIdx + j << "][" << d << "][0] " << sums[0][d][j] << " " << mBaseOTs[treeIdx + j][d][0] << std::endl;; - std::cout << "c[" << treeIdx + j << "][" << d << "][1] " << sums[1][d][j] << " " << mBaseOTs[treeIdx + j][d][1] << std::endl;; - } -#endif - sums[0][d][j] = sums[0][d][j] ^ pprf.mBaseOTs[treeIdx + j][d][0]; - sums[1][d][j] = sums[1][d][j] ^ pprf.mBaseOTs[treeIdx + j][d][1]; - } + encSums[d][0][j] = sums[d][0][j] ^ mBaseOTs[treeIdx + j][d][0]; + encSums[d][1][j] = sums[d][1][j] ^ mBaseOTs[treeIdx + j][d][1]; } - // pprf.setTimePoint("SilentMultiPprfSender.expand " + std::to_string(treeIdx)); + } - if (mActiveChildXorDelta) + if (programActivePath) + { + // For the last level, we are going to do something special. + // The other party is currently missing both leaf children of + // the active parent. Since this is the last level, we want + // the inactive child to just be the normal value but the + // active child should be the correct value XOR the delta. + // This will be done by sending the sums and the sums plus + // delta and ensure that they can only decrypt the correct ones. + auto d = mDepth - 1; + assert(lastOts.size() == min); + for (u64 j = 0; j < min; ++j) { - // For the last level, we are going to do something special. - // The other party is currently missing both leaf children of - // the active parent. Since this is the last level, we want - // the inactive child to just be the normal value but the - // active child should be the correct value XOR the delta. - // This will be done by sending the sums and the sums plus - // delta and ensure that they can only decrypt the correct ones. - d = pprf.mDepth - 1; - //std::vector>& lastOts = lastOts; - lastOts.resize(min); - for (u64 j = 0; j < min; ++j) - { - // Construct the sums where we will allow the delta (mValue) - // to either be on the left child or right child depending - // on which has the active path. - lastOts[j][0] = sums[0][d][j]; - lastOts[j][1] = sums[1][d][j] ^ pprf.mValue[treeIdx + j]; - lastOts[j][2] = sums[1][d][j]; - lastOts[j][3] = sums[0][d][j] ^ pprf.mValue[treeIdx + j]; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - std::array masks, maskIn; - maskIn[0] = pprf.mBaseOTs[treeIdx + j][d][0]; - maskIn[1] = pprf.mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; - maskIn[2] = pprf.mBaseOTs[treeIdx + j][d][1]; - maskIn[3] = pprf.mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; - mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); - -#ifdef DEBUG_PRINT_PPRF - if (mPrint) { - std::cout << "c[" << treeIdx + j << "][" << d << "][0] " << sums[0][d][j] << " " << mBaseOTs[treeIdx + j][d][0] << std::endl;; - std::cout << "c[" << treeIdx + j << "][" << d << "][1] " << sums[1][d][j] << " " << mBaseOTs[treeIdx + j][d][1] << std::endl;; - } -#endif - - // Add the OT masks to the sums and send them over. - lastOts[j][0] = lastOts[j][0] ^ masks[0]; - lastOts[j][1] = lastOts[j][1] ^ masks[1]; - lastOts[j][2] = lastOts[j][2] ^ masks[2]; - lastOts[j][3] = lastOts[j][3] ^ masks[3]; - } - - // pprf.setTimePoint("SilentMultiPprfSender.last " + std::to_string(treeIdx)); - - // Resize the sums to that they dont include - // the unmasked sums on the last level! - sums[0].resize(pprf.mDepth - 1); - sums[1].resize(pprf.mDepth - 1); - } - - // Send the sums to the other party. - //sendOne(treeGrp); - //chl.asyncSend(std::move(sums[0])); - //chl.asyncSend(std::move(sums[1])); - - MC_AWAIT(chl.send(std::move(sums[0]))); - MC_AWAIT(chl.send(std::move(sums[1]))); - - if (mActiveChildXorDelta) - MC_AWAIT(chl.send(std::move(lastOts))); - - - //// send the special OT messages for the last level. - //chl.asyncSend(std::move(lastOts)); - //gTimer.setTimePoint("send.expand_send"); - - // copy the last level to the output. If desired, this is - // where the transpose is performed. - auto lvl = getLevel(pprf.mDepth, treeIdx); - - // s is a checksum that is used for malicious security. - copyOut(lvl, output, pprf.mPntCount, treeIdx, oFormat, pprf.mOutputFn); - - // pprf.setTimePoint("SilentMultiPprfSender.copyOut " + std::to_string(treeIdx)); - + // Construct the sums where we will allow the delta (mValue) + // to either be on the left child or right child depending + // on which has the active path. + lastOts[j][0] = sums[d][0][j]; + lastOts[j][1] = sums[d][1][j] ^ mValue[treeIdx + j]; + lastOts[j][2] = sums[d][1][j]; + lastOts[j][3] = sums[d][0][j] ^ mValue[treeIdx + j]; + + // We are going to expand the 128 bit OT string + // into a 256 bit OT string using AES. + std::array masks, maskIn; + maskIn[0] = mBaseOTs[treeIdx + j][d][0]; + maskIn[1] = mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; + maskIn[2] = mBaseOTs[treeIdx + j][d][1]; + maskIn[3] = mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; + mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); + + // Add the OT masks to the sums and send them over. + lastOts[j][0] = lastOts[j][0] ^ masks[0]; + lastOts[j][1] = lastOts[j][1] ^ masks[1]; + lastOts[j][2] = lastOts[j][2] ^ masks[2]; + lastOts[j][3] = lastOts[j][3] ^ masks[3]; } - - //uPtr_ = {}; - //tree = {}; - pprf.mTreeAlloc.del(tree); - // pprf.setTimePoint("SilentMultiPprfSender.delete " + std::to_string(treeIdx)); - - MC_END(); } + } + void allocateExpandBuffer( + u64 depth, + u64 activeChildXorDelta, + std::vector& buff, + span< std::array, 2>>& sums, + span< std::array>& last) + { - //task<> expand( - // Socket& chl, - // span value, - // PRNG& prng, - // MatrixView output, - // PprfOutputFormat oFormat, - // bool activeChildXorDelta, - // u64 numThreads); + using SumType = std::array, 2>; + using LastType = std::array; + u64 numSums = depth - activeChildXorDelta; + u64 numLast = activeChildXorDelta * 8; + u64 numBlocks = numSums * 16 + numLast * 4; + buff.resize(numBlocks); + sums = span((SumType*)buff.data(), numSums); + last = span((LastType*)(sums.data() + sums.size()), numLast); + + void* sEnd = sums.data() + sums.size(); + void* lEnd = last.data() + last.size(); + void* end = buff.data() + buff.size(); + if (sEnd > end || lEnd > end) + throw RTE_LOC; + } + void allocateExpandTree( + u64 dpeth, + TreeAllocator& alloc, + span>& tree, + std::vector>>& levels) + { + tree = alloc.get(); + assert((u64)tree.data() % 32 == 0); + levels[0] = tree.subspan(0, 1); + auto rem = tree.subspan(2); + for (auto i : rng(1ull, dpeth)) + { + levels[i] = rem.subspan(0, levels[i - 1].size() * 2); + assert((u64)levels[i].data() % 32 == 0); + rem = rem.subspan(levels[i].size()); + } + } - task<> SilentMultiPprfSender::expand( - Socket& chl, - span value, - PRNG& prng, - MatrixView output, + void validateExpandFormat( PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 numThreads) + MatrixView output, + u64 domain, + u64 pntCount + ) { - if (activeChildXorDelta) - setValue(value); - setTimePoint("SilentMultiPprfSender.start"); - //gTimer.setTimePoint("send.enter"); - - - - if (oFormat == PprfOutputFormat::Plain) - { - if (output.rows() != mDomain) - throw RTE_LOC; - if (output.cols() != mPntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::BlockTransposed) + if (oFormat == PprfOutputFormat::ByLeafIndex) { - if (output.cols() != mDomain) + if (output.rows() != domain) throw RTE_LOC; - if (output.rows() != mPntCount) + if (output.cols() != pntCount) throw RTE_LOC; } - else if (oFormat == PprfOutputFormat::InterleavedTransposed) + else if (oFormat == PprfOutputFormat::ByTreeIndex) { - if (output.rows() != 128) + if (output.cols() != domain) throw RTE_LOC; - //if (output.cols() > (mDomain * mPntCount + 127) / 128) - // throw RTE_LOC; - - if (mPntCount & 7) + if (output.rows() != pntCount) throw RTE_LOC; } - else if - (oFormat == PprfOutputFormat::Interleaved) + else if (oFormat == PprfOutputFormat::Interleaved) { if (output.cols() != 1) throw RTE_LOC; - if (mDomain & 1) + if (domain & 1) throw RTE_LOC; auto rows = output.rows(); - if (rows > (mDomain * mPntCount) || - rows / 128 != (mDomain * mPntCount) / 128) + if (rows > (domain * pntCount) || + rows / 128 != (domain * pntCount) / 128) throw RTE_LOC; - if (mPntCount & 7) + if (pntCount & 7) throw RTE_LOC; } else if (oFormat == PprfOutputFormat::Callback) { - if (mDomain & 1) + if (domain & 1) throw RTE_LOC; - if (mPntCount & 7) + if (pntCount & 7) throw RTE_LOC; } else { throw RTE_LOC; } + } + + + task<> SilentMultiPprfSender::expand( + Socket& chl, + span value, + block seed, + MatrixView output, + PprfOutputFormat oFormat, + bool activeChildXorDelta, + u64 numThreads) + { + if (activeChildXorDelta) + setValue(value); + + setTimePoint("SilentMultiPprfSender.start"); + validateExpandFormat(oFormat, output, mDomain, mPntCount); - MC_BEGIN(task<>, this, numThreads, oFormat, output, &prng, &chl, activeChildXorDelta, + MC_BEGIN(task<>, this, numThreads, oFormat, output, seed, &chl, activeChildXorDelta, i = u64{}, - dd = u64{} + mTreeAllocDepth = u64{}, + tree = span>{}, + levels = std::vector>>{}, + buff = std::vector{}, + sums = span, 2>>{}, + last = span>{} ); + //if (oFormat == PprfOutputFormat::Callback && numThreads > 1) + // throw RTE_LOC; - if (oFormat == PprfOutputFormat::Callback && numThreads > 1) - throw RTE_LOC; - - dd = mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - mTreeAlloc.reserve(numThreads, (1ull << dd) + (32 * dd)); + mTreeAllocDepth = mDepth + (oFormat != PprfOutputFormat::Interleaved); + mTreeAlloc.reserve(numThreads, (1ull << mTreeAllocDepth) + 2); setTimePoint("SilentMultiPprfSender.reserve"); - mExps.clear(); mExps.reserve(divCeil(mPntCount, 8)); + levels.resize(mDepth + 1); + allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); + for (i = 0; i < mPntCount; i += 8) { - mExps.emplace_back(*this, prng.get(), i, oFormat, output, activeChildXorDelta, chl.fork()); - mExps.back().mFuture = macoro::make_eager(mExps.back().run()); - //MC_AWAIT(mExps.back().run()); - } + // for interleaved format, the last level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + auto b = (AlignedArray*)output.data(); + auto forest = i / 8; + b += forest * mDomain; - for (i = 0; i < mExps.size(); ++i) - MC_AWAIT(mExps[i].mFuture); + levels.back() = span>(b, mDomain); + } + + // allocate the send buffer and partition it. + allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); + + // exapnd the tree + expandOne(seed, i, activeChildXorDelta, levels, sums, last); - mExps.clear(); - setTimePoint("SilentMultiPprfSender.join"); + MC_AWAIT(chl.send(std::move(buff))); + + // if we aren't interleaved, we need to copy the + // last layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); + } mBaseOTs = {}; - //mTreeAlloc.clear(); + mTreeAlloc.del(tree); + mTreeAlloc.clear(); + setTimePoint("SilentMultiPprfSender.de-alloc"); MC_END(); - - } void SilentMultiPprfSender::setValue(span value) { - mValue.resize(mPntCount); if (value.size() == 1) @@ -759,213 +681,115 @@ namespace osuCrypto mPntCount = 0; } + void SilentMultiPprfReceiver::expandOne( - // Returns the i'th level of the current 8 trees. The - // children of node j on level i are located at 2*j and - // 2*j+1 on level i+1. - span> SilentMultiPprfReceiver::Expander::getLevel(u64 i, u64 g, bool f) - { - //auto size = (1ull << i); -#ifdef DEBUG_PRINT_PPRF - //auto offset = (size - 1); - //auto b = (f ? ftree.begin() : tree.begin()) + offset; -#else - if (oFormat == PprfOutputFormat::Interleaved && i == pprf.mDepth) - { - auto b = (AlignedArray*)output.data(); - auto forest = g / 8; - assert(g % 8 == 0); - b += forest * pprf.mDomain; - auto zone = span>(b, pprf.mDomain); - return zone; - } - - //assert(tree.size()); - //auto b = tree.begin() + offset; - - return mLevels[i]; -#endif - //return span>(b,e); - }; - - - SilentMultiPprfReceiver::Expander::Expander(SilentMultiPprfReceiver& p, Socket&& s, PprfOutputFormat of, MatrixView o, bool activeChildXorDelta, u64 ti) - : pprf(p) - , chl(std::move(s)) - , mActiveChildXorDelta(activeChildXorDelta) - , oFormat(of) - , output(o) - , treeIdx(ti) - //, threadIdx(tIdx) - { - assert((treeIdx & 7) == 0); - // A public PRF/PRG that we will use for deriving the GGM tree. - aes[0].setKey(toBlock(3242342)); - aes[1].setKey(toBlock(8993849)); - - - theirSums[0].resize(p.mDepth - mActiveChildXorDelta); - theirSums[1].resize(p.mDepth - mActiveChildXorDelta); - - dd = p.mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - - } - - task<> SilentMultiPprfReceiver::Expander::run() + u64 treeIdx, + bool programActivePath, + span>> levels, + span, 2>> theirSums, + span> lastOts) { + // This thread will process 8 trees at a time. - MC_BEGIN(task<>, this); - - + // special case for the first level. + auto l1 = levels[1]; + for (u64 i = 0; i < 8; ++i) { - tree = pprf.mTreeAlloc.get(); - assert(tree.size() >= 1ull << (dd)); - mLevels.resize(dd); - mLevels[0] = tree.subspan(0, 1); - auto rem = tree.subspan(1); - for (auto i : rng(1ull, dd)) - { - while ((u64)rem.data() % 32) - rem = rem.subspan(1); - mLevels[i] = rem.subspan(0, mLevels[i - 1].size() * 2); - rem = rem.subspan(mLevels[i].size()); - } + // For the non-active path, set the child of the root node + // as the OT message XOR'ed with the correction sum. + int notAi = mBaseChoices[i + treeIdx][0]; + l1[notAi][i] = mBaseOTs[i + treeIdx][0] ^ theirSums[0][notAi][i]; + l1[notAi ^ 1][i] = ZeroBlock; } + // space for our sums of each level. + std::array, 2> mySums; -#ifdef DEBUG_PRINT_PPRF - // This will be the full tree and is sent by the receiver to help debug. - std::vector> ftree(1ull << (mDepth + 1)); - - // The delta value on the active path. - //block deltaValue; - chl.recv(mDebugValue); -#endif - - - -#ifdef DEBUG_PRINT_PPRF - // prints out the contents of the d'th level. - auto printLevel = [&](u64 d) + // For all other levels, expand the GGM tree and add in + // the correction along the active path. + for (u64 d = 1; d < mDepth; ++d) { + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; + + // The next level that we want to construct. + auto level1 = levels[d + 1]; - auto level0 = getLevel(d); - auto flevel0 = getLevel(d, true); + // Zero out the previous sums. + memset(mySums.data(), 0, sizeof(mySums)); - std::cout - << "---------------------\nlevel " << d - << "\n---------------------" << std::endl; + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); - std::array sums{ ZeroBlock ,ZeroBlock }; - for (i64 i = 0; i < level0.size(); ++i) + // for internal nodes we the optimized approach. + if (d + 1 < mDepth && 0) { - for (u64 j = 0; j < 8; ++j) + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) { + // The value of the parent. + auto parent = level0[parentIdx]; - if (neq(level0[i][j], flevel0[i][j])) - std::cout << Color::Red; - - std::cout << "p[" << i << "][" << j << "] " - << level0[i][j] << " " << flevel0[i][j] << std::endl << Color::Default; - - if (i == 0 && j == 0) - sums[i & 1] = sums[i & 1] ^ flevel0[i][j]; - } - } - - std::cout << "sums[0] = " << sums[0] << " " << sums[1] << std::endl; - }; -#endif - - - // The number of real trees for this iteration. - lastOts.resize(8); - - // This thread will process 8 trees at a time. It will interlace - // the sets of trees are processed with the other threads. - { -#ifdef DEBUG_PRINT_PPRF - chl.recv(ftree); - auto l1f = getLevel(1, true); -#endif - - //timer.setTimePoint("recv.start" + std::to_string(treeIdx)); - // Receive their full set of sums for these 8 trees. - MC_AWAIT(chl.recv(theirSums[0])); - MC_AWAIT(chl.recv(theirSums[1])); - - if (mActiveChildXorDelta) - MC_AWAIT(chl.recv(lastOts)); - // pprf.setTimePoint("SilentMultiPprfReceiver.recv " + std::to_string(treeIdx)); - - tree = pprf.mTreeAlloc.get(); - assert(tree.size() >= 1ull << (dd)); - assert((u64)tree.data() % 32 == 0); - - // pprf.setTimePoint("SilentMultiPprfReceiver.alloc " + std::to_string(treeIdx)); - - auto l1 = getLevel(1, treeIdx); - - for (u64 i = 0; i < 8; ++i) - { - // For the non-active path, set the child of the root node - // as the OT message XOR'ed with the correction sum. - int notAi = pprf.mBaseChoices[i + treeIdx][0]; - l1[notAi][i] = pprf.mBaseOTs[i + treeIdx][0] ^ theirSums[notAi][0][i]; - l1[notAi ^ 1][i] = ZeroBlock; - -#ifdef DEBUG_PRINT_PPRF - if (neq(l1[notAi][i], l1f[notAi][i])) { - std::cout << "l1[" << notAi << "][" << i << "] " << l1[notAi][i] << " = " - << (mBaseOTs[i + treeIdx][0]) << " ^ " - << theirSums[notAi][0][i] << " vs " << l1f[notAi][i] << std::endl; + auto& child0 = level1.data()[childIdx]; + auto& child1 = level1.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + child0[0] = child1[0] ^ parent[0]; + child0[1] = child1[1] ^ parent[1]; + child0[2] = child1[2] ^ parent[2]; + child0[3] = child1[3] ^ parent[3]; + child0[4] = child1[4] ^ parent[4]; + child0[5] = child1[5] ^ parent[5]; + child0[6] = child1[6] ^ parent[6]; + child0[7] = child1[7] ^ parent[7]; + + // Update the running sums for this level. We keep + // a left and right totals for each level. Note that + // we are actually XOR in the incorrect value of the + // children of the active parent (assuming !DEBUG_PRINT_PPRF). + // This is ok since we will later XOR off these incorrect values. + mySums[0][0] = mySums[0][0] ^ child0[0]; + mySums[0][1] = mySums[0][1] ^ child0[1]; + mySums[0][2] = mySums[0][2] ^ child0[2]; + mySums[0][3] = mySums[0][3] ^ child0[3]; + mySums[0][4] = mySums[0][4] ^ child0[4]; + mySums[0][5] = mySums[0][5] ^ child0[5]; + mySums[0][6] = mySums[0][6] ^ child0[6]; + mySums[0][7] = mySums[0][7] ^ child0[7]; + + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + mySums[1][0] = mySums[1][0] ^ child1[0]; + mySums[1][1] = mySums[1][1] ^ child1[1]; + mySums[1][2] = mySums[1][2] ^ child1[2]; + mySums[1][3] = mySums[1][3] ^ child1[3]; + mySums[1][4] = mySums[1][4] ^ child1[4]; + mySums[1][5] = mySums[1][5] ^ child1[5]; + mySums[1][6] = mySums[1][6] ^ child1[6]; + mySums[1][7] = mySums[1][7] ^ child1[7]; } -#endif } - -#ifdef DEBUG_PRINT_PPRF - if (mPrint) - printLevel(1); -#endif - - // For all other levels, expand the GGM tree and add in - // the correction along the active path. - for (u64 d = 1; d < pprf.mDepth; ++d) + else { - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = getLevel(d, treeIdx); - - // The next level that we want to construct. - auto level1 = getLevel(d + 1, treeIdx); - - // Zero out the previous sums. - memset(mySums[0].data(), 0, mySums[0].size() * sizeof(block)); - memset(mySums[1].data(), 0, mySums[1].size() * sizeof(block)); - - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = static_cast(level1.size()); - for (u64 childIdx = 0; childIdx < width; ) + // for leaf nodes both children should be hashed. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) { - - // Index of the parent in the previous level. - auto parentIdx = childIdx >> 1; - // The value of the parent. auto parent = level0[parentIdx]; for (u64 keep = 0; keep < 2; ++keep, ++childIdx) { - - //// The bit that indicates if we are on the left child (0) - //// or on the right child (1). - //u8 keep = childIdx & 1; - - // The child that we will write in this iteration. auto& child = level1[childIdx]; @@ -975,16 +799,8 @@ namespace osuCrypto // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); // // where each half defines one of the children. - aes[keep].hashBlocks<8>(parent.data(), child.data()); - + gAes[keep].hashBlocks<8>(parent.data(), child.data()); - -#ifdef DEBUG_PRINT_PPRF - // For debugging, set the active path to zero. - for (u64 i = 0; i < 8; ++i) - if (eq(parent[i], ZeroBlock)) - child[i] = ZeroBlock; -#endif // Update the running sums for this level. We keep // a left and right totals for each level. Note that // we are actually XOR in the incorrect value of the @@ -1001,268 +817,179 @@ namespace osuCrypto sum[7] = sum[7] ^ child[7]; } } + } - // For everything but the last level we have to - // 1) fix our sums so they dont include the incorrect - // values that are the children of the active parent - // 2) Update the non-active child of the active parent. - if (!mActiveChildXorDelta || d != pprf.mDepth - 1) - { - - for (u64 i = 0; i < 8; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = pprf.mPoints[i + treeIdx]; - - // The index of the active child node. - auto activeChildIdx = leafIdx >> (pprf.mDepth - 1 - d); - - // The index of the active child node sibling. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; -#ifdef DEBUG_PRINT_PPRF - auto prev = level1[inactiveChildIdx][i]; -#endif - - auto& inactiveChild = level1[inactiveChildIdx][i]; - - - // correct the sum value by XORing off the incorrect - auto correctSum = - inactiveChild ^ - theirSums[notAi][d][i]; - - inactiveChild = - correctSum ^ - mySums[notAi][i] ^ - pprf.mBaseOTs[i + treeIdx][d]; - -#ifdef DEBUG_PRINT_PPRF - if (mPrint) - std::cout << "up[" << i << "] = level1[" << inactiveChildIdx << "][" << i << "] " - << prev << " -> " << level1[inactiveChildIdx][i] << " " << activeChildIdx << " " << inactiveChildIdx << " ~~ " - << mBaseOTs[i + treeIdx][d] << " " << theirSums[notAi][d][i] << " @ " << (i + treeIdx) << " " << d << std::endl; - - auto fLevel1 = getLevel(d + 1, true); - if (neq(fLevel1[inactiveChildIdx][i], inactiveChild)) - throw RTE_LOC; -#endif - } - } -#ifdef DEBUG_PRINT_PPRF - if (mPrint) - printLevel(d + 1); -#endif - - } - - // pprf.setTimePoint("SilentMultiPprfReceiver.expand " + std::to_string(treeIdx)); - - //timer.setTimePoint("recv.expanded"); - - - // copy the last level to the output. If desired, this is - // where the transpose is performed. - auto lvl = getLevel(pprf.mDepth, treeIdx); - - if (mActiveChildXorDelta) + // For everything but the last level we have to + // 1) fix our sums so they dont include the incorrect + // values that are the children of the active parent + // 2) Update the non-active child of the active parent. + if (!programActivePath || d != mDepth - 1) { - // Now processes the last level. This one is special - // because we must XOR in the correction value as - // before but we must also fixed the child value for - // the active child. To do this, we will receive 4 - // values. Two for each case (left active or right active). - //timer.setTimePoint("recv.recvLast"); - - auto level = getLevel(pprf.mDepth, treeIdx); - auto d = pprf.mDepth - 1; - for (u64 j = 0; j < 8; ++j) + for (u64 i = 0; i < 8; ++i) { - // The index of the child on the active path. - auto activeChildIdx = pprf.mPoints[j + treeIdx]; + // the index of the leaf node that is active. + auto leafIdx = mPoints[i + treeIdx]; + + // The index of the active child node. + auto activeChildIdx = leafIdx >> (mDepth - 1 - d); - // The index of the other (inactive) child. + // The index of the active child node sibling. auto inactiveChildIdx = activeChildIdx ^ 1; // The indicator as to the left or right child is inactive auto notAi = inactiveChildIdx & 1; - std::array masks, maskIn; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - maskIn[0] = pprf.mBaseOTs[j + treeIdx][d]; - maskIn[1] = pprf.mBaseOTs[j + treeIdx][d] ^ AllOneBlock; - mAesFixedKey.hashBlocks<2>(maskIn.data(), masks.data()); - - // now get the chosen message OT strings by XORing - // the expended (random) OT strings with the lastOts values. - auto& ot0 = lastOts[j][2 * notAi + 0]; - auto& ot1 = lastOts[j][2 * notAi + 1]; - ot0 = ot0 ^ masks[0]; - ot1 = ot1 ^ masks[1]; + auto& inactiveChild = level1[inactiveChildIdx][i]; -#ifdef DEBUG_PRINT_PPRF - auto prev = level[inactiveChildIdx][j]; -#endif + // correct the sum value by XORing off the incorrect + auto correctSum = + inactiveChild ^ + theirSums[d][notAi][i]; - auto& inactiveChild = level[inactiveChildIdx][j]; - auto& activeChild = level[activeChildIdx][j]; - - // Fix the sums we computed previously to not include the - // incorrect child values. - auto inactiveSum = mySums[notAi][j] ^ inactiveChild; - auto activeSum = mySums[notAi ^ 1][j] ^ activeChild; - - // Update the inactive and active child to have to correct - // value by XORing their full sum with out partial sum, which - // gives us exactly the value we are missing. - inactiveChild = ot0 ^ inactiveSum; - activeChild = ot1 ^ activeSum; - -#ifdef DEBUG_PRINT_PPRF - auto fLevel1 = getLevel(d + 1, true); - if (neq(fLevel1[inactiveChildIdx][j], inactiveChild)) - throw RTE_LOC; - if (neq(fLevel1[activeChildIdx][j], activeChild ^ mDebugValue)) - throw RTE_LOC; - - if (mPrint) - std::cout << "up[" << d << "] = level1[" << (inactiveChildIdx / mPntCount) << "][" << (inactiveChildIdx % mPntCount) << " " - << prev << " -> " << level[inactiveChildIdx][j] << " ~~ " - << mBaseOTs[j + treeIdx][d] << " " << ot0 << " @ " << (j + treeIdx) << " " << d << std::endl; -#endif - } - // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); - - //timer.setTimePoint("recv.expandLast"); - } - else - { - for (auto j : rng(std::min(8, pprf.mPntCount - treeIdx))) - { + inactiveChild = + correctSum ^ + mySums[notAi][i] ^ + mBaseOTs[i + treeIdx][d]; - // The index of the child on the active path. - auto activeChildIdx = pprf.mPoints[j + treeIdx]; - lvl[activeChildIdx][j] = ZeroBlock; } } + } - // s is a checksum that is used for malicious security. - copyOut(lvl, output, pprf.mPntCount, treeIdx, oFormat, pprf.mOutputFn); - - // pprf.setTimePoint("SilentMultiPprfReceiver.copy " + std::to_string(treeIdx)); - - //uPtr_ = {}; - //tree = {}; - pprf.mTreeAlloc.del(tree); - - // pprf.setTimePoint("SilentMultiPprfReceiver.delete " + std::to_string(treeIdx)); + // last level. + auto level = levels[mDepth]; + if (programActivePath) + { + // Now processes the last level. This one is special + // because we must XOR in the correction value as + // before but we must also fixed the child value for + // the active child. To do this, we will receive 4 + // values. Two for each case (left active or right active). + //timer.setTimePoint("recv.recvLast"); + auto d = mDepth - 1; + for (u64 j = 0; j < 8; ++j) + { + // The index of the child on the active path. + auto activeChildIdx = mPoints[j + treeIdx]; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + std::array masks, maskIn; + + // We are going to expand the 128 bit OT string + // into a 256 bit OT string using AES. + maskIn[0] = mBaseOTs[j + treeIdx][d]; + maskIn[1] = mBaseOTs[j + treeIdx][d] ^ AllOneBlock; + mAesFixedKey.hashBlocks<2>(maskIn.data(), masks.data()); + + // now get the chosen message OT strings by XORing + // the expended (random) OT strings with the lastOts values. + auto& ot0 = lastOts[j][2 * notAi + 0]; + auto& ot1 = lastOts[j][2 * notAi + 1]; + ot0 = ot0 ^ masks[0]; + ot1 = ot1 ^ masks[1]; + + auto& inactiveChild = level[inactiveChildIdx][j]; + auto& activeChild = level[activeChildIdx][j]; + + // Fix the sums we computed previously to not include the + // incorrect child values. + auto inactiveSum = mySums[notAi][j] ^ inactiveChild; + auto activeSum = mySums[notAi ^ 1][j] ^ activeChild; + + // Update the inactive and active child to have to correct + // value by XORing their full sum with out partial sum, which + // gives us exactly the value we are missing. + inactiveChild = ot0 ^ inactiveSum; + activeChild = ot1 ^ activeSum; } + // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); - MC_END(); + //timer.setTimePoint("recv.expandLast"); + } + else + { + for (auto j : rng(std::min(8, mPntCount - treeIdx))) + { + // The index of the child on the active path. + auto activeChildIdx = mPoints[j + treeIdx]; + level[activeChildIdx][j] = ZeroBlock; } + } + } - - - task<> SilentMultiPprfReceiver::expand(Socket& chl, PRNG& prng, MatrixView output, + task<> SilentMultiPprfReceiver::expand( + Socket& chl, + MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, - u64 numThreads) + u64 _) { + validateExpandFormat(oFormat, output, mDomain, mPntCount); - setTimePoint("SilentMultiPprfReceiver.start"); - - //lout << " d " << mDomain << " p " << mPntCount << " do " << mDepth << std::endl; - - if (oFormat == PprfOutputFormat::Plain) - { - if (output.rows() != mDomain) - throw RTE_LOC; + MC_BEGIN(task<>, this, oFormat, output, &chl, activeChildXorDelta, + i = u64{}, + mTreeAllocDepth = u64{}, + tree = span>{}, + levels = std::vector>>{}, + buff = std::vector{}, + sums = span, 2>>{}, + last = span>{} + ); - if (output.cols() != mPntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::BlockTransposed) - { - if (output.cols() != mDomain) - throw RTE_LOC; + setTimePoint("SilentMultiPprfReceiver.start"); + mPoints.resize(roundUpTo(mPntCount, 8)); + getPoints(mPoints, PprfOutputFormat::ByLeafIndex); - if (output.rows() != mPntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::InterleavedTransposed) - { - if (output.rows() != 128) - throw RTE_LOC; + mTreeAllocDepth = mDepth + (oFormat != PprfOutputFormat::Interleaved); + mTreeAlloc.reserve(1, (1ull << mTreeAllocDepth) + 2); + setTimePoint("SilentMultiPprfSender.reserve"); - //if (output.cols() > (mDomain * mPntCount + 127) / 128) - // throw RTE_LOC; + levels.resize(mDepth + 1); + allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); - if (mPntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Interleaved) - { - if (output.cols() != 1) - throw RTE_LOC; - if (mDomain & 1) - throw RTE_LOC; - auto rows = output.rows(); - if (rows > (mDomain * mPntCount) || - rows / 128 != (mDomain * mPntCount) / 128) - throw RTE_LOC; - if (mPntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Callback) - { - if (mDomain & 1) - throw RTE_LOC; - if (mPntCount & 7) - throw RTE_LOC; - } - else + for (i = 0; i < mPntCount; i += 8) { - throw RTE_LOC; - } - - mPoints.resize(roundUpTo(mPntCount, 8)); - getPoints(mPoints, PprfOutputFormat::Plain); - + // for interleaved format, the last level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + auto b = (AlignedArray*)output.data(); + auto forest = i / 8; + b += forest * mDomain; - MC_BEGIN(task<>, this, numThreads, oFormat, output, &chl, activeChildXorDelta, - i = u64{}, - dd = u64{} - ); + levels.back() = span>(b, mDomain); + } + // allocate the send buffer and partition it. + allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); - dd = mDepth + (oFormat == PprfOutputFormat::Interleaved ? 0 : 1); - mTreeAlloc.reserve(numThreads, (1ull << (dd)) + (32 * dd)); - setTimePoint("SilentMultiPprfReceiver.reserve"); + MC_AWAIT(chl.recv(buff)); - mExps.clear(); mExps.reserve(divCeil(mPntCount, 8)); - for (i = 0; i < mPntCount; i += 8) - { - mExps.emplace_back(*this, chl.fork(), oFormat, output, activeChildXorDelta, i); - mExps.back().mFuture = macoro::make_eager(mExps.back().run()); + // exapnd the tree + expandOne(i, activeChildXorDelta, levels, sums, last); - //MC_AWAIT(mExps.back().run()); + // if we aren't interleaved, we need to copy the + // last layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); } - for (i = 0; i < mExps.size(); ++i) - MC_AWAIT(mExps[i].mFuture); setTimePoint("SilentMultiPprfReceiver.join"); mBaseOTs = {}; + mTreeAlloc.del(tree); + mTreeAlloc.clear(); + setTimePoint("SilentMultiPprfReceiver.de-alloc"); MC_END(); } - } +} #endif diff --git a/libOTe/Tools/SilentPprf.h b/libOTe/Tools/SilentPprf.h index 014c05c2..b5ffbd4d 100644 --- a/libOTe/Tools/SilentPprf.h +++ b/libOTe/Tools/SilentPprf.h @@ -24,14 +24,35 @@ namespace osuCrypto { + // the various formats that the output of the + // Pprf can be generated. enum class PprfOutputFormat { - Plain, // One column per tree, one row per leaf - BlockTransposed, // One row per tree, one column per leaf - Interleaved, - InterleavedTransposed, // Bit transposed - Callback // call the user's callback - + // The i'th row holds the i'th leaf for all trees. + // The j'th tree is in the j'th column. + ByLeafIndex, + + // The i'th row holds the i'th tree. + // The j'th leaf is in the j'th column. + ByTreeIndex, + + // The native output mode. The output will be + // a single row with all leaf values. + // Every 8 trees are mixed together where the + // i'th leaf for each of the 8 tree will be next + // to each other. For example, let tij be the j'th + // leaf of the i'th tree. If we have m leaves, then + // + // t00 t10 ... t70 t01 t11 ... t71 ... t0m t1m ... t7m + // t80 t90 ... t_{15,0} t81 t91 ... t_{15,1} ... t8m t9m ... t_{15,m} + // ... + // + // These are all flattened into a single row. + Interleaved, + + // call the user's callback. The leaves will be in + // Interleaved format. + Callback }; enum class OTType @@ -106,12 +127,27 @@ namespace osuCrypto void clear() { assert(mNumTrees == mFreeTrees.size()); - mTrees.clear(); + mTrees = {}; mFreeTrees = {}; - mTreeSize = {}; + mTreeSize = 0; + mNumTrees = 0; } }; + + void allocateExpandBuffer( + u64 depth, + u64 activeChildXorDelta, + std::vector& buff, + span< std::array, 2>>& sums, + span< std::array>& last); + + void allocateExpandTree( + u64 dpeth, + TreeAllocator& alloc, + span>& tree, + std::vector>>& levels); + class SilentMultiPprfSender : public TimerAdapter { public: @@ -157,35 +193,17 @@ namespace osuCrypto void setBase(span> baseMessages); - - // expand the whole PPRF and store the result in output - //task<> expand(Socket& chl, block value, PRNG& prng, span output, PprfOutputFormat oFormat, u64 numThreads) - //{ - // MatrixView o(output.data(), output.size(), 1); - // return expand(chl, value, prng, o, oFormat, numThreads); - //} - - - //task<> expand( - // Socket& chl, - // block value, - // PRNG& prng, - // MatrixView output, - // PprfOutputFormat oFormat, - // bool activeChildXorDelta, - // u64 numThreads); - - task<> expand(Socket& chls, span value, PRNG& prng, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + task<> expand(Socket& chls, span value, block seed, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) { MatrixView o(output.data(), output.size(), 1); - return expand(chls, value, prng, o, oFormat, activeChildXorDelta, numThreads); + return expand(chls, value, seed, o, oFormat, activeChildXorDelta, numThreads); } task<> expand( Socket& chl, span value, - PRNG& prng, + block seed, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, @@ -195,50 +213,14 @@ namespace osuCrypto void clear(); - struct Expander - { - SilentMultiPprfSender& pprf; - Socket chl; - std::array aes; - PRNG prng; - u64 dd, treeIdx, min, d; - bool mActiveChildXorDelta = true; - - macoro::eager_task mFuture; - std::vector>> mLevels; - - //std::unique_ptr uPtr_; - - // tree will hold the full GGM tree. Note that there are 8 - // indepenendent trees that are being processed together. - // The trees are flattenned to that the children of j are - // located at 2*j and 2*j+1. - span> tree; - - // sums will hold the left and right GGM tree sums - // for each level. For example sums[0][i][5] will - // hold the sum of the left children for level i of - // the 5th tree. - std::array>, 2> sums; - std::vector> lastOts; - - PprfOutputFormat oFormat; - - MatrixView output; - - // The number of real trees for this iteration. - // Returns the i'th level of the current 8 trees. The - // children of node j on level i are located at 2*j and - // 2*j+1 on level i+1. - span> getLevel(u64 i, u64 g); - - Expander(SilentMultiPprfSender& p, block seed, u64 treeIdx, - PprfOutputFormat of, MatrixViewo, bool activeChildXorDelta, Socket&& s); - - task<> run(); - }; - - std::vector mExps; + void expandOne( + block aesSeed, + u64 treeIdx, + bool activeChildXorDelta, + span>> levels, + span, 2>> sums, + span> lastOts + ); }; @@ -259,7 +241,6 @@ namespace osuCrypto SilentMultiPprfReceiver() = default; SilentMultiPprfReceiver(const SilentMultiPprfReceiver&) = delete; SilentMultiPprfReceiver(SilentMultiPprfReceiver&&) = delete; - //SilentMultiPprfReceiver(u64 domainSize, u64 pointCount); void configure(u64 domainSize, u64 pointCount) { @@ -270,11 +251,10 @@ namespace osuCrypto mBaseOTs.resize(0, 0); } - - // For output format Plain or BlockTransposed, the choice bits it + // For output format ByLeafIndex or ByTreeIndex, the choice bits it // samples are in blocks of mDepth, with mPntCount blocks total (one for - // each punctured point). For Plain these blocks encode the punctured - // leaf index in big endian, while for BlockTransposed they are in + // each punctured point). For ByLeafIndex these blocks encode the punctured + // leaf index in big endian, while for ByTreeIndex they are in // little endian. BitVector sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng); @@ -293,7 +273,6 @@ namespace osuCrypto return mBaseOTs.size(); } - void setBase(span baseMessages); std::vector getPoints(PprfOutputFormat format) @@ -304,16 +283,16 @@ namespace osuCrypto } void getPoints(span points, PprfOutputFormat format); - task<> expand(Socket& chl, PRNG& prng, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) + task<> expand(Socket& chl, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) { MatrixView o(output.data(), output.size(), 1); - return expand(chl, prng, o, oFormat, activeChildXorDelta, numThreads); + return expand(chl, o, oFormat, activeChildXorDelta, numThreads); } // activeChildXorDelta says whether the sender is trying to program the // active child to be its correct value XOR delta. If it is not, the // active child will just take a random value. - task<> expand(Socket& chl, PRNG& prng, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads); + task<> expand(Socket& chl, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads); void clear() { @@ -324,57 +303,12 @@ namespace osuCrypto mPntCount = 0; } - - - struct Expander - { - SilentMultiPprfReceiver& pprf; - Socket chl; - - bool mActiveChildXorDelta = false; - std::array aes; - - PprfOutputFormat oFormat; - MatrixView output; - - macoro::eager_task mFuture; - - std::vector>> mLevels; - - // mySums will hold the left and right GGM tree sums - // for each level. For example mySums[5][0] will - // hold the sum of the left children for the 5th tree. This - // sum will be "missing" the children of the active parent. - // The sender will give of one of the full somes so we can - // compute the missing inactive child. - std::array, 2> mySums; - - // A buffer for receiving the sums from the other party. - // These will be masked by the OT strings. - std::array>, 2> theirSums; - - u64 dd, treeIdx; - // tree will hold the full GGM tree. Not that there are 8 - // indepenendent trees that are being processed together. - // The trees are flattenned to that the children of j are - // located at 2*j and 2*j+1. - //std::unique_ptr uPtr_; - span> tree; - - std::vector> lastOts; - - - // Returns the i'th level of the current 8 trees. The - // children of node j on level i are located at 2*j and - // 2*j+1 on level i+1. - span> getLevel(u64 i, u64 g, bool f = false); - - - Expander(SilentMultiPprfReceiver& p, Socket&& s, PprfOutputFormat of, MatrixView o, bool activeChildXorDelta, u64 treeIdx); - task<> run(); - }; - - std::vector mExps; + void expandOne( + u64 treeIdx, + bool programActivePath, + span>> levels, + span, 2>> encSums, + span> lastOts); }; } #endif diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index 7109b177..3107f0db 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -507,7 +507,7 @@ namespace osuCrypto } - MC_AWAIT(mGen.expand(chl, prng, mA.subspan(0, main), PprfOutputFormat::Interleaved, true, mNumThreads)); + MC_AWAIT(mGen.expand(chl, mA.subspan(0, main), PprfOutputFormat::Interleaved, true, mNumThreads)); setTimePoint("recver.expand.pprf_transpose"); gTimer.setTimePoint("recver.expand.pprf_transpose"); diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp index 3af5c6cf..7fb8e32d 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp @@ -454,7 +454,7 @@ namespace osuCrypto } - MC_AWAIT(mGen.expand(chl, { &mDelta,1 }, prng, mB.subspan(0, main), PprfOutputFormat::Interleaved, true, mNumThreads)); + MC_AWAIT(mGen.expand(chl, { &mDelta,1 }, prng.get(), mB.subspan(0, main), PprfOutputFormat::Interleaved, true, mNumThreads)); if (mMalType == SilentSecType::Malicious) diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.cpp b/libOTe/Vole/Silent/SilentVoleReceiver.cpp index 17380ada..8d1972b8 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.cpp +++ b/libOTe/Vole/Silent/SilentVoleReceiver.cpp @@ -477,7 +477,7 @@ using BaseOT = DefaultBaseOT; if (mTimer) mGen.setTimer(*mTimer); // expand the seeds into mA - MC_AWAIT(mGen.expand(chl, prng, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); + MC_AWAIT(mGen.expand(chl, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); diff --git a/libOTe/Vole/Silent/SilentVoleSender.cpp b/libOTe/Vole/Silent/SilentVoleSender.cpp index 825828f8..ef107a62 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.cpp +++ b/libOTe/Vole/Silent/SilentVoleSender.cpp @@ -362,7 +362,7 @@ namespace osuCrypto // output of the PPRF at the correct locations. noiseShares = span(mNoiseDeltaShares.data(), mNumPartitions); mbb = mB.subspan(0, mNumPartitions * mSizePer); - MC_AWAIT(mGen.expand(chl, noiseShares, prng, mbb, + MC_AWAIT(mGen.expand(chl, noiseShares, prng.get(), mbb, PprfOutputFormat::Interleaved, true, mNumThreads)); setTimePoint("SilentVoleSender.expand.pprf_transpose"); diff --git a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp index 1525c531..f8ad9f5a 100644 --- a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp +++ b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp @@ -379,7 +379,7 @@ namespace osuCrypto throw RTE_LOC; mPprf->setBase(baseMessages); - mPprf->setChoiceBits(PprfOutputFormat::BlockTransposed, choices); + mPprf->setChoiceBits(PprfOutputFormat::ByTreeIndex, choices); } void SmallFieldVoleSender::setBaseOts(span> msgs) @@ -413,7 +413,7 @@ namespace osuCrypto std::fill(mSeeds.begin(), mSeeds.end(), block(0, 0)); seedView = MatrixView(mSeeds.data(), mNumVoles, fieldSize()); - MC_AWAIT(mPprf->expand(chl, span(), prng, seedView, PprfOutputFormat::BlockTransposed, false, 1)); + MC_AWAIT(mPprf->expand(chl, span(), prng.get(), seedView, PprfOutputFormat::ByTreeIndex, false, 1)); // Prove consistency if (mMalicious) @@ -466,7 +466,7 @@ namespace osuCrypto seedsFull.resize(mNumVoles, fieldSize()); - MC_AWAIT(mPprf->expand(chl, prng, seedsFull, PprfOutputFormat::BlockTransposed, false, 1)); + MC_AWAIT(mPprf->expand(chl, seedsFull, PprfOutputFormat::ByTreeIndex, false, 1)); // Check consistency if (mMalicious) diff --git a/libOTe_Tests/SilentOT_Tests.cpp b/libOTe_Tests/SilentOT_Tests.cpp index b74df507..e1eb82e1 100644 --- a/libOTe_Tests/SilentOT_Tests.cpp +++ b/libOTe_Tests/SilentOT_Tests.cpp @@ -185,7 +185,7 @@ void Tools_quasiCyclic_test(const oc::CLP& cmd) QuasiCyclicCode code; u64 nn = 1 << 10; - u64 t = 10; + u64 t = 1; auto scaler = 2; //auto secParam = 128; @@ -239,7 +239,7 @@ void Tools_quasiCyclic_test(const oc::CLP& cmd) } code.dualEncode(A); - + for (u64 i : rng(mP)) { @@ -254,7 +254,7 @@ void Tools_quasiCyclic_test(const oc::CLP& cmd) { - + mP = nextPrime(50); n = mP * scaler; code.init(mP); @@ -460,7 +460,7 @@ void OtExt_Silent_random_Test(const CLP& cmd) { #ifdef ENABLE_SILENTOT - + auto sockets = cp::LocalAsyncSocket::makePair(); u64 n = cmd.getOr("n", 10000); @@ -498,7 +498,7 @@ void OtExt_Silent_correlated_Test(const CLP& cmd) { #ifdef ENABLE_SILENTOT - + auto sockets = cp::LocalAsyncSocket::makePair(); u64 n = cmd.getOr("n", 10000); @@ -541,7 +541,7 @@ void OtExt_Silent_inplace_Test(const CLP& cmd) #ifdef ENABLE_SILENTOT - + auto sockets = cp::LocalAsyncSocket::makePair(); u64 n = cmd.getOr("n", 10000); @@ -598,7 +598,7 @@ void OtExt_Silent_paramSweep_Test(const oc::CLP& cmd) { #ifdef ENABLE_SILENTOT - + auto sockets = cp::LocalAsyncSocket::makePair(); std::vector nn = cmd.getManyOr("n", @@ -641,7 +641,7 @@ void OtExt_Silent_QuasiCyclic_Test(const oc::CLP& cmd) #if defined(ENABLE_SILENTOT) && defined(ENABLE_BITPOLYMUL) - + auto sockets = cp::LocalAsyncSocket::makePair(); std::vector nn = cmd.getManyOr("n", @@ -769,7 +769,7 @@ void OtExt_Silent_baseOT_Test(const oc::CLP& cmd) #ifdef ENABLE_SILENTOT - + auto sockets = cp::LocalAsyncSocket::makePair(); u64 n = 123;// @@ -809,7 +809,7 @@ void OtExt_Silent_mal_Test(const oc::CLP& cmd) #ifdef ENABLE_SILENTOT - + auto sockets = cp::LocalAsyncSocket::makePair(); u64 n = 12093;// @@ -841,115 +841,141 @@ void OtExt_Silent_mal_Test(const oc::CLP& cmd) #endif } -void Tools_Pprf_test(const CLP& cmd) + +void Tools_Pprf_expandOne_test(const oc::CLP& cmd) { #if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - u64 depth = cmd.getOr("d", 3);; - u64 domain = 1ull << depth; - auto threads = cmd.getOr("t", 3ull); - u64 numPoints = cmd.getOr("s", 8); + u64 depth = cmd.getOr("d", 4);; + u64 domain = (1ull << depth) * 0.75; + auto pntCount = 8ull; + PRNG prng(CCBlock); - PRNG prng(ZeroBlock); - - //IOService ios; - //Session s0(ios, "localhost:1212", SessionMode::Server); - //Session s1(ios, "localhost:1212", SessionMode::Client); - //auto sockets[0] = s0.addChannel(); - //auto sockets[1] = s1.addChannel(); - - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::Plain; + auto format = PprfOutputFormat::Interleaved; SilentMultiPprfSender sender; SilentMultiPprfReceiver recver; - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); + sender.configure(domain, pntCount); + recver.configure(domain, pntCount); + + block value = prng.get(); + sender.setValue({ &value, 1 }); auto numOTs = sender.baseOtCount(); std::vector> sendOTs(numOTs); std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain, format, prng); + BitVector recvBits = recver.sampleChoiceBits(domain * pntCount, format, prng); - prng.get(sendOTs.data(), sendOTs.size()); - //sendOTs[cmd.getOr("i",0)] = prng.get(); - //recvBits[16] = 1; + prng.get(sendOTs.data(), sendOTs.size()); for (u64 i = 0; i < numOTs; ++i) { - //recvBits[i] = 0; recvOTs[i] = sendOTs[i][recvBits[i]]; } sender.setBase(sendOTs); recver.setBase(recvOTs); - Matrix sOut(domain, numPoints); - Matrix rOut(domain, numPoints); - std::vector points(numPoints); - recver.getPoints(points, format); + std::vector points(8); + recver.getPoints(points, PprfOutputFormat::ByLeafIndex); - auto p0 = sender.expand(sockets[0], {&CCBlock,1}, prng, sOut, format, true, threads); - auto p1 = recver.expand(sockets[1], prng, rOut, format, true, threads); + block seed = CCBlock; + bool program = true; - eval(p0, p1); + auto sTree = span>{}; + auto sLevels = std::vector>>{}; + auto rTree = span>{}; + auto rLevels = std::vector>>{}; + //auto rBuff = std::vector{}; + auto sBuff = std::vector{}; + auto sSums = span, 2>>{}; + auto sLast = span>{}; - bool failed = false; + TreeAllocator mTreeAlloc; + sLevels.resize(depth + 1); + rLevels.resize(depth + 1); - for (u64 j = 0; j < numPoints; ++j) - { + mTreeAlloc.reserve(2, (1ull << (depth + 1)) + 2); + allocateExpandTree(depth + 1, mTreeAlloc, sTree, sLevels); + allocateExpandTree(depth + 1, mTreeAlloc, rTree, rLevels); - for (u64 i = 0; i < domain; ++i) + + allocateExpandBuffer(depth, program, sBuff, sSums, sLast); + //allocateExpandBuffer(depth, program, rbuff, sSums, sLast); + + recver.mPoints.resize(roundUpTo(recver.mPntCount, 8)); + recver.getPoints(recver.mPoints, PprfOutputFormat::ByLeafIndex); + + sender.expandOne(seed, 0, program, sLevels, sSums, sLast); + recver.expandOne(0, program, rLevels, sSums, sLast); + + bool failed = false; + for (u64 i = 0; i < pntCount; ++i) + { + // the index of the leaf node that is active. + auto leafIdx = points[i]; + //std::cout << leafIdx << std::endl; + for (u64 d = 1; d <= depth; ++d) { + //u64 width = std::min(domain, 1ull << d); + auto width = divCeil(domain, 1ull << (depth - d)); - auto exp = sOut(i, j); - if (points[j] == i) - exp = exp ^ CCBlock; + // The index of the active child node. + auto activeChildIdx = leafIdx >> (depth - d); - if (neq(exp, rOut(i, j))) + // The index of the active child node sibling. + + for (u64 j = 0; j < width; ++j) { - failed = true; + //std::cout + // << " " << sLevels[d][j][i].get()[0] + // << " " << rLevels[d][j][i].get()[0] + // << ", "; - if (cmd.isSet("v")) - std::cout << Color::Red; + if (j == activeChildIdx) + { + //std::cout << "*"; + continue; + } + + if (sLevels[d][j][i] != rLevels[d][j][i]) + { + //std::cout << " < "; + failed = true; + } + } + //std::cout << std::endl; } - if (cmd.isSet("v")) - std::cout << "r[" << j << "][" << i << "] " << exp << " " << rOut(i, j) << std::endl << Color::Default; } - } if (failed) throw RTE_LOC; - #else throw UnitTestSkipped("ENABLE_SILENTOT not defined."); #endif -} + } -void Tools_Pprf_trans_test(const CLP& cmd) +void Tools_Pprf_test(const CLP& cmd) { #if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - //u64 depth = 6; - //u64 domain = 13;// (1ull << depth) - 7; - //u64 numPoints = 40; - - u64 domain = cmd.getOr("d", 334); + u64 depth = cmd.getOr("d", 3);; + u64 domain = 1ull << depth; auto threads = cmd.getOr("t", 3ull); - u64 numPoints = cmd.getOr("s", 5) * 8; - //bool mal = cmd.isSet("mal"); + u64 numPoints = cmd.getOr("s", 8); PRNG prng(ZeroBlock); - - auto sockets = cp::LocalAsyncSocket::makePair(); - + //IOService ios; + //Session s0(ios, "localhost:1212", SessionMode::Server); + //Session s1(ios, "localhost:1212", SessionMode::Client); + //auto sockets[0] = s0.addChannel(); + //auto sockets[1] = s1.addChannel(); + auto sockets = cp::LocalAsyncSocket::makePair(); - auto format = PprfOutputFormat::InterleavedTransposed; + auto format = PprfOutputFormat::ByLeafIndex; SilentMultiPprfSender sender; SilentMultiPprfReceiver recver; @@ -959,11 +985,12 @@ void Tools_Pprf_trans_test(const CLP& cmd) auto numOTs = sender.baseOtCount(); std::vector> sendOTs(numOTs); std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); - //recvBits.randomize(prng); + BitVector recvBits = recver.sampleChoiceBits(domain, format, prng); - //recvBits[16] = 1; prng.get(sendOTs.data(), sendOTs.size()); + //sendOTs[cmd.getOr("i",0)] = prng.get(); + + //recvBits[16] = 1; for (u64 i = 0; i < numOTs; ++i) { //recvBits[i] = 0; @@ -972,57 +999,39 @@ void Tools_Pprf_trans_test(const CLP& cmd) sender.setBase(sendOTs); recver.setBase(recvOTs); - auto cols = (numPoints * domain + 127) / 128; - Matrix sOut(128, cols); - Matrix rOut(128, cols); - + Matrix sOut(domain, numPoints); + Matrix rOut(domain, numPoints); std::vector points(numPoints); recver.getPoints(points, format); - - - - auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng, sOut, format, true, threads); - auto p1 = recver.expand(sockets[1], prng, rOut, format, true, threads); - + auto p0 = sender.expand(sockets[0], { &CCBlock,1 }, prng.get(), sOut, format, true, threads); + auto p1 = recver.expand(sockets[1], rOut, format, true, threads); eval(p0, p1); - bool failed = false; - Matrix out(128, cols); - Matrix outT(numPoints * domain, 1); - - if (cmd.getOr("v", 0) > 1) - std::cout << sender.mDomain << " " << sender.mPntCount << - " " << sOut.rows() << " " << sOut.cols() << std::endl; + bool failed = false; - for (u64 i = 0; i < cols; ++i) - { - for (u64 j = 0; j < 128; ++j) - { - out(j, i) = (sOut(j, i) ^ rOut(j, i)); - //if (cmd.isSet("v")) - // std::cout << "r[" << i << "][" << j << "] " << out(j,i) << " ~ " << rOut(j, i) << std::endl << Color::Default; - } - } - transpose(MatrixView(out), MatrixView(outT)); - for (u64 i = 0; i < outT.rows(); ++i) + for (u64 j = 0; j < numPoints; ++j) { - auto f = std::find(points.begin(), points.end(), i) != points.end(); + for (u64 i = 0; i < domain; ++i) + { - auto exp = f ? AllOneBlock : ZeroBlock; + auto exp = sOut(i, j); + if (points[j] == i) + exp = exp ^ CCBlock; - if (neq(outT(i), exp)) - { - failed = true; + if (neq(exp, rOut(i, j))) + { + failed = true; - if (cmd.getOr("v", 0) > 1) - std::cout << Color::Red; + if (cmd.isSet("v")) + std::cout << Color::Red; + } + if (cmd.isSet("v")) + std::cout << "r[" << j << "][" << i << "] " << exp << " " << rOut(i, j) << std::endl << Color::Default; } - if (cmd.getOr("v", 0) > 1) - std::cout << i << " " << outT(i) << " " << exp << std::endl << Color::Default; } if (failed) @@ -1033,7 +1042,6 @@ void Tools_Pprf_trans_test(const CLP& cmd) #endif } - void Tools_Pprf_inter_test(const CLP& cmd) { #if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) @@ -1082,10 +1090,24 @@ void Tools_Pprf_inter_test(const CLP& cmd) recver.getPoints(points, format); - auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng, sOut2, format, true, threads); - auto p1 = recver.expand(sockets[1], prng, rOut2, format, true, threads); + auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng.get(), sOut2, format, true, threads); + auto p1 = recver.expand(sockets[1], rOut2, format, true, threads); - eval(p0, p1); + try + { + + eval(p0, p1); + } + catch (std::exception& e) + { + sockets[0].close(); + sockets[1].close(); + macoro::sync_wait(macoro::when_all_ready( + sockets[0].flush(), + sockets[1].flush() + )); + throw; + } for (u64 i = 0; i < rOut2.rows(); ++i) { sOut2(i) = (sOut2(i) ^ rOut2(i)); @@ -1137,7 +1159,7 @@ void Tools_Pprf_blockTrans_test(const oc::CLP& cmd) auto sockets = cp::LocalAsyncSocket::makePair(); - auto format = PprfOutputFormat::BlockTransposed; + auto format = PprfOutputFormat::ByTreeIndex; SilentMultiPprfSender sender; SilentMultiPprfReceiver recver; @@ -1167,8 +1189,8 @@ void Tools_Pprf_blockTrans_test(const oc::CLP& cmd) recver.getPoints(points, format); cp::sync_wait(cp::when_all_ready( - sender.expand(sockets[0], span{}, prng, sOut, format, false, threads), - recver.expand(sockets[1], prng, rOut, format, false, threads) + sender.expand(sockets[0], span{}, prng.get(), sOut, format, false, threads), + recver.expand(sockets[1], rOut, format, false, threads) )); bool failed = false; @@ -1194,7 +1216,7 @@ void Tools_Pprf_blockTrans_test(const oc::CLP& cmd) } else { - if (ss != rr || rr == ZeroBlock) + if (ss != rr || rr == ZeroBlock) { failed = true; @@ -1256,23 +1278,23 @@ void Tools_Pprf_callback_test(const oc::CLP& cmd) recver.getPoints(points, format); sender.mOutputFn = [&](u64 treeIdx, span> data) - { - span d = sOut2; - d = d.subspan(treeIdx * data.size()); - d = d.subspan(0, std::min(d.size(), data.size() * 8)); - memcpy(d.data(), data.data(), d.size_bytes()); - }; + { + span d = sOut2; + d = d.subspan(treeIdx * data.size()); + d = d.subspan(0, std::min(d.size(), data.size() * 8)); + memcpy(d.data(), data.data(), d.size_bytes()); + }; recver.mOutputFn = [&](u64 treeIdx, span> data) - { - span d = rOut2; - d = d.subspan(treeIdx * data.size()); - d = d.subspan(0, std::min(d.size(), data.size() * 8)); - memcpy(d.data(), data.data(), d.size_bytes()); - }; + { + span d = rOut2; + d = d.subspan(treeIdx * data.size()); + d = d.subspan(0, std::min(d.size(), data.size() * 8)); + memcpy(d.data(), data.data(), d.size_bytes()); + }; - auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng, span{}, format, true, threads); - auto p1 = recver.expand(sockets[1], prng, span{}, format, true, threads); + auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng.get(), span{}, format, true, threads); + auto p1 = recver.expand(sockets[1], span{}, format, true, threads); eval(p0, p1); for (u64 i = 0; i < rOut2.rows(); ++i) @@ -1297,7 +1319,7 @@ void Tools_Pprf_callback_test(const oc::CLP& cmd) } if (cmd.getOr("v", 0) > 1) std::cout << i << " " << sOut2(i) << " " << exp << std::endl << Color::Default; - } + } if (failed) throw RTE_LOC; @@ -1305,4 +1327,4 @@ void Tools_Pprf_callback_test(const oc::CLP& cmd) #else throw UnitTestSkipped("ENABLE_SILENTOT not defined."); #endif -} + } diff --git a/libOTe_Tests/SilentOT_Tests.h b/libOTe_Tests/SilentOT_Tests.h index 83b5f2e5..d93e09be 100644 --- a/libOTe_Tests/SilentOT_Tests.h +++ b/libOTe_Tests/SilentOT_Tests.h @@ -9,6 +9,7 @@ #include +void Tools_Pprf_expandOne_test(const oc::CLP& cmd); void Tools_Pprf_test(const oc::CLP& cmd); void Tools_Pprf_trans_test(const oc::CLP& cmd); void Tools_Pprf_inter_test(const oc::CLP& cmd); diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 8279aed5..1c1e72a0 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -56,9 +56,9 @@ namespace tests_libOTe tc.add("EACode_encode_basic_test ", EACode_encode_basic_test); tc.add("ExConvCode_encode_basic_test ", ExConvCode_encode_basic_test); - + + tc.add("Tools_Pprf_expandOne_test ", Tools_Pprf_expandOne_test); tc.add("Tools_Pprf_test ", Tools_Pprf_test); - tc.add("Tools_Pprf_trans_test ", Tools_Pprf_trans_test); tc.add("Tools_Pprf_inter_test ", Tools_Pprf_inter_test); tc.add("Tools_Pprf_blockTrans_test ", Tools_Pprf_blockTrans_test); tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index 284f2668..b287607d 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -395,16 +395,18 @@ inline u64 eval( macoro::task<>& t1, macoro::task<>& t0, cp::BufferingSocket& s1, cp::BufferingSocket& s0) { + std::cout << "begin " << std::endl; auto e = macoro::make_eager(macoro::when_all_ready(std::move(t0), std::move(t1))); u64 rounds = 0; - { auto b1 = s1.getOutbound(); if (b1) { s0.processInbound(*b1); ++rounds; + + std::cout << "round " << rounds << std::endl; } } @@ -429,6 +431,7 @@ inline u64 eval( ++rounds; + std::cout << "round " << rounds << std::endl; ++idx; } @@ -476,11 +479,16 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) std::vector c(n), z0(n), z1(n); auto p0 = recv.silentReceive(c, z0, prng, chls[0]); auto p1 = send.silentSend(x, z1, prng, chls[1]); - +#if (defined ENABLE_MRR_TWIST && defined ENABLE_SSE) || \ + defined ENABLE_MR || defined ENABLE_MRR + u64 expRound = 3; +#else + u64 expRound = 5; +#endif auto rounds = eval(p0, p1, chls[1], chls[0]); - if (rounds != 3) - throw std::runtime_error(std::to_string(rounds) + "!=3. " +COPROTO_LOCATION); + if (rounds != expRound) + throw std::runtime_error(std::to_string(rounds) + "!="+std::to_string(expRound)+". " +COPROTO_LOCATION); for (u64 i = 0; i < n; ++i)