From a7619c9609611590aa05c641032a014d67b691a6 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 29 Apr 2024 18:07:36 -0700 Subject: [PATCH] ci debug asan --- .../TwoChooseOne/Iknp/IknpOtExtReceiver.cpp | 384 +++++++++--------- libOTe/TwoChooseOne/Iknp/IknpOtExtSender.cpp | 94 +++-- 2 files changed, 230 insertions(+), 248 deletions(-) diff --git a/libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.cpp b/libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.cpp index 5ef5803..c222168 100644 --- a/libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.cpp @@ -11,215 +11,199 @@ namespace osuCrypto { - void IknpOtExtReceiver::setBaseOts(span> baseOTs) - { - if (baseOTs.size() != gOtExtBaseOtCount) - throw std::runtime_error(LOCATION); - - for (u64 j = 0; j < 2; ++j) - { - block buff[gOtExtBaseOtCount]; - for (u64 i = 0; i < gOtExtBaseOtCount; i++) - buff[i] = baseOTs[i][j]; - - mGens[j].setKeys(buff); - } - - mHasBase = true; - } - - - IknpOtExtReceiver IknpOtExtReceiver::splitBase() - { - std::array, gOtExtBaseOtCount> baseRecvOts; - - if (!hasBaseOts()) - throw std::runtime_error("base OTs have not been set. " LOCATION); - - for (u64 j = 0; j < 2; ++j) - { - block buff[gOtExtBaseOtCount]; - mGens[j].ecbEncCounterMode(mPrngIdx, buff); - for (u64 i = 0; i < gOtExtBaseOtCount; ++i) - { - baseRecvOts[i][j] = buff[i]; - } - } - ++mPrngIdx; - - return IknpOtExtReceiver(baseRecvOts); - } - - - std::unique_ptr IknpOtExtReceiver::split() - { - return std::make_unique(splitBase()); - } - - - - task<> IknpOtExtReceiver::receive( - const BitVector& choices, - span messages, - PRNG& prng, - Socket& chl) - { - auto numOtExt = u64{}; - auto numSuperBlocks = u64{}; - auto numBlocks = u64{}; - auto superBlkIdx = u64{}; - auto step = u64{}; - auto choices2 = BitVector{}; - auto choiceBlocks = span{}; - // this will be used as temporary buffers of 128 columns, - // each containing 1024 bits. Once transposed, they will be copies - // into the T1, T0 buffers for long term storage - auto t0 = AlignedUnVector{ 128 }; - auto mIter = (block*)nullptr; - auto uIter = (block*)nullptr; - auto tIter = (block*)nullptr; - auto cIter = (block*)nullptr; - auto uEnd = (block*)nullptr; - auto uBuff = AlignedUnVector{}; - - if (choices.size() != messages.size()) - throw RTE_LOC; - - if (hasBaseOts() == false) - co_await genBaseOts(prng, chl); - - // we are going to process OTs in blocks of 128 * superBlkSize messages. - numOtExt = roundUpTo(choices.size(), 128); - numSuperBlocks = (numOtExt / 128); - numBlocks = numSuperBlocks; - - choices2.resize(numBlocks * 128); - choices2 = choices; - choices2.resize(numBlocks * 128); - - choiceBlocks = choices2.getSpan(); - - // the index of the OT that has been completed. - //u64 doneIdx = 0; - - mIter = messages.data(); - - step = std::min(numSuperBlocks, (u64)commStepSize); - uBuff.resize(step * 128); - - // get an array of blocks that we will fill. - uIter = (block*)uBuff.data(); - uEnd = uIter + uBuff.size(); - - // NOTE: We do not transpose a bit-matrix of size numCol * numCol. - // Instead we break it down into smaller chunks. We do 128 columns - // times 8 * 128 rows at a time, where 8 = superBlkSize. This is done for - // performance reasons. The reason for 8 is that most CPUs have 8 AES vector - // lanes, and so its more efficient to encrypt (aka prng) 8 blocks at a time. - // So that's what we do. - for (superBlkIdx = 0; superBlkIdx < numSuperBlocks; ++superBlkIdx) - { - - // this will store the next 128 rows of the matrix u - - tIter = (block*)t0.data(); - cIter = choiceBlocks.data() + superBlkIdx; - - mGens[0].ecbEncCounterMode(mPrngIdx, tIter); - mGens[1].ecbEncCounterMode(mPrngIdx, uIter); - ++mPrngIdx; - - for (u64 colIdx = 0; colIdx < 128 / 8; ++colIdx) - { - uIter[0] = uIter[0] ^ cIter[0]; - uIter[1] = uIter[1] ^ cIter[0]; - uIter[2] = uIter[2] ^ cIter[0]; - uIter[3] = uIter[3] ^ cIter[0]; - uIter[4] = uIter[4] ^ cIter[0]; - uIter[5] = uIter[5] ^ cIter[0]; - uIter[6] = uIter[6] ^ cIter[0]; - uIter[7] = uIter[7] ^ cIter[0]; - - uIter[0] = uIter[0] ^ tIter[0]; - uIter[1] = uIter[1] ^ tIter[1]; - uIter[2] = uIter[2] ^ tIter[2]; - uIter[3] = uIter[3] ^ tIter[3]; - uIter[4] = uIter[4] ^ tIter[4]; - uIter[5] = uIter[5] ^ tIter[5]; - uIter[6] = uIter[6] ^ tIter[6]; - uIter[7] = uIter[7] ^ tIter[7]; - - uIter += 8; - tIter += 8; - } - - if (uIter == uEnd) - { - // send over u buffer - co_await chl.send(std::move(uBuff)); - - u64 step = std::min(numSuperBlocks - superBlkIdx - 1, (u64)commStepSize); - - if (step) - { - uBuff.resize(step * 128); - uIter = (block*)uBuff.data(); - uEnd = uIter + uBuff.size(); - } - } - - // transpose our 128 columns of 1024 bits. We will have 1024 rows, - // each 128 bits wide. - transpose128(t0.data()); - - - auto mEnd = mIter + std::min(128, messages.data() + messages.size() - mIter); - - - tIter = t0.data(); - - memcpy(mIter, tIter, (mEnd - mIter) * sizeof(block)); - mIter = mEnd; + void IknpOtExtReceiver::setBaseOts(span> baseOTs) + { + if (baseOTs.size() != gOtExtBaseOtCount) + throw std::runtime_error(LOCATION); + + for (u64 j = 0; j < 2; ++j) + { + block buff[gOtExtBaseOtCount]; + for (u64 i = 0; i < gOtExtBaseOtCount; i++) + buff[i] = baseOTs[i][j]; + + mGens[j].setKeys(buff); + } + + mHasBase = true; + } + + + IknpOtExtReceiver IknpOtExtReceiver::splitBase() + { + std::array, gOtExtBaseOtCount> baseRecvOts; + + if (!hasBaseOts()) + throw std::runtime_error("base OTs have not been set. " LOCATION); + + for (u64 j = 0; j < 2; ++j) + { + block buff[gOtExtBaseOtCount]; + mGens[j].ecbEncCounterMode(mPrngIdx, buff); + for (u64 i = 0; i < gOtExtBaseOtCount; ++i) + { + baseRecvOts[i][j] = buff[i]; + } + } + ++mPrngIdx; + + return IknpOtExtReceiver(baseRecvOts); + } + + + std::unique_ptr IknpOtExtReceiver::split() + { + return std::make_unique(splitBase()); + } + + + + task<> IknpOtExtReceiver::receive( + const BitVector& choices, + span messages, + PRNG& prng, + Socket& chl) + { + + // this will be used as temporary buffers of 128 columns, + // each containing 1024 bits. Once transposed, they will be copies + // into the T1, T0 buffers for long term storage + { + auto t0 = AlignedUnVector{ 128 }; + { + + if (choices.size() != messages.size()) + throw RTE_LOC; + + if (hasBaseOts() == false) + co_await genBaseOts(prng, chl); + + // we are going to process OTs in blocks of 128 * superBlkSize messages. + auto uBuff = AlignedUnVector{}; + auto numOtExt = roundUpTo(choices.size(), 128); + auto numBlocks = (numOtExt / 128); + + block* cIter = choices.blocks(); + block* mIter = messages.data(); + block* uIter = 0; + block* uEnd = 0; + + // NOTE: We do not transpose a bit-matrix of size numCol * numCol. + // Instead we break it down into smaller chunks. We do 128 columns + // times 8 * 128 rows at a time, where 8 = superBlkSize. This is done for + // performance reasons. The reason for 8 is that most CPUs have 8 AES vector + // lanes, and so its more efficient to encrypt (aka prng) 8 blocks at a time. + // So that's what we do. + for (auto i = 0; i < numBlocks; ++i) + { + if (uIter == uEnd) + { + u64 step = std::min(numBlocks - i, (u64)commStepSize); + uBuff.resize(step * 128); + uIter = uBuff.data(); + uEnd = uIter + uBuff.size(); + } + + // this will store the next 128 rows of the matrix u + auto tIter = t0.data(); + mGens[0].ecbEncCounterMode(mPrngIdx, tIter); + mGens[1].ecbEncCounterMode(mPrngIdx, uIter); + ++mPrngIdx; + + for (u64 colIdx = 0; colIdx < 128 / 8; ++colIdx) + { + uIter[0] = uIter[0] ^ cIter[0]; + uIter[1] = uIter[1] ^ cIter[0]; + uIter[2] = uIter[2] ^ cIter[0]; + uIter[3] = uIter[3] ^ cIter[0]; + uIter[4] = uIter[4] ^ cIter[0]; + uIter[5] = uIter[5] ^ cIter[0]; + uIter[6] = uIter[6] ^ cIter[0]; + uIter[7] = uIter[7] ^ cIter[0]; + + uIter[0] = uIter[0] ^ tIter[0]; + uIter[1] = uIter[1] ^ tIter[1]; + uIter[2] = uIter[2] ^ tIter[2]; + uIter[3] = uIter[3] ^ tIter[3]; + uIter[4] = uIter[4] ^ tIter[4]; + uIter[5] = uIter[5] ^ tIter[5]; + uIter[6] = uIter[6] ^ tIter[6]; + uIter[7] = uIter[7] ^ tIter[7]; + + uIter += 8; + tIter += 8; + } + ++cIter; + + if (uIter == uEnd) + { + // send over u buffer + co_await chl.send(std::move(uBuff)); + } + + // transpose our 128 columns of 1024 bits. We will have 1024 rows, + // each 128 bits wide. + assert(t0.size() == 128); + transpose128(t0.data()); + + + tIter = t0.data(); + + auto size = std::min(128, messages.data() + messages.size() - mIter); + auto mEnd = mIter + size; + assert(mEnd <= messages.data() + messages.size()); + assert(tIter + size <= t0.data() + t0.size()); + + memcpy(mIter, tIter, size * sizeof(block)); + mIter = mEnd; #ifdef IKNP_DEBUG - ... fix this - u64 doneIdx = mStart - messages.data(); - block* msgIter = messages.data() + doneIdx; - chl.send(msgIter, sizeof(block) * 128 * superBlkSize); - cIter = choiceBlocks.data() + superBlkSize * superBlkIdx; - chl.send(cIter, sizeof(block) * superBlkSize); + ... fix this; + u64 doneIdx = mStart - messages.data(); + block* msgIter = messages.data() + doneIdx; + chl.send(msgIter, sizeof(block) * 128 * superBlkSize); + cIter = choiceBlocks.data() + superBlkSize * superBlkIdx; + chl.send(cIter, sizeof(block) * superBlkSize); #endif - } + } + + assert(cIter == choices.blocks() + choices.sizeBlocks()); + assert(uIter == uEnd); + assert(uBuff.size() == 0); + assert(mIter == messages.data() + messages.size()); - if (mHash) - { + if (mHash) + { #ifdef IKNP_SHA_HASH - RandomOracle sha; - u8 hashBuff[20]; - u64 doneIdx = (0); - - u64 bb = (messages.size() + 127) / 128; - for (u64 blockIdx = 0; blockIdx < bb; ++blockIdx) - { - u64 stop = std::min(messages.size(), doneIdx + 128); - - for (u64 i = 0; doneIdx < stop; ++doneIdx, ++i) - { - // hash it - sha.Reset(); - sha.Update((u8*)&messages[doneIdx], sizeof(block)); - sha.Final(hashBuff); - messages[doneIdx] = *(block*)hashBuff; - } - } + RandomOracle sha; + u8 hashBuff[20]; + u64 doneIdx = (0); + + u64 bb = (messages.size() + 127) / 128; + for (u64 blockIdx = 0; blockIdx < bb; ++blockIdx) + { + u64 stop = std::min(messages.size(), doneIdx + 128); + + for (u64 i = 0; doneIdx < stop; ++doneIdx, ++i) + { + // hash it + sha.Reset(); + sha.Update((u8*)&messages[doneIdx], sizeof(block)); + sha.Final(hashBuff); + messages[doneIdx] = *(block*)hashBuff; + } + } #else - mAesFixedKey.hashBlocks(messages.data(), messages.size(), messages.data()); + mAesFixedKey.hashBlocks(messages.data(), messages.size(), messages.data()); #endif - } - static_assert(gOtExtBaseOtCount == 128, "expecting 128"); - } + } + static_assert(gOtExtBaseOtCount == 128, "expecting 128"); + + } + } + } } #endif \ No newline at end of file diff --git a/libOTe/TwoChooseOne/Iknp/IknpOtExtSender.cpp b/libOTe/TwoChooseOne/Iknp/IknpOtExtSender.cpp index b4113b1..15f5b5c 100644 --- a/libOTe/TwoChooseOne/Iknp/IknpOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Iknp/IknpOtExtSender.cpp @@ -40,10 +40,6 @@ namespace osuCrypto PRNG& prng, Socket& chl) { - auto numOtExt = u64{}; - auto numSuperBlocks = u64{}; - auto step = u64{}; - auto superBlkIdx = u64{}; // a temp that will be used to transpose the sender's matrix auto t = AlignedUnVector{ 128 }; @@ -53,24 +49,14 @@ namespace osuCrypto auto choiceMask = AlignedUnVector{ 128 }; { - auto delta = block{}; - auto recvView = span{}; - auto mIter = span>::iterator{}; - auto uIter = (block*)nullptr; - auto tIter = (block*)nullptr; - auto cIter = (block*)nullptr; - auto uEnd = (block*)nullptr; - if (hasBaseOts() == false) co_await genBaseOts(prng, chl); // round up - numOtExt = roundUpTo(messages.size(), 128); - numSuperBlocks = (numOtExt / 128); - //u64 numBlocks = numSuperBlocks * superBlkSize; - + auto numBlocks = divCeil(messages.size(), 128); - delta = *(block*)mBaseChoiceBits.data(); + assert(mBaseChoiceBits.size() == 128); + auto delta = *(block*)mBaseChoiceBits.data(); for (u64 i = 0; i < 128; ++i) { @@ -78,23 +64,23 @@ namespace osuCrypto else choiceMask[i] = ZeroBlock; } - mIter = messages.begin(); - uEnd = u.data() + u.size(); - uIter = uEnd; + auto mIter = messages.data(); + block* uEnd = 0; + block* uIter = 0; - for (superBlkIdx = 0; superBlkIdx < numSuperBlocks; ++superBlkIdx) + for (auto i = 0ull; i < numBlocks; ++i) { - tIter = (block*)t.data(); - cIter = choiceMask.data(); + auto tIter = t.data(); + auto cIter = choiceMask.data(); if (uIter == uEnd) { - step = std::min(numSuperBlocks - superBlkIdx, (u64)commStepSize); - step *= 128 * sizeof(block); - recvView = span((u8*)u.data(), step); + auto step = std::min(numBlocks - i, (u64)commStepSize) * 128; + u.resize(step); uIter = u.data(); + uEnd = uIter + u.size(); - co_await(chl.recv(recvView)); + co_await(chl.recv(u)); } mGens.ecbEncCounterMode(mPrngIdx, tIter); @@ -126,41 +112,47 @@ namespace osuCrypto tIter += 8; } + assert(cIter == choiceMask.data() + choiceMask.size()); + assert(tIter == t.data() + t.size()); + // transpose our 128 columns of 1024 bits. We will have 1024 rows, // each 128 bits wide. + assert(t.size() == 128); transpose128(t.data()); - auto mEnd = mIter + std::min(128, messages.end() - mIter); - + auto size = std::min(128, messages.data() + messages.size() - mIter); tIter = t.data(); - if (mEnd - mIter == 128) + + if (size == 128) { for (u64 i = 0; i < 128; i += 8) { - mIter[i + 0][0] = tIter[i + 0]; - mIter[i + 1][0] = tIter[i + 1]; - mIter[i + 2][0] = tIter[i + 2]; - mIter[i + 3][0] = tIter[i + 3]; - mIter[i + 4][0] = tIter[i + 4]; - mIter[i + 5][0] = tIter[i + 5]; - mIter[i + 6][0] = tIter[i + 6]; - mIter[i + 7][0] = tIter[i + 7]; - mIter[i + 0][1] = tIter[i + 0] ^ delta; - mIter[i + 1][1] = tIter[i + 1] ^ delta; - mIter[i + 2][1] = tIter[i + 2] ^ delta; - mIter[i + 3][1] = tIter[i + 3] ^ delta; - mIter[i + 4][1] = tIter[i + 4] ^ delta; - mIter[i + 5][1] = tIter[i + 5] ^ delta; - mIter[i + 6][1] = tIter[i + 6] ^ delta; - mIter[i + 7][1] = tIter[i + 7] ^ delta; - + mIter[0][0] = tIter[0]; + mIter[1][0] = tIter[1]; + mIter[2][0] = tIter[2]; + mIter[3][0] = tIter[3]; + mIter[4][0] = tIter[4]; + mIter[5][0] = tIter[5]; + mIter[6][0] = tIter[6]; + mIter[7][0] = tIter[7]; + mIter[0][1] = tIter[0] ^ delta; + mIter[1][1] = tIter[1] ^ delta; + mIter[2][1] = tIter[2] ^ delta; + mIter[3][1] = tIter[3] ^ delta; + mIter[4][1] = tIter[4] ^ delta; + mIter[5][1] = tIter[5] ^ delta; + mIter[6][1] = tIter[6] ^ delta; + mIter[7][1] = tIter[7] ^ delta; + + tIter += 8; + mIter += 8; } - mIter += 128; } else { + auto mEnd = mIter + size; while (mIter != mEnd) { (*mIter)[0] = *tIter; @@ -171,6 +163,9 @@ namespace osuCrypto } } + assert(tIter == t.data() + size); + + #ifdef IKNP_DEBUG fix this... BitVector choice(128 * superBlkSize); @@ -191,6 +186,9 @@ namespace osuCrypto #endif } + assert(uIter == uEnd); + assert(mIter == messages.data() + messages.size()); + if (mHash) {