From 7e82275d5c244a06ea98e5e3da49fa6f8ff0cdb2 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Fri, 25 Oct 2024 12:06:39 -0700 Subject: [PATCH] minor co_await changes --- libOTe/Tools/Pprf/RegularPprf.h | 2 ++ libOTe/Vole/Noisy/NoisyVoleSender.h | 6 ++--- libOTe/Vole/Silent/SilentVoleReceiver.h | 30 ++++++++++++------------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/libOTe/Tools/Pprf/RegularPprf.h b/libOTe/Tools/Pprf/RegularPprf.h index 8e2bd6a..f19f4c1 100644 --- a/libOTe/Tools/Pprf/RegularPprf.h +++ b/libOTe/Tools/Pprf/RegularPprf.h @@ -647,6 +647,8 @@ namespace osuCrypto // The OTs are used in blocks of 8, so make sure that there is a whole // number of blocks. mBaseOTs.resize(roundUpTo(mPntCount, 8), mDepth); + if (mBaseOTs.size() < baseMessages.size()) + throw RTE_LOC; memcpy(mBaseOTs.data(), baseMessages.data(), baseMessages.size() * sizeof(block)); } diff --git a/libOTe/Vole/Noisy/NoisyVoleSender.h b/libOTe/Vole/Noisy/NoisyVoleSender.h index 3297884..d1d12d2 100644 --- a/libOTe/Vole/Noisy/NoisyVoleSender.h +++ b/libOTe/Vole/Noisy/NoisyVoleSender.h @@ -63,10 +63,10 @@ namespace osuCrypto { setTimePoint("NoisyVoleSender.ot.begin"); - co_await(ot.receive(bv, otMsg, prng, chl)); + co_await ot.receive(bv, otMsg, prng, chl); setTimePoint("NoisyVoleSender.ot.end"); - co_await(send(delta, b, prng, otMsg, chl, ctx)); + co_await send(delta, b, prng, otMsg, chl, ctx); } MACORO_CATCH(eptr) { co_await chl.close(); @@ -100,7 +100,7 @@ namespace osuCrypto { // receive the the excrypted one shares. buffer.resize(xb.size() * b.size() * ctx.template byteSize()); - co_await(chl.recv(buffer)); + co_await chl.recv(buffer); ctx.resize(msg, xb.size() * b.size()); ctx.deserialize(buffer.begin(), buffer.end(), msg.begin()); diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index ff1616f..db80634 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -196,7 +196,7 @@ namespace osuCrypto bb.randomize(prng); choice.append(bb); - co_await(mOtExtRecver->receive(choice, msg, prng, chl)); + co_await mOtExtRecver->receive(choice, msg, prng, chl); mOtExtSender->setBaseOts( span(msg).subspan( @@ -205,18 +205,18 @@ namespace osuCrypto bb); msg.resize(msg.size() - mOtExtSender->baseOtCount()); - co_await(nv.receive(noiseVals, baseAs, prng, *mOtExtSender, chl, mCtx)); + co_await nv.receive(noiseVals, baseAs, prng, *mOtExtSender, chl, mCtx); } else { auto chl2 = chl.fork(); auto prng2 = prng.fork(); - co_await( + co_await macoro::when_all_ready( nv.receive(noiseVals, baseAs, prng2, *mOtExtSender, chl2, mCtx), mOtExtRecver->receive(choice, msg, prng, chl) - )); + ); } #else throw std::runtime_error("soft spoken must be enabled"); @@ -228,11 +228,11 @@ namespace osuCrypto auto prng2 = prng.fork(); BaseOT baseOt; - co_await( + co_await macoro::when_all_ready( baseOt.receive(choice, msg, prng, chl), nv.receive(noiseVals, baseAs, prng2, baseOt, chl2, mCtx)) - ); + ; } @@ -400,7 +400,7 @@ namespace osuCrypto if (c.size() != a.size()) throw std::runtime_error("input sizes do not match." LOCATION); - co_await(silentReceiveInplace(c.size(), prng, chl)); + co_await silentReceiveInplace(c.size(), prng, chl); mCtx.copy(mC.begin(), mC.begin() + c.size(), c.begin()); mCtx.copy(mA.begin(), mA.begin() + a.size(), a.begin()); @@ -437,7 +437,7 @@ namespace osuCrypto if (hasSilentBaseOts() == false) { - co_await(genSilentBaseOts(prng, chl)); + co_await genSilentBaseOts(prng, chl); } // allocate mA @@ -473,7 +473,7 @@ namespace osuCrypto // // mA = mB + mS(mBaseC * mDelta) // - co_await(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); + co_await mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads); setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); @@ -488,14 +488,14 @@ namespace osuCrypto if (mDebug) { - co_await(checkRT(chl)); + co_await checkRT(chl); setTimePoint("SilentVoleReceiver.expand.checkRT"); } if (mMalType == SilentSecType::Malicious) { - co_await(chl.send(std::move(mMalCheckSeed))); + co_await chl.send(std::move(mMalCheckSeed)); if constexpr (MaliciousSupported) myHash = ferretMalCheck(); @@ -503,7 +503,7 @@ namespace osuCrypto throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); } - co_await(chl.recv(theirHash)); + co_await chl.recv(theirHash); if (theirHash != myHash) { @@ -601,19 +601,19 @@ namespace osuCrypto // recv delta buffer.resize(mCtx.template byteSize()); mCtx.resize(delta, 1); - co_await(chl.recv(buffer)); + co_await chl.recv(buffer); mCtx.deserialize(buffer.begin(), buffer.end(), delta.begin()); // recv B buffer.resize(mCtx.template byteSize() * mA.size()); mCtx.resize(B, mA.size()); - co_await(chl.recv(buffer)); + co_await chl.recv(buffer); mCtx.deserialize(buffer.begin(), buffer.end(), B.begin()); // recv the noisy values. buffer.resize(mCtx.template byteSize() * mBaseA.size()); mCtx.resize(baseB, mBaseA.size()); - co_await(chl.recvResize(buffer)); + co_await chl.recvResize(buffer); mCtx.deserialize(buffer.begin(), buffer.end(), baseB.begin()); // it shoudl hold that