diff --git a/CMakePresets.json b/CMakePresets.json index 3f921bdd..6fe20218 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -13,8 +13,14 @@ "ENABLE_ALL_OT": true, "ENABLE_SSE": true, "ENABLE_AVX": true, + "ENABLE_BOOST": true, "ENABLE_BITPOLYMUL": false, "ENABLE_CIRCUITS": true, + "ENABLE_SIMPLESTOT": true, + "ENABLE_MRR": true, + "ENABLE_MR": true, + "ENABLE_SIMPLESTOT": true, + "ENABLE_RELIC": true, "LIBOTE_STD_VER": "17", "CMAKE_PREFIX_PATH": "${sourceDir}/../out/install", "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}" @@ -43,23 +49,19 @@ "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", "ENABLE_INSECURE_SILVER": false, - "ENABLE_LDPC": false, + "ENABLE_PPRF": true, + "ENABLE_SILENT_VOLE": true, "LIBOTE_STD_VER": "17", - "ENABLE_ALL_OT": true, - "ENABLE_KKRT": "ON", - "ENABLE_IKNP": "ON", - "ENABLE_MR": "ON", - "ENABLE_SIMPLESTOT": "ON", "ENABLE_RELIC": false, "ENABLE_SODIUM": true, - "ENABLE_BOOST": false, - "ENABLE_BITPOLYMUL": true, + "ENABLE_BOOST": true, + "ENABLE_BITPOLYMUL": false, "FETCH_AUTO": "ON", "ENABLE_CIRCUITS": true, "VERBOSE_FETCH": true, "ENABLE_SSE": true, "ENABLE_AVX": true, - "ENABLE_ASAN": true, + "ENABLE_ASAN": false, "COPROTO_ENABLE_BOOST": true, "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}", "CMAKE_PREFIX_PATH": "${sourceDir}/../out/install/${presetName}" @@ -85,10 +87,10 @@ "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", "LIBOTE_STD_VER": "17", - "ENABLE_ALL_OT": false, - "ENABLE_RELIC": false, + "ENABLE_ALL_OT": true, + "ENABLE_RELIC": true, "ENABLE_SODIUM": false, - "ENABLE_BOOST": false, + "ENABLE_BOOST": true, "ENABLE_OPENSSL": false, "FETCH_AUTO": true, "ENABLE_CIRCUITS": true, diff --git a/cmake/buildOptions.cmake b/cmake/buildOptions.cmake index 4d05ed90..cefdd3c4 100644 --- a/cmake/buildOptions.cmake +++ b/cmake/buildOptions.cmake @@ -94,9 +94,8 @@ option(ENABLE_DELTA_IKNP "Build the IKNP Delta-OT-Ext protocol." OFF) option(ENABLE_OOS "Build the OOS 1-oo-N OT-Ext protocol." OFF) option(ENABLE_KKRT "Build the KKRT 1-oo-N OT-Ext protocol." OFF) +option(ENABLE_PPRF "Build the PPRF protocol." OFF) option(ENABLE_SILENT_VOLE "Build the Silent Vole protocol." OFF) -#option(COPROTO_ENABLE_BOOST "Build with coproto boost support." OFF) -#option(COPROTO_ENABLE_OPENSSL "Build with coproto boost open ssl support." OFF) option(ENABLE_INSECURE_SILVER "Build with silver codes." OFF) option(ENABLE_LDPC "Build with ldpc functions." OFF) @@ -105,21 +104,14 @@ if(ENABLE_INSECURE_SILVER) endif() option(NO_KOS_WARNING "Build with no kos security warning." OFF) -#option(FETCH_BITPOLYMUL "download and build bitpolymul" OFF)) EVAL(FETCH_BITPOLYMUL_IMPL (DEFINED FETCH_BITPOLYMUL AND FETCH_BITPOLYMUL) OR ((NOT DEFINED FETCH_BITPOLYMUL) AND (FETCH_AUTO AND ENABLE_BITPOLYMUL))) - -#option(FETCH_BITPOLYMUL "download and build bitpolymul" OFF)) -EVAL(FETCH_BITPOLYMUL_IMPL - (DEFINED FETCH_BITPOLYMUL AND FETCH_BITPOLYMUL) OR - ((NOT DEFINED FETCH_BITPOLYMUL) AND (FETCH_AUTO AND ENABLE_BITPOLYMUL))) - - - - +if(ENABLE_SILENT_VOLE OR ENABLE_SILENTOT OR ENABLE_SOFTSPOKEN_OT) + set(ENABLE_PPRF true) +endif() option(VERBOSE_FETCH "Print build info for fetched libraries" ON) @@ -159,7 +151,8 @@ message(STATUS "Option: ENABLE_KKRT = ${ENABLE_KKRT}\n\n") message(STATUS "other \n=======================================================") -message(STATUS "Option: NO_KOS_WARNING = ${NO_KOS_WARNING}\n\n") +message(STATUS "Option: NO_KOS_WARNING = ${NO_KOS_WARNING}") +message(STATUS "Option: ENABLE_PPRF = ${ENABLE_PPRF}\n\n") ############################################# # Config Checks # diff --git a/cmake/buildOptions.cmake.in b/cmake/buildOptions.cmake.in index 1c8cf600..f3329773 100644 --- a/cmake/buildOptions.cmake.in +++ b/cmake/buildOptions.cmake.in @@ -70,6 +70,7 @@ set(ENABLE_KKRT @ENABLE_KKRT@) set(ENABLE_SILENT_VOLE @ENABLE_SILENT_VOLE@) set(NO_SILVER_WARNING @NO_SILVER_WARNING@) +set(ENABLE_PPRF @ENABLE_PPRF@) set(libOTe_boost_FOUND ${ENABLE_BOOST}) set(libOTe_relic_FOUND ${ENABLE_RELIC}) diff --git a/cryptoTools b/cryptoTools index 5f90354b..0fe05285 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 5f90354b499adddbcf6861a3b4463e0724e5f719 +Subproject commit 0fe05285f4f22d520a31ed226ea757d7c3dac49c diff --git a/frontend/ExampleBase.cpp b/frontend/ExampleBase.cpp new file mode 100644 index 00000000..13b2726d --- /dev/null +++ b/frontend/ExampleBase.cpp @@ -0,0 +1,154 @@ + +#include "libOTe/Base/SimplestOT.h" +#include "libOTe/Base/McRosRoyTwist.h" +#include "libOTe/Base/McRosRoy.h" +#include "libOTe/Tools/Popf/EKEPopf.h" +#include "libOTe/Tools/Popf/MRPopf.h" +#include "libOTe/Tools/Popf/FeistelPopf.h" +#include "libOTe/Tools/Popf/FeistelMulPopf.h" +#include "libOTe/Tools/Popf/FeistelRistPopf.h" +#include "libOTe/Tools/Popf/FeistelMulRistPopf.h" +#include "libOTe/Base/MasnyRindal.h" +#include "libOTe/Base/MasnyRindalKyber.h" + +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/CLP.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" +#include "cryptoTools/Common/Timer.h" + +namespace osuCrypto +{ + + template + void baseOT_example_from_ot(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP&, BaseOT ot) + { +#ifdef COPROTO_ENABLE_BOOST + PRNG prng(sysRandomSeed()); + + if (totalOTs == 0) + totalOTs = 128; + + if (numThreads > 1) + std::cout << "multi threading for the base OT example is not implemented.\n" << std::flush; + + Timer t; + Timer::timeUnit s; + if (role == Role::Receiver) + { + auto sock = cp::asioConnect(ip, false); + BaseOT recv = ot; + + AlignedVector msg(totalOTs); + BitVector choice(totalOTs); + choice.randomize(prng); + + + s = t.setTimePoint("base OT start"); + + coproto::sync_wait(recv.receive(choice, msg, prng, sock)); + + // make sure all messages are sent. + cp::sync_wait(sock.flush()); + } + else + { + auto sock = cp::asioConnect(ip, true); + + BaseOT send = ot; + + AlignedVector> msg(totalOTs); + + s = t.setTimePoint("base OT start"); + + coproto::sync_wait(send.send(msg, prng, sock)); + + + // make sure all messages are sent. + cp::sync_wait(sock.flush()); + } + + + + auto e = t.setTimePoint("base OT end"); + auto milli = std::chrono::duration_cast(e - s).count(); + + std::cout << tag << (role == Role::Receiver ? " (receiver)" : " (sender)") + << " n=" << totalOTs << " " << milli << " ms" << std::endl; +#endif + } + + template + void baseOT_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) + { + return baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, BaseOT()); + } + + bool baseOT_examples(const CLP& cmd) + { + bool flagSet = false; + +#ifdef ENABLE_SIMPLESTOT + flagSet |= runIf(baseOT_example, cmd, simple); +#endif + +#ifdef ENABLE_SIMPLESTOT_ASM + flagSet |= runIf(baseOT_example, cmd, simpleasm); +#endif + +#ifdef ENABLE_MRR_TWIST +#ifdef ENABLE_SSE + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepEKEPopf factory; + const char* domain = "EKE POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwist(factory)); + }, cmd, moellerpopf, { "eke" }); +#endif + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepMRPopf factory; + const char* domain = "MR POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMR(factory)); + }, cmd, moellerpopf, { "mrPopf" }); + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelPopf factory; + const char* domain = "Feistel POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistFeistel(factory)); + }, cmd, moellerpopf, { "feistel" }); + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelMulPopf factory; + const char* domain = "Feistel With Multiplication POPF OT example"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMul(factory)); + }, cmd, moellerpopf, { "feistelMul" }); +#endif + +#ifdef ENABLE_MRR + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelRistPopf factory; + const char* domain = "Feistel POPF OT example (Risretto)"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoy(factory)); + }, cmd, ristrettopopf, { "feistel" }); + + flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& clp) { + DomainSepFeistelMulRistPopf factory; + const char* domain = "Feistel With Multiplication POPF OT example (Risretto)"; + factory.Update(domain, std::strlen(domain)); + baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyMul(factory)); + }, cmd, ristrettopopf, { "feistelMul" }); +#endif + +#ifdef ENABLE_MR + flagSet |= runIf(baseOT_example, cmd, mr); +#endif + + return flagSet; + } + +} diff --git a/frontend/ExampleBase.h b/frontend/ExampleBase.h index a4684405..762d3b78 100644 --- a/frontend/ExampleBase.h +++ b/frontend/ExampleBase.h @@ -14,65 +14,8 @@ #include "cryptoTools/Common/BitVector.h" #include "cryptoTools/Common/CLP.h" -#include "util.h" -#include "coproto/Socket/AsioSocket.h" namespace osuCrypto { - - template - void baseOT_example_from_ot(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP&, BaseOT ot) - { -#ifdef COPROTO_ENABLE_BOOST - PRNG prng(sysRandomSeed()); - - if (totalOTs == 0) - totalOTs = 128; - - if (numThreads > 1) - std::cout << "multi threading for the base OT example is not implemented.\n" << std::flush; - - Timer t; - Timer::timeUnit s; - if (role == Role::Receiver) - { - auto sock = cp::asioConnect(ip, false); - BaseOT recv = ot; - - AlignedVector msg(totalOTs); - BitVector choice(totalOTs); - choice.randomize(prng); - - - s = t.setTimePoint("base OT start"); - - coproto::sync_wait(recv.receive(choice, msg, prng, sock)); - - } - else - { - auto sock = cp::asioConnect(ip, true); - - BaseOT send = ot; - - AlignedVector> msg(totalOTs); - - s = t.setTimePoint("base OT start"); - - coproto::sync_wait(send.send(msg, prng, sock)); - } - - auto e = t.setTimePoint("base OT end"); - auto milli = std::chrono::duration_cast(e - s).count(); - - std::cout << tag << (role == Role::Receiver ? " (receiver)" : " (sender)") - << " n=" << totalOTs << " " << milli << " ms" << std::endl; -#endif - } - - template - void baseOT_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) - { - return baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, BaseOT()); - } + bool baseOT_examples(const CLP& clp); } diff --git a/frontend/ExampleNChooseOne.cpp b/frontend/ExampleNChooseOne.cpp new file mode 100644 index 00000000..86bd1d5f --- /dev/null +++ b/frontend/ExampleNChooseOne.cpp @@ -0,0 +1,245 @@ + +#include "cryptoTools/Common/Matrix.h" +#include "libOTe/NChooseOne/Oos/OosNcoOtReceiver.h" +#include "libOTe/NChooseOne/Oos/OosNcoOtSender.h" +#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.h" +#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.h" +#include "cryptoTools/Common/Matrix.h" +#include "libOTe/Tools/Coproto.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ + + + + auto chls = cp::LocalAsyncSocket::makePair(); + + template + void NChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP&) + { +#ifdef COPROTO_ENABLE_BOOST + const u64 step = 1024; + + if (totalOTs == 0) + totalOTs = 1 << 20; + + bool randomOT = true; + u64 numOTs = (u64)totalOTs; + auto numChosenMsgs = 256; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + //auto chl = role == Role::Sender ? chls[0] : chls[1]; + + PRNG prng(ZeroBlock);// sysRandomSeed()); + + NcoOtSender sender; + NcoOtReceiver recver; + + // all Nco Ot extenders must have configure called first. This determines + // a variety of parameters such as how many base OTs are required. + bool maliciousSecure = false; + u64 statSecParam = 40; + u64 inputBitCount = 76; // the kkrt protocol default to 128 but oos can only do 76. + + // create a lambda function that performs the computation of a single receiver thread. + auto recvRoutine = [&]() -> task<> + { + MC_BEGIN(task<>, &, + i = u64{}, min = u64{}, + recvMsgs = std::vector{}, + choices = std::vector{} + ); + + recver.configure(maliciousSecure, statSecParam, inputBitCount); + //MC_AWAIT(sync(chl, Role::Receiver)); + + if (randomOT) + { + // once configure(...) and setBaseOts(...) are called, + // we can compute many batches of OTs. First we need to tell + // the instance how many OTs we want in this batch. This is done here. + MC_AWAIT(recver.init(numOTs, prng, chl)); + + // now we can iterate over the OTs and actually retrieve the desired + // messages. However, for efficiency we will do this in steps where + // we do some computation followed by sending off data. This is more + // efficient since data will be sent in the background :). + for (i = 0; i < numOTs; ) + { + // figure out how many OTs we want to do in this step. + min = std::min(numOTs - i, step); + + //// iterate over this step. + for (u64 j = 0; j < min; ++j, ++i) + { + // For the OT index by i, we need to pick which + // one of the N OT messages that we want. For this + // example we simply pick a random one. Note only the + // first log2(N) bits of choice is considered. + block choice = prng.get(); + + // this will hold the (random) OT message of our choice + block otMessage; + + // retrieve the desired message. + recver.encode(i, &choice, &otMessage); + + // do something cool with otMessage + //otMessage; + } + + // Note that all OTs in this region must be encode. If there are some + // that you don't actually care about, then you can skip them by calling + // + // recver.zeroEncode(i); + // + + // Now that we have gotten out the OT mMessages for this step, + // we are ready to send over network some information that + // allows the sender to also compute the OT mMessages. Since we just + // encoded "min" OT mMessages, we will tell the class to send the + // next min "correction" values. + MC_AWAIT(recver.sendCorrection(chl, min)); + } + + // once all numOTs have been encoded and had their correction values sent + // we must call check. This allows to sender to make sure we did not cheat. + // For semi-honest protocols, this can and will be skipped. + MC_AWAIT(recver.check(chl, prng.get())); + } + else + { + recvMsgs.resize(numOTs); + choices.resize(numOTs); + + // define which messages the receiver should learn. + for (i = 0; i < numOTs; ++i) + choices[i] = prng.get(); + + // the messages that were learned are written to recvMsgs. + MC_AWAIT(recver.receiveChosen(numChosenMsgs, recvMsgs, choices, prng, chl)); + } + + MC_AWAIT(chl.flush()); + MC_END(); + }; + + // create a lambda function that performs the computation of a single sender thread. + auto sendRoutine = [&]() + { + MC_BEGIN(task<>, &, + sendMessages = Matrix{}, + i = u64{}, min = u64{} + ); + + sender.configure(maliciousSecure, statSecParam, inputBitCount); + //MC_AWAIT(sync(chl, Role::Sender)); + + if (randomOT) + { + // Same explanation as above. + MC_AWAIT(sender.init(numOTs, prng, chl)); + + // Same explanation as above. + for (i = 0; i < numOTs; ) + { + // Same explanation as above. + min = std::min(numOTs - i, step); + + // unlike for the receiver, before we call encode to get + // some desired OT message, we must call recvCorrection(...). + // This receivers some information that the receiver had sent + // and allows the sender to compute any OT message that they desired. + // Note that the step size must match what the receiver used. + // If this is unknown you can use recvCorrection(chl) -> u64 + // which will tell you how many were sent. + MC_AWAIT(sender.recvCorrection(chl, min)); + + // we now encode any OT message with index less that i + min. + for (u64 j = 0; j < min; ++j, ++i) + { + // in particular, the sender can retrieve many OT messages + // at a single index, in this case we chose to retrieve 3 + // but that is arbitrary. + auto choice0 = prng.get(); + auto choice1 = prng.get(); + auto choice2 = prng.get(); + + // these we hold the actual OT messages. + block + otMessage0, + otMessage1, + otMessage2; + + // now retrieve the messages + sender.encode(i, &choice0, &otMessage0); + sender.encode(i, &choice1, &otMessage1); + sender.encode(i, &choice2, &otMessage2); + } + } + + // This call is required to make sure the receiver did not cheat. + // All corrections must be received before this is called. + MC_AWAIT(sender.check(chl, ZeroBlock)); + } + else + { + // populate this with the messages that you want to send. + sendMessages.resize(numOTs, numChosenMsgs); + prng.get(sendMessages.data(), sendMessages.size()); + + // perform the OTs with the given messages. + MC_AWAIT(sender.sendChosen(sendMessages, prng, chl)); + + } + + MC_AWAIT(chl.flush()); + MC_END(); + }; + + + Timer time; + auto s = time.setTimePoint("start"); + + + task<> proto; + if (role == Role::Sender) + proto = sendRoutine(); + else + proto = recvRoutine(); + try + { + cp::sync_wait(proto); + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } + + auto e = time.setTimePoint("finish"); + auto milli = std::chrono::duration_cast(e - s).count(); + + if (role == Role::Sender) + std::cout << tag << " n=" << totalOTs << " " << milli << " ms " << std::endl; +#endif + } + + + bool NChooseOne_Examples(const CLP& cmd) + { + bool flagSet = false; +#ifdef ENABLE_KKRT + flagSet |= runIf(NChooseOne_example, cmd, kkrt); +#endif + +#ifdef ENABLE_OOS + flagSet |= runIf(NChooseOne_example, cmd, oos); +#endif + + return flagSet; + } + +} diff --git a/frontend/ExampleNChooseOne.h b/frontend/ExampleNChooseOne.h index 7d886cc4..571657ed 100644 --- a/frontend/ExampleNChooseOne.h +++ b/frontend/ExampleNChooseOne.h @@ -1,230 +1,8 @@ #pragma once - - -#include "cryptoTools/Common/Matrix.h" -#include "libOTe/NChooseOne/Oos/OosNcoOtReceiver.h" -#include "libOTe/NChooseOne/Oos/OosNcoOtSender.h" -#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtReceiver.h" -#include "libOTe/NChooseOne/Kkrt/KkrtNcoOtSender.h" -#include "cryptoTools/Common/Matrix.h" -#include "libOTe/Tools/Coproto.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { - - - auto chls = cp::LocalAsyncSocket::makePair(); - - template - void NChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP&) - { -#ifdef COPROTO_ENABLE_BOOST - const u64 step = 1024; - - if (totalOTs == 0) - totalOTs = 1 << 20; - - bool randomOT = true; - u64 numOTs = (u64)totalOTs; - auto numChosenMsgs = 256; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - //auto chl = role == Role::Sender ? chls[0] : chls[1]; - - PRNG prng(ZeroBlock);// sysRandomSeed()); - - NcoOtSender sender; - NcoOtReceiver recver; - - // all Nco Ot extenders must have configure called first. This determines - // a variety of parameters such as how many base OTs are required. - bool maliciousSecure = false; - u64 statSecParam = 40; - u64 inputBitCount = 76; // the kkrt protocol default to 128 but oos can only do 76. - - // create a lambda function that performs the computation of a single receiver thread. - auto recvRoutine = [&]() -> task<> - { - MC_BEGIN(task<>,&, - i = u64{}, min = u64{}, - recvMsgs = std::vector{}, - choices = std::vector{} - ); - - recver.configure(maliciousSecure, statSecParam, inputBitCount); - //MC_AWAIT(sync(chl, Role::Receiver)); - - if (randomOT) - { - // once configure(...) and setBaseOts(...) are called, - // we can compute many batches of OTs. First we need to tell - // the instance how many OTs we want in this batch. This is done here. - MC_AWAIT(recver.init(numOTs, prng, chl)); - - // now we can iterate over the OTs and actually retrieve the desired - // messages. However, for efficiency we will do this in steps where - // we do some computation followed by sending off data. This is more - // efficient since data will be sent in the background :). - for (i = 0; i < numOTs; ) - { - // figure out how many OTs we want to do in this step. - min = std::min(numOTs - i, step); - - //// iterate over this step. - for (u64 j = 0; j < min; ++j, ++i) - { - // For the OT index by i, we need to pick which - // one of the N OT messages that we want. For this - // example we simply pick a random one. Note only the - // first log2(N) bits of choice is considered. - block choice = prng.get(); - - // this will hold the (random) OT message of our choice - block otMessage; - - // retrieve the desired message. - recver.encode(i, &choice, &otMessage); - - // do something cool with otMessage - //otMessage; - } - - // Note that all OTs in this region must be encode. If there are some - // that you don't actually care about, then you can skip them by calling - // - // recver.zeroEncode(i); - // - - // Now that we have gotten out the OT mMessages for this step, - // we are ready to send over network some information that - // allows the sender to also compute the OT mMessages. Since we just - // encoded "min" OT mMessages, we will tell the class to send the - // next min "correction" values. - MC_AWAIT(recver.sendCorrection(chl, min)); - } - - // once all numOTs have been encoded and had their correction values sent - // we must call check. This allows to sender to make sure we did not cheat. - // For semi-honest protocols, this can and will be skipped. - MC_AWAIT(recver.check(chl, prng.get())); - } - else - { - recvMsgs.resize(numOTs); - choices.resize(numOTs); - - // define which messages the receiver should learn. - for (i = 0; i < numOTs; ++i) - choices[i] = prng.get(); - - // the messages that were learned are written to recvMsgs. - MC_AWAIT(recver.receiveChosen(numChosenMsgs, recvMsgs, choices, prng, chl)); - } - - MC_AWAIT(chl.flush()); - MC_END(); - }; - - // create a lambda function that performs the computation of a single sender thread. - auto sendRoutine = [&]() - { - MC_BEGIN(task<>,&, - sendMessages = Matrix{}, - i = u64{}, min = u64{} - ); - - sender.configure(maliciousSecure, statSecParam, inputBitCount); - //MC_AWAIT(sync(chl, Role::Sender)); - - if (randomOT) - { - // Same explanation as above. - MC_AWAIT(sender.init(numOTs, prng, chl)); - - // Same explanation as above. - for (i = 0; i < numOTs; ) - { - // Same explanation as above. - min = std::min(numOTs - i, step); - - // unlike for the receiver, before we call encode to get - // some desired OT message, we must call recvCorrection(...). - // This receivers some information that the receiver had sent - // and allows the sender to compute any OT message that they desired. - // Note that the step size must match what the receiver used. - // If this is unknown you can use recvCorrection(chl) -> u64 - // which will tell you how many were sent. - MC_AWAIT(sender.recvCorrection(chl, min)); - - // we now encode any OT message with index less that i + min. - for (u64 j = 0; j < min; ++j, ++i) - { - // in particular, the sender can retrieve many OT messages - // at a single index, in this case we chose to retrieve 3 - // but that is arbitrary. - auto choice0 = prng.get(); - auto choice1 = prng.get(); - auto choice2 = prng.get(); - - // these we hold the actual OT messages. - block - otMessage0, - otMessage1, - otMessage2; - - // now retrieve the messages - sender.encode(i, &choice0, &otMessage0); - sender.encode(i, &choice1, &otMessage1); - sender.encode(i, &choice2, &otMessage2); - } - } - - // This call is required to make sure the receiver did not cheat. - // All corrections must be received before this is called. - MC_AWAIT(sender.check(chl, ZeroBlock)); - } - else - { - // populate this with the messages that you want to send. - sendMessages.resize(numOTs, numChosenMsgs); - prng.get(sendMessages.data(), sendMessages.size()); - - // perform the OTs with the given messages. - MC_AWAIT(sender.sendChosen(sendMessages, prng, chl)); - - } - - MC_AWAIT(chl.flush()); - MC_END(); - }; - - - Timer time; - auto s = time.setTimePoint("start"); - - - task<> proto; - if (role == Role::Sender) - proto = sendRoutine(); - else - proto = recvRoutine(); - try - { - cp::sync_wait(proto); - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - } - - auto e = time.setTimePoint("finish"); - auto milli = std::chrono::duration_cast(e - s).count(); - - if (role == Role::Sender) - std::cout << tag << " n=" << totalOTs << " " << milli << " ms " << std::endl; -#endif - } - + bool NChooseOne_Examples(const CLP& cmd); } diff --git a/frontend/ExampleSilent.cpp b/frontend/ExampleSilent.cpp new file mode 100644 index 00000000..c2b17d02 --- /dev/null +++ b/frontend/ExampleSilent.cpp @@ -0,0 +1,189 @@ + +#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" +#include "util.h" +#include + +#include "cryptoTools/Network/IOService.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ + + void Silent_example(Role role, u64 numOTs, u64 numThreads, std::string ip, std::string tag, const CLP& cmd) + { +#if defined(ENABLE_SILENTOT) && defined(COPROTO_ENABLE_BOOST) + + if (numOTs == 0) + numOTs = 1 << 20; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + + + PRNG prng(sysRandomSeed()); + + bool fakeBase = cmd.isSet("fakeBase"); + u64 trials = cmd.getOr("trials", 1); + auto malicious = cmd.isSet("mal") ? SilentSecType::Malicious : SilentSecType::SemiHonest; + + auto multType = (MultType)cmd.getOr("multType", (int)DefaultMultType); + + std::vector types; + if (cmd.isSet("base")) + types.push_back(SilentBaseType::Base); + else + types.push_back(SilentBaseType::BaseExtend); + + macoro::thread_pool threadPool; + auto work = threadPool.make_work(); + if (numThreads > 1) + threadPool.create_threads(numThreads); + + for (auto type : types) + { + for (u64 tt = 0; tt < trials; ++tt) + { + Timer timer; + auto start = timer.setTimePoint("start"); + if (role == Role::Sender) + { + SilentOtExtSender sender; + + // optionally request the LPN encoding matrix. + sender.mMultType = multType; + + // optionally configure the sender. default is semi honest security. + sender.configure(numOTs, 2, numThreads, malicious); + + if (fakeBase) + { + auto nn = sender.baseOtCount(); + BitVector bits(nn); + bits.randomize(prng); + std::vector> baseSendMsgs(bits.size()); + std::vector baseRecvMsgs(bits.size()); + + auto commonPrng = PRNG(ZeroBlock); + commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); + for (u64 i = 0; i < bits.size(); ++i) + baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; + + sender.setBaseOts(baseRecvMsgs, bits); + } + else + { + // optional. You can request that the base ot are generated either + // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. + // The default is the latter, base + extension. + cp::sync_wait(sender.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); + } + + std::vector> messages(numOTs); + + // create the protocol object. + auto protocol = sender.silentSend(messages, prng, chl); + + // run the protocol + if (numThreads <= 1) + cp::sync_wait(protocol); + else + // launch the protocol on the thread pool. + cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); + + // messages has been populated with random OT messages. + // See the header for other options. + } + else + { + + SilentOtExtReceiver recver; + + // optionally request the LPN encoding matrix. + recver.mMultType = multType; + + // configure the sender. optional for semi honest security... + recver.configure(numOTs, 2, numThreads, malicious); + + if (fakeBase) + { + auto nn = recver.baseOtCount(); + BitVector bits(nn); + bits.randomize(prng); + std::vector> baseSendMsgs(bits.size()); + std::vector baseRecvMsgs(bits.size()); + + auto commonPrng = PRNG(ZeroBlock); + commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); + for (u64 i = 0; i < bits.size(); ++i) + baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; + + recver.setBaseOts(baseSendMsgs); + } + else + { + // optional. You can request that the base ot are generated either + // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. + // The default is the latter, base + extension. + cp::sync_wait(recver.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); + } + + std::vector messages(numOTs); + BitVector choices(numOTs); + + // create the protocol object. + auto protocol = recver.silentReceive(choices, messages, prng, chl); + + // run the protocol + if (numThreads <= 1) + cp::sync_wait(protocol); + else + // launch the protocol on the thread pool. + cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); + + // choices, messages has been populated with random OT messages. + // messages[i] = sender.message[i][choices[i]] + // See the header for other options. + } + auto end = timer.setTimePoint("end"); + auto milli = std::chrono::duration_cast(end - start).count(); + + u64 com = chl.bytesReceived() + chl.bytesSent(); + + if (role == Role::Sender) + { + std::string typeStr = type == SilentBaseType::Base ? "b " : "be "; + lout << tag << + " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << + " type: " << Color::Green << typeStr << Color::Default << + " || " << Color::Green << + std::setw(6) << std::setfill(' ') << milli << " ms " << + std::setw(6) << std::setfill(' ') << com << " bytes" << std::endl << Color::Default; + + if (cmd.getOr("v", 0) > 1) + lout << gTimer << std::endl; + } + + if (cmd.isSet("v")) + { + if (role == Role::Sender) + lout << " **** sender ****\n" << timer << std::endl; + + if (role == Role::Receiver) + lout << " **** receiver ****\n" << timer << std::endl; + } + } + + } + + cp::sync_wait(chl.flush()); + +#endif + } + bool Silent_Examples(const CLP& cmd) + { + return runIf(Silent_example, cmd, Silent); + } + + +} diff --git a/frontend/ExampleSilent.h b/frontend/ExampleSilent.h index e792ba73..abc4122f 100644 --- a/frontend/ExampleSilent.h +++ b/frontend/ExampleSilent.h @@ -1,186 +1,7 @@ #pragma once - -#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" -#include "util.h" -#include - -#include "cryptoTools/Network/IOService.h" -#include "coproto/Socket/AsioSocket.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { - - void Silent_example(Role role, u64 numOTs, u64 numThreads, std::string ip, std::string tag, CLP& cmd) - { -#if defined(ENABLE_SILENTOT) && defined(COPROTO_ENABLE_BOOST) - - if (numOTs == 0) - numOTs = 1 << 20; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - - - PRNG prng(sysRandomSeed()); - - bool fakeBase = cmd.isSet("fakeBase"); - u64 trials = cmd.getOr("trials", 1); - auto malicious = cmd.isSet("mal") ? SilentSecType::Malicious : SilentSecType::SemiHonest; - - auto multType = (MultType)cmd.getOr("multType", (int)DefaultMultType); - - std::vector types; - if (cmd.isSet("base")) - types.push_back(SilentBaseType::Base); - else - types.push_back(SilentBaseType::BaseExtend); - - macoro::thread_pool threadPool; - auto work = threadPool.make_work(); - if (numThreads > 1) - threadPool.create_threads(numThreads); - - for (auto type : types) - { - for (u64 tt = 0; tt < trials; ++tt) - { - Timer timer; - auto start = timer.setTimePoint("start"); - if (role == Role::Sender) - { - SilentOtExtSender sender; - - // optionally request the LPN encoding matrix. - sender.mMultType = multType; - - // optionally configure the sender. default is semi honest security. - sender.configure(numOTs, 2, numThreads, malicious); - - if (fakeBase) - { - auto nn = sender.baseOtCount(); - BitVector bits(nn); - bits.randomize(prng); - std::vector> baseSendMsgs(bits.size()); - std::vector baseRecvMsgs(bits.size()); - - auto commonPrng = PRNG(ZeroBlock); - commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < bits.size(); ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - - sender.setBaseOts(baseRecvMsgs, bits); - } - else - { - // optional. You can request that the base ot are generated either - // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. - // The default is the latter, base + extension. - cp::sync_wait(sender.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); - } - - std::vector> messages(numOTs); - - // create the protocol object. - auto protocol = sender.silentSend(messages, prng, chl); - - // run the protocol - if (numThreads <= 1) - cp::sync_wait(protocol); - else - // launch the protocol on the thread pool. - cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); - - // messages has been populated with random OT messages. - // See the header for other options. - } - else - { - - SilentOtExtReceiver recver; - - // optionally request the LPN encoding matrix. - recver.mMultType = multType; - - // configure the sender. optional for semi honest security... - recver.configure(numOTs, 2, numThreads, malicious); - - if (fakeBase) - { - auto nn = recver.baseOtCount(); - BitVector bits(nn); - bits.randomize(prng); - std::vector> baseSendMsgs(bits.size()); - std::vector baseRecvMsgs(bits.size()); - - auto commonPrng = PRNG(ZeroBlock); - commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < bits.size(); ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - - recver.setBaseOts(baseSendMsgs); - } - else - { - // optional. You can request that the base ot are generated either - // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. - // The default is the latter, base + extension. - cp::sync_wait(recver.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); - } - - std::vector messages(numOTs); - BitVector choices(numOTs); - - // create the protocol object. - auto protocol = recver.silentReceive(choices, messages, prng, chl); - - // run the protocol - if (numThreads <= 1) - cp::sync_wait(protocol); - else - // launch the protocol on the thread pool. - cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); - - // choices, messages has been populated with random OT messages. - // messages[i] = sender.message[i][choices[i]] - // See the header for other options. - } - auto end = timer.setTimePoint("end"); - auto milli = std::chrono::duration_cast(end - start).count(); - - u64 com = chl.bytesReceived() + chl.bytesSent(); - - if (role == Role::Sender) - { - std::string typeStr = type == SilentBaseType::Base ? "b " : "be "; - lout << tag << - " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << - " type: " << Color::Green << typeStr << Color::Default << - " || " << Color::Green << - std::setw(6) << std::setfill(' ') << milli << " ms " << - std::setw(6) << std::setfill(' ') << com << " bytes" << std::endl << Color::Default; - - if (cmd.getOr("v", 0) > 1) - lout << gTimer << std::endl; - } - - if (cmd.isSet("v")) - { - if (role == Role::Sender) - lout << " **** sender ****\n" << timer << std::endl; - - if (role == Role::Receiver) - lout << " **** receiver ****\n" << timer << std::endl; - } - } - - } - - cp::sync_wait(chl.flush()); - -#endif - } - - + bool Silent_Examples(const CLP& cmd); } diff --git a/frontend/ExampleTwoChooseOne.cpp b/frontend/ExampleTwoChooseOne.cpp new file mode 100644 index 00000000..4dc5527b --- /dev/null +++ b/frontend/ExampleTwoChooseOne.cpp @@ -0,0 +1,356 @@ + +#include "libOTe/Base/BaseOT.h" +#include "libOTe/TwoChooseOne/Kos/KosOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Kos/KosOtExtSender.h" +#include "libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.h" +#include "libOTe/TwoChooseOne/KosDot/KosDotExtSender.h" +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" + + +#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" +#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ +#ifdef ENABLE_IKNP + void noHash(IknpOtExtSender& s, IknpOtExtReceiver& r) + { + s.mHash = false; + r.mHash = false; + } +#endif + + template + void noHash(Sender&, Receiver&) + { + throw std::runtime_error("This protocol does not support noHash"); + } + +#ifdef ENABLE_SOFTSPOKEN_OT + // soft spoken takes an extra parameter as input what determines + // the computation/communication trade-off. + template + using is_SoftSpoken = typename std::conditional< + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value || + //std::is_same::value + false + , + std::true_type, std::false_type>::type; +#else + template + using is_SoftSpoken = std::false_type; +#endif + + template + typename std::enable_if::value, T>::type + construct(CLP& cmd) + { + return T(cmd.getOr("f", 2)); + } + + template + typename std::enable_if::value, T>::type + construct(CLP& cmd) + { + return T{}; + } + + template + void TwoChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& cmd) + { + +#ifdef COPROTO_ENABLE_BOOST + if (totalOTs == 0) + totalOTs = 1 << 20; + + bool randomOT = true; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + + PRNG prng(sysRandomSeed()); + + + OtExtSender sender; + OtExtRecver receiver; + + +#ifdef LIBOTE_HAS_BASE_OT + // Now compute the base OTs, we need to set them on the first pair of extenders. + // In real code you would only have a sender or reciever, not both. But we do + // here just showing the example. + if (role == Role::Receiver) + { + DefaultBaseOT base; + std::array, 128> baseMsg; + + // perform the base To, call sync_wait to block until they have completed. + cp::sync_wait(base.send(baseMsg, prng, chl)); + receiver.setBaseOts(baseMsg); + } + else + { + + DefaultBaseOT base; + BitVector bv(128); + std::array baseMsg; + bv.randomize(prng); + + // perform the base To, call sync_wait to block until they have completed. + cp::sync_wait(base.receive(bv, baseMsg, prng, chl)); + sender.setBaseOts(baseMsg, bv); + } + +#else + if (!cmd.isSet("fakeBase")) + std::cout << "warning, base ots are not enabled. Fake base OTs will be used. " << std::endl; + PRNG commonPRNG(oc::ZeroBlock); + std::array, 128> sendMsgs; + commonPRNG.get(sendMsgs.data(), sendMsgs.size()); + if (role == Role::Receiver) + { + receiver.setBaseOts(sendMsgs); + } + else + { + BitVector bv(128); + bv.randomize(commonPRNG); + std::array recvMsgs; + for (u64 i = 0; i < 128; ++i) + recvMsgs[i] = sendMsgs[i][bv[i]]; + sender.setBaseOts(recvMsgs, bv); + } +#endif + + if (cmd.isSet("noHash")) + noHash(sender, receiver); + + Timer timer, sendTimer, recvTimer; + sendTimer.setTimePoint("start"); + recvTimer.setTimePoint("start"); + auto s = timer.setTimePoint("start"); + + if (numThreads == 1) + { + if (role == Role::Receiver) + { + // construct the choices that we want. + BitVector choice(totalOTs); + // in this case pick random messages. + choice.randomize(prng); + + // construct a vector to stored the received messages. + AlignedUnVector rMsgs(totalOTs); + + try { + + if (randomOT) + { + // perform totalOTs random OTs, the results will be written to msgs. + cp::sync_wait(receiver.receive(choice, rMsgs, prng, chl)); + } + else + { + // perform totalOTs chosen message OTs, the results will be written to msgs. + cp::sync_wait(receiver.receiveChosen(choice, rMsgs, prng, chl)); + } + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + chl.close(); + } + } + else + { + // construct a vector to stored the random send messages. + AlignedUnVector> sMsgs(totalOTs); + + + // if delta OT is used, then the user can call the following + // to set the desired XOR difference between the zero messages + // and the one messages. + // + // senders[i].setDelta(some 128 bit delta); + // + try + { + if (randomOT) + { + // perform the OTs and write the random OTs to msgs. + cp::sync_wait(sender.send(sMsgs, prng, chl)); + } + else + { + // Populate msgs with something useful... + prng.get(sMsgs.data(), sMsgs.size()); + + // perform the OTs. The receiver will learn one + // of the messages stored in msgs. + cp::sync_wait(sender.sendChosen(sMsgs, prng, chl)); + } + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + chl.close(); + } + } + + // make sure all messages have been sent. + cp::sync_wait(chl.flush()); + } + else + { + + // for multi threading, we only show example for random OTs. + // We first need to construct the inputs + // that each thread will use. Note that the actual protocol + // is not thread safe so everything needs to be independent. + std::vector> tasks(numThreads); + std::vector threadPrngs(numThreads); + std::vector threadChls(numThreads); + + macoro::thread_pool::work work; + macoro::thread_pool threadPool(numThreads, work); + + if (role == Role::Receiver) + { + std::vector receivers(numThreads); + std::vector threadChoices(numThreads); + std::vector> threadMsgs(numThreads); + + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + { + u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); + u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); + + threadChoices[threadIndex].resize(endIndex - beginIndex); + threadChoices[threadIndex].randomize(prng); + + threadMsgs[threadIndex].resize(endIndex - beginIndex); + + // create a copy of the receiver so that each can run + // independently. A single receiver is not thread safe. + receivers[threadIndex] = receiver.splitBase(); + + // create a PRNG for this thread. + threadPrngs[threadIndex].SetSeed(prng.get()); + + // create a socket for this thread. This is done by calling fork(). + threadChls[threadIndex] = chl.fork(); + + // start the receive protocol on the thread pool + tasks[threadIndex] = + receivers[threadIndex].receive( + threadChoices[threadIndex], + threadMsgs[threadIndex], + threadPrngs[threadIndex], + threadChls[threadIndex]) + | macoro::start_on(threadPool); + } + + // block this thread until the receive operations + // have completed. + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + cp::sync_wait(tasks[threadIndex]); + } + else + { + std::vector senders(numThreads); + std::vector>> threadMsgs(numThreads); + + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + { + u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); + u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); + + threadMsgs[threadIndex].resize(endIndex - beginIndex); + + // create a copy of the receiver so that each can run + // independently. A single receiver is not thread safe. + senders[threadIndex] = sender.splitBase(); + + // create a PRNG for this thread. + threadPrngs[threadIndex].SetSeed(prng.get()); + + // create a socket for this thread. This is done by calling fork(). + threadChls[threadIndex] = chl.fork(); + + // start the send protocol on the thread pool + tasks[threadIndex] = + senders[threadIndex].send( + threadMsgs[threadIndex], + threadPrngs[threadIndex], + threadChls[threadIndex]) + | macoro::start_on(threadPool); + } + + // block this thread until the receive operations + // have completed. + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + cp::sync_wait(tasks[threadIndex]); + } + + work.reset(); + } + + + + auto e = timer.setTimePoint("finish"); + auto milli = std::chrono::duration_cast(e - s).count(); + + auto com = 0;// (chls[0].getTotalDataRecv() + chls[0].getTotalDataSent())* numThreads; + + if (role == Role::Sender) + lout << tag << " n=" << Color::Green << totalOTs << " " << milli << " ms " << com << " bytes" << std::endl << Color::Default; + + + if (cmd.isSet("v") && role == Role::Sender) + { + if (role == Role::Sender) + lout << " **** sender ****\n" << sendTimer << std::endl; + + if (role == Role::Receiver) + lout << " **** receiver ****\n" << recvTimer << std::endl; + } + +#else + throw std::runtime_error("This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. " LOCATION); +#endif + } + + + + + bool TwoChooseOne_Examples(const CLP& cmd) + { + bool flagSet = false; +#ifdef ENABLE_IKNP + flagSet |= runIf(TwoChooseOne_example, cmd, iknp); +#endif + +#ifdef ENABLE_KOS + flagSet |= runIf(TwoChooseOne_example, cmd, kos); +#endif + +#ifdef ENABLE_DELTA_KOS + flagSet |= runIf(TwoChooseOne_example, cmd, dkos); +#endif + +#ifdef ENABLE_SOFTSPOKEN_OT + flagSet |= runIf(TwoChooseOne_example, SoftSpokenShOtReceiver<>>, cmd, sshonest); + flagSet |= runIf(TwoChooseOne_example, cmd, smalicious); +#endif + return flagSet; + } + } diff --git a/frontend/ExampleTwoChooseOne.h b/frontend/ExampleTwoChooseOne.h index d532c8f3..c90ad94d 100644 --- a/frontend/ExampleTwoChooseOne.h +++ b/frontend/ExampleTwoChooseOne.h @@ -1,317 +1,8 @@ #pragma once - -#include "libOTe/Base/BaseOT.h" -#include "libOTe/TwoChooseOne/Kos/KosOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Kos/KosOtExtSender.h" -#include "libOTe/TwoChooseOne/KosDot/KosDotExtReceiver.h" -#include "libOTe/TwoChooseOne/KosDot/KosDotExtSender.h" -#include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" - - -#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" - -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalLeakyDotExt.h" -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.h" -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.h" -//#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShDotExt.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { -#ifdef ENABLE_IKNP - void noHash(IknpOtExtSender& s, IknpOtExtReceiver& r) - { - s.mHash = false; - r.mHash = false; - } -#endif - - template - void noHash(Sender&, Receiver&) - { - throw std::runtime_error("This protocol does not support noHash"); - } - -#ifdef ENABLE_SOFTSPOKEN_OT - // soft spoken takes an extra parameter as input what determines - // the computation/communication trade-off. - template - using is_SoftSpoken = typename std::conditional< - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value || - //std::is_same::value - false - , - std::true_type, std::false_type>::type; -#else - template - using is_SoftSpoken = std::false_type; -#endif - - template - typename std::enable_if::value,T>::type - construct(CLP& cmd) - { - return T( cmd.getOr("f", 2) ); - } - - template - typename std::enable_if::value, T>::type - construct(CLP& cmd) - { - return T{}; - } - - template - void TwoChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& cmd) - { - -#ifdef COPROTO_ENABLE_BOOST - if (totalOTs == 0) - totalOTs = 1 << 20; - - bool randomOT = true; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - - PRNG prng(sysRandomSeed()); - - - OtExtSender sender; - OtExtRecver receiver; - - -#ifdef LIBOTE_HAS_BASE_OT - // Now compute the base OTs, we need to set them on the first pair of extenders. - // In real code you would only have a sender or reciever, not both. But we do - // here just showing the example. - if (role == Role::Receiver) - { - DefaultBaseOT base; - std::array, 128> baseMsg; - - // perform the base To, call sync_wait to block until they have completed. - cp::sync_wait(base.send(baseMsg, prng, chl)); - receiver.setBaseOts(baseMsg); - } - else - { - - DefaultBaseOT base; - BitVector bv(128); - std::array baseMsg; - bv.randomize(prng); - - // perform the base To, call sync_wait to block until they have completed. - cp::sync_wait(base.receive(bv, baseMsg, prng, chl)); - sender.setBaseOts(baseMsg, bv); - } - -#else - if (!cmd.isSet("fakeBase")) - std::cout << "warning, base ots are not enabled. Fake base OTs will be used. " << std::endl; - PRNG commonPRNG(oc::ZeroBlock); - std::array, 128> sendMsgs; - commonPRNG.get(sendMsgs.data(), sendMsgs.size()); - if (role == Role::Receiver) - { - receiver.setBaseOts(sendMsgs); - } - else - { - BitVector bv(128); - bv.randomize(commonPRNG); - std::array recvMsgs; - for (u64 i = 0; i < 128; ++i) - recvMsgs[i] = sendMsgs[i][bv[i]]; - sender.setBaseOts(recvMsgs, bv); - } -#endif - - if (cmd.isSet("noHash")) - noHash(sender, receiver); - - Timer timer, sendTimer, recvTimer; - sendTimer.setTimePoint("start"); - recvTimer.setTimePoint("start"); - auto s = timer.setTimePoint("start"); - - if (numThreads == 1) - { - if (role == Role::Receiver) - { - // construct the choices that we want. - BitVector choice(totalOTs); - // in this case pick random messages. - choice.randomize(prng); - - // construct a vector to stored the received messages. - std::vector rMsgs(totalOTs); - - if (randomOT) - { - // perform totalOTs random OTs, the results will be written to msgs. - cp::sync_wait(receiver.receive(choice, rMsgs, prng, chl)); - } - else - { - // perform totalOTs chosen message OTs, the results will be written to msgs. - cp::sync_wait(receiver.receiveChosen(choice, rMsgs, prng, chl)); - } - } - else - { - // construct a vector to stored the random send messages. - std::vector> sMsgs(totalOTs); - - - // if delta OT is used, then the user can call the following - // to set the desired XOR difference between the zero messages - // and the one messages. - // - // senders[i].setDelta(some 128 bit delta); - // - - if (randomOT) - { - // perform the OTs and write the random OTs to msgs. - cp::sync_wait(sender.send(sMsgs, prng, chl)); - } - else - { - // Populate msgs with something useful... - prng.get(sMsgs.data(), sMsgs.size()); - - // perform the OTs. The receiver will learn one - // of the messages stored in msgs. - cp::sync_wait(sender.sendChosen(sMsgs, prng, chl)); - } - } - - } - else - { - - // for multi threading, we only show example for random OTs. - // We first need to construct the inputs - // that each thread will use. Note that the actual protocol - // is not thread safe so everything needs to be independent. - std::vector> tasks(numThreads); - std::vector threadPrngs(numThreads); - std::vector threadChls(numThreads); - - macoro::thread_pool::work work; - macoro::thread_pool threadPool(numThreads, work); - - if (role == Role::Receiver) - { - std::vector receivers(numThreads); - std::vector threadChoices(numThreads); - std::vector> threadMsgs(numThreads); - - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - { - u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); - u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); - - threadChoices[threadIndex].resize(endIndex - beginIndex); - threadChoices[threadIndex].randomize(prng); - - threadMsgs[threadIndex].resize(endIndex - beginIndex); - - // create a copy of the receiver so that each can run - // independently. A single receiver is not thread safe. - receivers[threadIndex] = receiver.splitBase(); - - // create a PRNG for this thread. - threadPrngs[threadIndex].SetSeed(prng.get()); - - // create a socket for this thread. This is done by calling fork(). - threadChls[threadIndex] = chl.fork(); - - // start the receive protocol on the thread pool - tasks[threadIndex] = - receivers[threadIndex].receive( - threadChoices[threadIndex], - threadMsgs[threadIndex], - threadPrngs[threadIndex], - threadChls[threadIndex]) - | macoro::start_on(threadPool); - } - - // block this thread until the receive operations - // have completed. - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - cp::sync_wait(tasks[threadIndex]); - } - else - { - std::vector senders(numThreads); - std::vector>> threadMsgs(numThreads); - - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - { - u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); - u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); - - threadMsgs[threadIndex].resize(endIndex - beginIndex); - - // create a copy of the receiver so that each can run - // independently. A single receiver is not thread safe. - senders[threadIndex] = sender.splitBase(); - - // create a PRNG for this thread. - threadPrngs[threadIndex].SetSeed(prng.get()); - - // create a socket for this thread. This is done by calling fork(). - threadChls[threadIndex] = chl.fork(); - - // start the send protocol on the thread pool - tasks[threadIndex] = - senders[threadIndex].send( - threadMsgs[threadIndex], - threadPrngs[threadIndex], - threadChls[threadIndex]) - | macoro::start_on(threadPool); - } - - // block this thread until the receive operations - // have completed. - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - cp::sync_wait(tasks[threadIndex]); - } - - work.reset(); - } - - auto e = timer.setTimePoint("finish"); - auto milli = std::chrono::duration_cast(e - s).count(); - - auto com = 0;// (chls[0].getTotalDataRecv() + chls[0].getTotalDataSent())* numThreads; - - if (role == Role::Sender) - lout << tag << " n=" << Color::Green << totalOTs << " " << milli << " ms " << com << " bytes" << std::endl << Color::Default; - - - if (cmd.isSet("v") && role == Role::Sender) - { - if (role == Role::Sender) - lout << " **** sender ****\n" << sendTimer << std::endl; - - if (role == Role::Receiver) - lout << " **** receiver ****\n" << recvTimer << std::endl; - } - -#else - throw std::runtime_error("This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. " LOCATION); -#endif - } + bool TwoChooseOne_Examples(const CLP& cmd); } diff --git a/frontend/ExampleVole.cpp b/frontend/ExampleVole.cpp new file mode 100644 index 00000000..2ef71894 --- /dev/null +++ b/frontend/ExampleVole.cpp @@ -0,0 +1,111 @@ + +#include "libOTe/Vole/Silent/SilentVoleReceiver.h" +#include "libOTe/Vole/Silent/SilentVoleSender.h" +#include "util.h" +#include "coproto/Socket/AsioSocket.h" + +namespace osuCrypto +{ + + + template + void Vole_example(Role role, int numVole, int numThreads, std::string ip, std::string tag, const CLP& cmd) + { +#if defined(ENABLE_SILENT_VOLE) && defined(COPROTO_ENABLE_BOOST) + + if (numVole == 0) + numVole = 1 << 20; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + + // get a random number generator seeded from the system + PRNG prng(sysRandomSeed()); + + auto mulType = (MultType)cmd.getOr("multType", (int)DefaultMultType); + + u64 milli; + Timer timer; + + gTimer.setTimePoint("begin"); + if (role == Role::Receiver) + { + // construct a vector to stored the received messages. + // A = B + C * delta + AlignedUnVector C(numVole); + AlignedUnVector A(numVole); + gTimer.setTimePoint("recver.msg.alloc"); + + SilentVoleReceiver receiver; + receiver.mMultType = mulType; + receiver.configure(numVole); + gTimer.setTimePoint("recver.config"); + + // block until both parties are ready (optional). + cp::sync_wait(sync(chl, role)); + auto b = timer.setTimePoint("start"); + receiver.setTimePoint("start"); + gTimer.setTimePoint("recver.genBase"); + + // perform numVole random OTs, the results will be written to msgs. + cp::sync_wait(receiver.silentReceive(C, A, prng, chl)); + + // record the time. + receiver.setTimePoint("finish"); + auto e = timer.setTimePoint("finish"); + milli = std::chrono::duration_cast(e - b).count(); + } + else + { + gTimer.setTimePoint("sender.thrd.begin"); + + // A = B + C * delta + AlignedUnVector B(numVole); + block delta = prng.get(); + + gTimer.setTimePoint("sender.msg.alloc"); + + SilentVoleSender sender; + sender.mMultType = mulType; + sender.configure(numVole); + gTimer.setTimePoint("sender.config"); + timer.setTimePoint("start"); + + // block until both parties are ready (optional). + cp::sync_wait(sync(chl, role)); + auto b = sender.setTimePoint("start"); + gTimer.setTimePoint("sender.genBase"); + + // perform the OTs and write the random OTs to msgs. + cp::sync_wait(sender.silentSend(delta, B, prng, chl)); + + sender.setTimePoint("finish"); + auto e = timer.setTimePoint("finish"); + milli = std::chrono::duration_cast(e - b).count(); + } + if (role == Role::Sender) + { + + lout << tag << + " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numVole << Color::Default << + " || " << Color::Green << + std::setw(6) << std::setfill(' ') << milli << " ms " << + //std::setw(6) << std::setfill(' ') << com << " bytes" << + std::endl << Color::Default; + + if (cmd.getOr("v", 0) > 1) + lout << gTimer << std::endl; + + } + + // make sure all messages are sent. + cp::sync_wait(chl.flush()); +#endif + } + bool Vole_Examples(const CLP& cmd) + { + return + runIf(Vole_example, cmd, vole); + } + +} diff --git a/frontend/ExampleVole.h b/frontend/ExampleVole.h index c8891879..da6e58a5 100644 --- a/frontend/ExampleVole.h +++ b/frontend/ExampleVole.h @@ -1,142 +1,10 @@ #pragma once - -#include "libOTe/Vole/Silent/SilentVoleReceiver.h" -#include "libOTe/Vole/Silent/SilentVoleSender.h" +#include "cryptoTools/Common/CLP.h" namespace osuCrypto { - - //template - void Vole_example(Role role, int numOTs, int numThreads, std::string ip, std::string tag, CLP& cmd) - { -#if defined(ENABLE_SILENT_VOLE) && defined(COPROTO_ENABLE_BOOST) - - if (numOTs == 0) - numOTs = 1 << 20; - using OtExtSender = SilentVoleSender; - using OtExtRecver = SilentVoleReceiver; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - - // get a random number generator seeded from the system - PRNG prng(sysRandomSeed()); - - auto mulType = (MultType)cmd.getOr("multType", (int)DefaultMultType); - bool fakeBase = cmd.isSet("fakeBase"); - - u64 milli; - Timer timer; - - gTimer.setTimePoint("begin"); - if (role == Role::Receiver) - { - // construct a vector to stored the received messages. - std::unique_ptr backing0(new block[numOTs]); - std::unique_ptr backing1(new block[numOTs]); - span choice(backing0.get(), numOTs); - span msgs(backing1.get(), numOTs); - gTimer.setTimePoint("recver.msg.alloc"); - - OtExtRecver receiver; - receiver.mMultType = mulType; - receiver.configure(numOTs); - gTimer.setTimePoint("recver.config"); - - // generate base OTs - if (fakeBase) - { - auto nn = receiver.baseOtCount(); - std::vector> baseSendMsgs(nn); - PRNG pp(oc::ZeroBlock); - pp.get(baseSendMsgs.data(), baseSendMsgs.size()); - receiver.setBaseOts(baseSendMsgs); - } - else - { - cp::sync_wait(receiver.genSilentBaseOts(prng, chl)); - } - - // block until both parties are ready (optional). - cp::sync_wait(sync(chl, role)); - auto b = timer.setTimePoint("start"); - receiver.setTimePoint("start"); - gTimer.setTimePoint("recver.genBase"); - - // perform numOTs random OTs, the results will be written to msgs. - cp::sync_wait(receiver.silentReceive(choice, msgs, prng, chl)); - - // record the time. - receiver.setTimePoint("finish"); - auto e = timer.setTimePoint("finish"); - milli = std::chrono::duration_cast(e - b).count(); - } - else - { - gTimer.setTimePoint("sender.thrd.begin"); - - - std::unique_ptr backing(new block[numOTs]); - span msgs(backing.get(), numOTs); - - gTimer.setTimePoint("sender.msg.alloc"); - - OtExtSender sender; - sender.mMultType = mulType; - sender.configure(numOTs); - gTimer.setTimePoint("sender.config"); - timer.setTimePoint("start"); - - // generate base OTs - if (fakeBase) - { - auto nn = sender.baseOtCount(); - BitVector bits(nn); bits.randomize(prng); - std::vector> baseSendMsgs(nn); - std::vector baseRecvMsgs(nn); - PRNG pp(oc::ZeroBlock); - pp.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < nn; ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - sender.setBaseOts(baseRecvMsgs, bits); - } - else - { - cp::sync_wait(sender.genSilentBaseOts(prng, chl)); - } - - // block until both parties are ready (optional). - cp::sync_wait(sync(chl, role)); - auto b = sender.setTimePoint("start"); - gTimer.setTimePoint("sender.genBase"); - - // perform the OTs and write the random OTs to msgs. - block delta = prng.get(); - cp::sync_wait(sender.silentSend(delta, msgs, prng, chl)); - - sender.setTimePoint("finish"); - auto e = timer.setTimePoint("finish"); - milli = std::chrono::duration_cast(e - b).count(); - } - if (role == Role::Sender) - { - - lout << tag << - " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << - " || " << Color::Green << - std::setw(6) << std::setfill(' ') << milli << " ms " << - //std::setw(6) << std::setfill(' ') << com << " bytes" << - std::endl << Color::Default; - - if (cmd.getOr("v", 0) > 1) - lout << gTimer << std::endl; - - } - - cp::sync_wait(chl.flush()); -#endif - } + bool Vole_Examples(const CLP& cmd); } diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 1f70a490..25043f5e 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -4,13 +4,15 @@ #include #include "libOTe/Tools/Tools.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" #include "libOTe/Tools/QuasiCyclicCode.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" +#include "libOTe/Vole/Silent/SilentVoleSender.h" +#include "libOTe/Vole/Silent/SilentVoleReceiver.h" +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { @@ -98,7 +100,7 @@ namespace osuCrypto timer.setTimePoint("_____________________"); for (u64 i = 0; i < trials; ++i) { - code.dualEncode(x, y); + code.dualEncode(x, y, {}); timer.setTimePoint("encode"); } @@ -130,6 +132,8 @@ namespace osuCrypto // size for the accumulator (# random transitions) u64 a = cmd.getOr("a", roundUpTo(log2ceil(n), 8)); + bool gf128 = cmd.isSet("gf128"); + // verbose flag. bool v = cmd.isSet("v"); bool sys = cmd.isSet("sys"); @@ -153,10 +157,10 @@ namespace osuCrypto timer.setTimePoint("_____________________"); for (u64 i = 0; i < trials; ++i) { - if (sys) - code.dualEncode(x); + if(gf128) + code.dualEncode(x.begin(), {}); else - code.dualEncode(x, y); + code.dualEncode(x.begin(), {}); timer.setTimePoint("encode"); } @@ -168,74 +172,12 @@ namespace osuCrypto std::cout << verbose << std::endl; } - inline void encodeBench(CLP& cmd) - { -#ifdef ENABLE_INSECURE_SILVER - u64 trials = cmd.getOr("t", 10); - - // the message length of the code. - // The noise vector will have size n=2*m. - // the user can use - // -m X - // to state that exactly X rows should be used or - // -mm X - // to state that 2^X rows should be used. - u64 m = cmd.getOr("m", 1ull << cmd.getOr("mm", 10)); - - // the weight of the code, must be 5 or 11. - u64 w = cmd.getOr("w", 5); - - // verbose flag. - bool v = cmd.isSet("v"); - - - SilverCode code; - if (w == 11) - code = SilverCode::Weight11; - else if (w == 5) - code = SilverCode::Weight5; - else - { - std::cout << "invalid weight" << std::endl; - throw RTE_LOC; - } - - PRNG prng(ZeroBlock); - SilverEncoder encoder; - encoder.init(m, code); - - - std::vector x(encoder.cols()); - Timer timer, verbose; - - if (v) - encoder.setTimer(verbose); - - timer.setTimePoint("_____________________"); - for (u64 i = 0; i < trials; ++i) - { - encoder.dualEncode(x); - timer.setTimePoint("encode"); - } - - std::cout << timer << std::endl; - - if (v) - std::cout << verbose << std::endl; -#else - std::cout << "disabled, ENABLE_INSECURE_SILVER not defined " << std::endl; -#endif - } - - - inline void transpose(const CLP& cmd) { #ifdef ENABLE_AVX u64 trials = cmd.getOr("trials", 1ull << 18); { - AlignedArray data; Timer timer; @@ -373,6 +315,110 @@ namespace osuCrypto { std::cout << e.what() << std::endl; } +#else + std::cout << "ENABLE_SILENTOT = false" << std::endl; +#endif + } + + inline void VoleBench2(const CLP& cmd) + { +#ifdef ENABLE_SILENT_VOLE + try + { + + SilentVoleSender sender; + SilentVoleReceiver recver; + + u64 trials = cmd.getOr("t", 10); + + u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); + MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); + std::cout << multType << std::endl; + + recver.mMultType = multType; + sender.mMultType = multType; + + std::vector> baseSend(128); + std::vector baseRecv(128); + BitVector baseChoice(128); + PRNG prng(CCBlock); + baseChoice.randomize(prng); + for (u64 i = 0; i < 128; ++i) + { + baseSend[i] = prng.get(); + baseRecv[i] = baseSend[i][baseChoice[i]]; + } + +#ifdef ENABLE_SOFTSPOKEN_OT + sender.mOtExtRecver.emplace(); + sender.mOtExtSender.emplace(); + recver.mOtExtRecver.emplace(); + recver.mOtExtSender.emplace(); + sender.mOtExtRecver->setBaseOts(baseSend); + recver.mOtExtRecver->setBaseOts(baseSend); + sender.mOtExtSender->setBaseOts(baseRecv, baseChoice); + recver.mOtExtSender->setBaseOts(baseRecv, baseChoice); +#endif // ENABLE_SOFTSPOKEN_OT + + PRNG prng0(ZeroBlock), prng1(ZeroBlock); + block delta = prng0.get(); + + auto sock = coproto::LocalAsyncSocket::makePair(); + + Timer sTimer; + Timer rTimer; + sTimer.setTimePoint("start"); + rTimer.setTimePoint("start"); + + auto t0 = std::thread([&] { + for (u64 t = 0; t < trials; ++t) + { + auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); + + char c; + + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("__"); + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("s start"); + coproto::sync_wait(p0); + sTimer.setTimePoint("s done"); + } + }); + + + for (u64 t = 0; t < trials; ++t) + { + auto p1 = recver.silentReceiveInplace(n, prng1, sock[1]); + char c; + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); + + rTimer.setTimePoint("__"); + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); + + rTimer.setTimePoint("r start"); + coproto::sync_wait(p1); + rTimer.setTimePoint("r done"); + + } + + + t0.join(); + std::cout << sTimer << std::endl; + std::cout << rTimer << std::endl; + + std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } +#else + std::cout << "ENABLE_Silent_VOLE = false" << std::endl; #endif } } \ No newline at end of file diff --git a/frontend/main.cpp b/frontend/main.cpp index cfed28d5..f98584f1 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -11,9 +11,6 @@ using namespace osuCrypto; #include #include -#include -#include -#include #include #include @@ -28,35 +25,16 @@ using namespace osuCrypto; #include "ExampleSilent.h" #include "ExampleVole.h" #include "ExampleMessagePassing.h" -#include "libOTe/Tools/LDPC/LdpcImpulseDist.h" #include "libOTe/Tools/LDPC/Util.h" #include "cryptoTools/Crypto/RandomOracle.h" #include "libOTe/Tools/EACode/EAChecker.h" -static const std::vector -unitTestTag{ "u", "unitTest" }, -kos{ "k", "kos" }, -dkos{ "d", "dkos" }, -ssdelta{ "ssd", "ssdelta" }, -sshonest{ "ss", "sshonest" }, -smleakydelta{ "smld", "smleakydelta" }, -smalicious{ "sm", "smalicious" }, -kkrt{ "kk", "kkrt" }, -iknp{ "i", "iknp" }, -diknp{ "diknp" }, -oos{ "o", "oos" }, -moellerpopf{ "p", "moellerpopf" }, -ristrettopopf{ "r", "ristrettopopf" }, -mr{ "mr" }, -mrb{ "mrb" }, -Silent{ "s", "Silent" }, -vole{ "vole" }, -akn{ "a", "akn" }, -np{ "np" }, -simple{ "simplest" }, -simpleasm{ "simplest-asm" }; +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" +#include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" #ifdef ENABLE_IKNP +using namespace oc; + void minimal() { // Setup networking. See cryptoTools\frontend_cryptoTools\Tutorials\Network.cpp @@ -107,7 +85,6 @@ void minimal() #include "cryptoTools/Crypto/RandomOracle.h" int main(int argc, char** argv) { - CLP cmd; cmd.parse(argc, argv); bool flagSet = false; @@ -115,12 +92,12 @@ int main(int argc, char** argv) // various benchmarks if (cmd.isSet("bench")) { - if (cmd.isSet("silver")) - encodeBench(cmd); - else if (cmd.isSet("QC")) + if (cmd.isSet("QC")) QCCodeBench(cmd); else if (cmd.isSet("silent")) SilentOtBench(cmd); + else if (cmd.isSet("vole2")) + VoleBench2(cmd); else if (cmd.isSet("ea")) EACodeBench(cmd); else @@ -128,21 +105,6 @@ int main(int argc, char** argv) return 0; } - - // minimum distance checker for EA codes. - if (cmd.isSet("ea")) - { - EAChecker(cmd); - return 0; - } -#ifdef ENABLE_LDPC - if (cmd.isSet("ldpc")) - { - LdpcDecode_impulse(cmd); - return 0; - } -#endif - // unit tests. if (cmd.isSet(unitTestTag)) { @@ -165,96 +127,11 @@ int main(int argc, char** argv) // run various examples. - - -#ifdef ENABLE_SIMPLESTOT - flagSet |= runIf(baseOT_example, cmd, simple); -#endif - -#ifdef ENABLE_SIMPLESTOT_ASM - flagSet |= runIf(baseOT_example, cmd, simpleasm); -#endif - -#ifdef ENABLE_MRR_TWIST -#ifdef ENABLE_SSE - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepEKEPopf factory; - const char* domain = "EKE POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwist(factory)); - }, cmd, moellerpopf, { "eke" }); -#endif - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepMRPopf factory; - const char* domain = "MR POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMR(factory)); - }, cmd, moellerpopf, { "mrPopf" }); - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelPopf factory; - const char* domain = "Feistel POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistFeistel(factory)); - }, cmd, moellerpopf, { "feistel" }); - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelMulPopf factory; - const char* domain = "Feistel With Multiplication POPF OT example"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyTwistMul(factory)); -}, cmd, moellerpopf, { "feistelMul" }); -#endif - -#ifdef ENABLE_MRR - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelRistPopf factory; - const char* domain = "Feistel POPF OT example (Risretto)"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoy(factory)); - }, cmd, ristrettopopf, { "feistel" }); - - flagSet |= runIf([&](Role role, int totalOTs, int numThreads, std::string ip, std::string tag, CLP& clp) { - DomainSepFeistelMulRistPopf factory; - const char* domain = "Feistel With Multiplication POPF OT example (Risretto)"; - factory.Update(domain, std::strlen(domain)); - baseOT_example_from_ot(role, totalOTs, numThreads, ip, tag, clp, McRosRoyMul(factory)); - }, cmd, ristrettopopf, { "feistelMul" }); -#endif - -#ifdef ENABLE_MR - flagSet |= runIf(baseOT_example, cmd, mr); -#endif - -#ifdef ENABLE_IKNP - flagSet |= runIf(TwoChooseOne_example, cmd, iknp); -#endif - -#ifdef ENABLE_KOS - flagSet |= runIf(TwoChooseOne_example, cmd, kos); -#endif - -#ifdef ENABLE_DELTA_KOS - flagSet |= runIf(TwoChooseOne_example, cmd, dkos); -#endif - -#ifdef ENABLE_SOFTSPOKEN_OT - flagSet |= runIf(TwoChooseOne_example, SoftSpokenShOtReceiver<>>, cmd, sshonest); - flagSet |= runIf(TwoChooseOne_example, cmd, smalicious); -#endif - -#ifdef ENABLE_KKRT - flagSet |= runIf(NChooseOne_example, cmd, kkrt); -#endif - -#ifdef ENABLE_OOS - flagSet |= runIf(NChooseOne_example, cmd, oos); -#endif - - flagSet |= runIf(Silent_example, cmd, Silent); - flagSet |= runIf(Vole_example, cmd, vole); - + flagSet |= baseOT_examples(cmd); + flagSet |= TwoChooseOne_Examples(cmd); + flagSet |= NChooseOne_Examples(cmd); + flagSet |= Silent_Examples(cmd); + flagSet |= Vole_Examples(cmd); if (cmd.isSet("messagePassing")) { @@ -263,8 +140,6 @@ int main(int argc, char** argv) } - - if (flagSet == false) { diff --git a/frontend/util.cpp b/frontend/util.cpp index ac7a95ce..fa6bc91b 100644 --- a/frontend/util.cpp +++ b/frontend/util.cpp @@ -7,6 +7,7 @@ #include #include +#include "util.h" namespace osuCrypto { @@ -67,6 +68,7 @@ namespace osuCrypto else { u8 dummy[1]; + dummy[0] = 0; chl.asyncSend(dummy, 1); chl.recv(dummy, 1); chl.asyncSend(dummy, 1); @@ -79,6 +81,7 @@ namespace osuCrypto { u8 dummy[1]; + dummy[0] = 0; chl.asyncSend(dummy, 1); diff --git a/frontend/util.h b/frontend/util.h index 599a922c..27fb08fe 100644 --- a/frontend/util.h +++ b/frontend/util.h @@ -10,11 +10,34 @@ #include #include #include "libOTe/Tools/Coproto.h" - +#include namespace osuCrypto { + + static const std::vector + unitTestTag{ "u", "unitTest" }, + kos{ "k", "kos" }, + dkos{ "d", "dkos" }, + ssdelta{ "ssd", "ssdelta" }, + sshonest{ "ss", "sshonest" }, + smleakydelta{ "smld", "smleakydelta" }, + smalicious{ "sm", "smalicious" }, + kkrt{ "kk", "kkrt" }, + iknp{ "i", "iknp" }, + diknp{ "diknp" }, + oos{ "o", "oos" }, + moellerpopf{ "p", "moellerpopf" }, + ristrettopopf{ "r", "ristrettopopf" }, + mr{ "mr" }, + mrb{ "mrb" }, + Silent{ "s", "Silent" }, + vole{ "vole" }, + akn{ "a", "akn" }, + np{ "np" }, + simple{ "simplest" }, + simpleasm{ "simplest-asm" }; enum class Role { Sender, @@ -36,7 +59,7 @@ namespace osuCrypto //using ProtocolFunc = std::function; template - inline bool runIf(ProtocolFunc protocol, CLP & cmd, std::vector tag, + inline bool runIf(ProtocolFunc protocol, const CLP & cmd, std::vector tag, std::vector tag2 = std::vector()) { auto n = cmd.isSet("nn") diff --git a/libOTe/CMakeLists.txt b/libOTe/CMakeLists.txt index 0e690f9f..6e1b8c27 100644 --- a/libOTe/CMakeLists.txt +++ b/libOTe/CMakeLists.txt @@ -4,7 +4,7 @@ file(GLOB_RECURSE SRCS *.cpp *.c) set(SRCS "${SRCS}") -add_library(libOTe STATIC ${SRCS} "Tools/EACode/EACodeInstantiations.cpp") +add_library(libOTe STATIC ${SRCS} ) # make projects that include libOTe use this as an include folder diff --git a/libOTe/Tools/CoeffCtx.h b/libOTe/Tools/CoeffCtx.h new file mode 100644 index 00000000..ddf64eed --- /dev/null +++ b/libOTe/Tools/CoeffCtx.h @@ -0,0 +1,479 @@ +#pragma once +#include "libOTe/Vole/Noisy/NoisyVoleSender.h" +#include "cryptoTools/Common/BitIterator.h" +#include "cryptoTools/Common/BitVector.h" +#include +#include + +namespace osuCrypto { + + /* + * Primitive CoeffCtx for integers-like types + * + * This class implements the required functions to perform a vole + * + * The core functions are plus, minus, mul. However, additional function + * and types are required. + */ + struct CoeffCtxInteger + { + template + OC_FORCEINLINE void plus(R&& ret, F1&& lhs, F2&& rhs) { + ret = lhs + rhs; + } + + template + OC_FORCEINLINE void minus(R&& ret, F1&& lhs, F2&& rhs) { + ret = lhs - rhs; + } + template + OC_FORCEINLINE void mul(R&& ret, F1&& lhs, F2&& rhs) { + ret = lhs * rhs; + } + + template + OC_FORCEINLINE bool eq(F&& lhs, F&& rhs) { + return lhs == rhs; + } + + // is G a field? + template + OC_FORCEINLINE bool isField() { + return false; // default. + } + + // For the base field G is an extension fields, + // mulConst should multiply x by some constant in G to linearly + // mix the components. Most of the LPN codes this library + // uses are binary and so for extension field this would + // result in the componets not interactive. This can lead + // to a splitting attack. To fix this we multiply by some + // non-zero G element. + // + // If your type is a scaler, e.g. Fp or Z2k, just return x. + template + OC_FORCEINLINE void mulConst(F& ret, const F& x) + { + ret = x; + } + + // the bit size require to prepresent F + // the protocol will perform binary decomposition + // of F using this many bits + template + u64 bitSize() + { + return sizeof(F) * 8; + } + + // return the binary decomposition of x. This will be used to + // reconstruct x as + // + // x = sum_{i = 0,...,n} 2^i * binaryDecomposition(x)[i] + // + template + OC_FORCEINLINE BitVector binaryDecomposition(F& x) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + return { (u8*)&x, bitSize() }; + } + + // derive an F using the randomness b. + template + OC_FORCEINLINE void fromBlock(F& ret, const block& b) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + + if constexpr (sizeof(F) <= sizeof(block)) + { + // if small, just return the bytes of b + memcpy(&ret, &b, sizeof(F)); + } + else + { + // if large, we need to expand the seed. using fixed key AES in counter mode + // with b as the IV. + auto constexpr size = (sizeof(F) + sizeof(block) - 1) / sizeof(block); + std::array buffer; + mAesFixedKey.ecbEncCounterMode(b, buffer); + memcpy(&ret, buffer.data(), sizeof(ret)); + } + } + + // return the F element with value 2^power + template + OC_FORCEINLINE void powerOfTwo(F& ret, u64 power) { + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + memset(&ret, 0, sizeof(F)); + *BitIterator((u8*)&ret, power) = 1; + } + + // A vector like type that can be used to store + // temporaries. + // + // must have: + // * size() + // * operator[i] that returns the i'th F element reference. + // * begin() iterator over the F elements + // * end() iterator over the F elements + template + using Vec = AlignedUnVector; + + // resize Vec + template + void resize(VecF& f, u64 size) + { + f.resize(size); + } + + // the size of F when serialized. + template + u64 byteSize() + { + return sizeof(F); + } + + // copy a single F element. + template + OC_FORCEINLINE void copy(F& dst, const F& src) + { + dst = src; + } + + // copy [begin,...,end) into [dstBegin, ...) + // the iterators will point to the same types, i.e. F + template + OC_FORCEINLINE void copy( + SrcIter begin, + SrcIter end, + DstIter dstBegin) + { + using F1 = std::remove_reference_t; + using F2 = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "memcpy is used so must be trivially_copyable."); + static_assert(std::is_same_v, "src and destication types are not the same."); + + std::copy(begin, end, dstBegin); + } + + // deserialize [begin,...,end) into [dstBegin, ...) + // begin will be a u8 pointer/iterator. + // dstBegin will be an F pointer/iterator + template + void deserialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) + { + // as written this function is a bit more general than strictly neccessary + // due to serialize(...) redirecting here. + using SrcType = std::remove_reference_t; + using DstType = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "source serialization types must be trivially_copyable."); + static_assert(std::is_trivially_copyable::value, "destination serialization types must be trivially_copyable."); + + // how many source elem do we have? + auto srcN = std::distance(begin, end); + if (srcN) + { + // the source size in bytes + auto n = srcN * sizeof(SrcType); + + // The byte size must be a multiple of the destination element byte size. + if (n % sizeof(DstType)) + { + std::cout << "bad buffer size. the source buffer (byte) size is not a multiple of the distination value type size." LOCATION << std::endl; + std::terminate(); + } + // the number of destination elements. + auto dstN = n / sizeof(DstType); + + // make sure the pointer math work with this iterator type. + auto beginU8 = (u8*)&*begin; + auto dstBeginU8 = (u8*)&*dstBegin; + + auto dstBackPtr = dstBeginU8 + (n - sizeof(DstType)); + auto dstBackIter = dstBegin + (dstN - 1); + + // try to deref the back. might bounds check. + // And check that the pointer math works + if (dstBackPtr != (u8*)&*dstBackIter) + { + std::cout << "bad destination iterator type. pointer arithemtic not correct. " LOCATION << std::endl; + std::terminate(); + } + + auto srcBackPtr = beginU8 + (n - sizeof(SrcType)); + auto srcBackIter = begin + (srcN - 1); + + // try to deref the back. might bounds check. + // And check that the pointer math works + if (srcBackPtr != (u8*)&*srcBackIter) + { + std::cout << "bad source iterator type. pointer arithemtic not correct. " LOCATION << std::endl; + std::terminate(); + } + + // memcpy the bytes + std::memcpy(dstBeginU8, beginU8, n); + } + } + + // serialize [begin,...,end) into [dstBegin, ...) + // begin will be an F pointer/iterator + // dstBegin will be a byte pointer/iterator. + template + void serialize(SrcIter&& begin, SrcIter&& end, DstIter&& dstBegin) + { + // for primitive types serialization and deserializaion + // are the same, a memcpy. + deserialize(begin, end, dstBegin); + } + + + // If unsure you can just return iter. iter will be + // a Vec::iterator or const of it. + // This function allows for some compiler optimziations/ + // The idea is to return a pointer with the __restrict + // attibute. If this does not make sense for your Vec::iterator, + // just return the iterator. + template + F* __restrict restrictPtr(Iter iter) + { + return &*iter; + } + + + // fill the range [begin,..., end) with zeros. + // begin will be an F pointer/iterator. + template + void zero(Iter begin, Iter end) + { + using F = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); + + if (begin != end) + { + auto n = std::distance(begin, end); + assert(n > 0); + memset(&*begin, 0, n * sizeof(F)); + } + } + + // fill the range [begin,..., end) with ones. + // begin will be an F pointer/iterator. + template + void one(Iter begin, Iter end) + { + using F = std::remove_reference_t; + static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); + + + if (begin != end) + { + auto n = std::distance(begin, end); + assert(n > 0); + memset(&*begin, 0, n * sizeof(F)); + while (begin != end) + { + auto& v = *begin++; + *(u8*)&v = 1; + } + } + } + + // convert F into a string + template + std::string str(F&& f) + { + std::stringstream ss; + if constexpr (std::is_same_v, u8>) + ss << int(f); + else + ss << f; + + return ss.str(); + } + + }; + + + + // block does not use operator* + struct CoeffCtxGF2 : CoeffCtxInteger + { + template + OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + ret = lhs ^ rhs; + } + template + OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) { + ret = lhs ^ rhs; + } + template + OC_FORCEINLINE void mul(F& ret, const F& lhs, const F& rhs) { + ret = lhs & rhs; + } + + template + OC_FORCEINLINE void mul(F& ret, const F& lhs, const bool& rhs) { + ret = rhs ? lhs : zeroElem(); + } + + // the bit size require to prepresent F + // the protocol will perform binary decomposition + // of F using this many bits + template + u64 bitSize() + { + if (std::is_same::value) + return 1; + else + return sizeof(F) * 8; + } + + // is F a field? + template + OC_FORCEINLINE bool isField() { + return true; // default. + } + + template + static OC_FORCEINLINE constexpr F zeroElem() + { + static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); + F r; + memset(&r, 0, sizeof(F)); + return r; + } + }; + + + // block does not use operator* + struct CoeffCtxGF128 : CoeffCtxGF2 + { + + OC_FORCEINLINE void mul(block& ret, const block& lhs, const block& rhs) { + ret = lhs.gf128Mul(rhs); + } + + // ret = x * 4234123421 mod 2^127 - 135 + OC_FORCEINLINE void mulConst(block& ret, const block& x) + { + // multiplication y modulo mod + block y(0, 4234123421); + +#ifdef ENABLE_SSE + static const constexpr std::uint64_t mod = 0b10000111; + const __m128i modulus = _mm_loadl_epi64((const __m128i*) & (mod)); + + block xy1 = _mm_clmulepi64_si128(x, y, (int)0x00); + block xy2 = _mm_clmulepi64_si128(x, y, 0x01); + xy1 = xy1 ^ _mm_slli_si128(xy2, 8); + xy2 = _mm_srli_si128(xy2, 8); + + /* reduce w.r.t. high half of mul256_high */ + auto tmp = _mm_clmulepi64_si128(xy2, modulus, 0x00); + ret = _mm_xor_si128(xy1, tmp); +#else + ret = x.gf128Mul(y); +#endif + } + }; + + + template + struct CoeffCtxArray : CoeffCtxInteger + { + using F = std::array; + + OC_FORCEINLINE void plus(F& ret, const F& lhs, const F& rhs) { + for (u64 i = 0; i < lhs.size(); ++i) { + ret[i] = lhs[i] + rhs[i]; + } + } + + OC_FORCEINLINE void plus(G& ret, const G& lhs, const G& rhs) { + ret = lhs + rhs; + } + + OC_FORCEINLINE void minus(F& ret, const F& lhs, const F& rhs) + { + for (u64 i = 0; i < lhs.size(); ++i) { + ret[i] = lhs[i] - rhs[i]; + } + } + + OC_FORCEINLINE void minus(G& ret, const G& lhs, const G& rhs) { + ret = lhs - rhs; + } + + OC_FORCEINLINE void mul(F& ret, const F& lhs, const G& rhs) + { + for (u64 i = 0; i < lhs.size(); ++i) { + ret[i] = lhs[i] * rhs; + } + } + + OC_FORCEINLINE bool eq(const F& lhs, const F& rhs) + { + for (u64 i = 0; i < lhs.size(); ++i) { + if (lhs[i] != rhs[i]) + return false; + } + return true; + } + + OC_FORCEINLINE bool eq(const G& lhs, const G& rhs) + { + return lhs == rhs; + } + + // convert F into a string + std::string str(const F& f) + { + auto delim = "{ "; + std::stringstream ss; + for (u64 i = 0; i < f.size(); ++i) + { + ss << std::exchange(delim, ", "); + + if constexpr (std::is_same_v, u8>) + ss << int(f[i]); + else + ss << f[i]; + + } + ss << " }"; + + return ss.str(); + } + + // convert G into a string + std::string str(const G& g) + { + std::stringstream ss; + if constexpr (std::is_same_v, u8>) + ss << int(g); + else + ss << g; + + return ss.str(); + } + + }; + + template + struct DefaultCoeffCtx_t { + using type = CoeffCtxInteger; + }; + + // GF128 vole + template<> + struct DefaultCoeffCtx_t { + using type = CoeffCtxGF128; + }; + + // OT, gf2 + template<> struct DefaultCoeffCtx_t { + using type = CoeffCtxGF2; + }; + + template + using DefaultCoeffCtx = typename DefaultCoeffCtx_t::type; +} diff --git a/libOTe/Tools/EACode/EACode.cpp b/libOTe/Tools/EACode/EACode.cpp deleted file mode 100644 index 3ae76cd9..00000000 --- a/libOTe/Tools/EACode/EACode.cpp +++ /dev/null @@ -1,142 +0,0 @@ - -#include "EACode.h" - -namespace osuCrypto -{ - - // Compute w = G * e. - template - void EACode::dualEncode(span e, span w) - { - if (mCodeSize == 0) - throw RTE_LOC; - if (e.size() != mCodeSize) - throw RTE_LOC; - if (w.size() != mMessageSize) - throw RTE_LOC; - - - setTimePoint("EACode.encode.begin"); - - accumulate(e); - - setTimePoint("EACode.encode.accumulate"); - - expand(e, w); - setTimePoint("EACode.encode.expand"); - } - - - - - template - void EACode::dualEncode2( - span e0, - span w0, - span e1, - span w1 - ) - { - if (mCodeSize == 0) - throw RTE_LOC; - if (e0.size() != mCodeSize) - throw RTE_LOC; - if (e1.size() != mCodeSize) - throw RTE_LOC; - if (w0.size() != mMessageSize) - throw RTE_LOC; - if (w1.size() != mMessageSize) - throw RTE_LOC; - - setTimePoint("EACode.encode.begin"); - - accumulate(e0, e1); - //accumulate(e1); - - setTimePoint("EACode.encode.accumulate"); - - expand(e0, w0); - expand(e1, w1); - setTimePoint("EACode.encode.expand"); - } - - - - template - void EACode::accumulate(span x) - { - if (x.size() != mCodeSize) - throw RTE_LOC; - auto main = (u64)std::max(0, mCodeSize - 1); - T* __restrict xx = x.data(); - - for (u64 i = 0; i < main; ++i) - { - auto xj = xx[i + 1] ^ xx[i]; - xx[i + 1] = xj; - } - } - - template - void EACode::accumulate(span x, span x2) - { - if (x.size() != mCodeSize) - throw RTE_LOC; - if (x2.size() != mCodeSize) - throw RTE_LOC; - - - auto main = (u64)std::max(0, mCodeSize - 1); - T* __restrict xx1 = x.data(); - T2* __restrict xx2 = x2.data(); - - for (u64 i = 0; i < main; ++i) - { - auto x1j = xx1[i + 1] ^ xx1[i]; - auto x2j = xx2[i + 1] ^ xx2[i]; - xx1[i + 1] = x1j; - xx2[i + 1] = x2j; - } - } - -#ifndef EACODE_INSTANTIONATIONS - - SparseMtx EACode::getAPar() const - { - PointList AP(mCodeSize, mCodeSize);; - for (u64 i = 0; i < mCodeSize; ++i) - { - AP.push_back(i, i); - if (i + 1 < mCodeSize) - AP.push_back(i + 1, i); - } - return AP; - } - - SparseMtx EACode::getA() const - { - auto APar = getAPar(); - - auto A = DenseMtx::Identity(mCodeSize); - - for (u64 i = 0; i < mCodeSize; ++i) - { - for (auto y : APar.col(i)) - { - //std::cout << y << " "; - if (y != i) - { - auto ay = A.row(y); - auto ai = A.row(i); - ay ^= ai; - - } - } - - //std::cout << "\n" << A << std::endl; - } - - return A.sparse(); - } -#endif -} \ No newline at end of file diff --git a/libOTe/Tools/EACode/EACode.h b/libOTe/Tools/EACode/EACode.h index bbe76ac4..e99ddf67 100644 --- a/libOTe/Tools/EACode/EACode.h +++ b/libOTe/Tools/EACode/EACode.h @@ -8,16 +8,11 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" -#include "Expander.h" +#include "libOTe/Tools/ExConvCode/Expander.h" #include "Util.h" - namespace osuCrypto { -#if __cplusplus >= 201703L -#define EA_CONSTEXPR constexpr -#else -#define EA_CONSTEXPR -#endif + // The encoder for the generator matrix G = B * A. // B is the expander while A is the accumulator. // @@ -30,27 +25,25 @@ namespace osuCrypto { public: using ExpanderCode::config; - using ExpanderCode::getB; - // Compute w = G * e. - template - void dualEncode(span e, span w); + template + void dualEncode(span e, span w, Ctx ctx); - template + template void dualEncode2( span e0, span w0, span e1, - span w1 + span w1, Ctx ctx ); - template - void accumulate(span x); - template - void accumulate(span x, span x2); + template + void accumulate(span x, Ctx ctx); + template + void accumulate(span x, span x2, Ctx ctx); // Get the parity check version of the accumulator SparseMtx getAPar() const; @@ -58,5 +51,137 @@ namespace osuCrypto SparseMtx getA() const; }; -#undef EA_CONSTEXPR + // Compute w = G * e. + template + void EACode::dualEncode(span e, span w, Ctx ctx) + { + if (mCodeSize == 0) + throw RTE_LOC; + if (e.size() != mCodeSize) + throw RTE_LOC; + if (w.size() != mMessageSize) + throw RTE_LOC; + + + setTimePoint("EACode.encode.begin"); + + accumulate(e, ctx); + + setTimePoint("EACode.encode.accumulate"); + + expand(e.begin(), w.begin(), ctx); + setTimePoint("EACode.encode.expand"); + } + + + + + template + void EACode::dualEncode2( + span e0, + span w0, + span e1, + span w1, Ctx ctx + ) + { + if (mCodeSize == 0) + throw RTE_LOC; + if (e0.size() != mCodeSize) + throw RTE_LOC; + if (e1.size() != mCodeSize) + throw RTE_LOC; + if (w0.size() != mMessageSize) + throw RTE_LOC; + if (w1.size() != mMessageSize) + throw RTE_LOC; + + setTimePoint("EACode.encode.begin"); + + accumulate(e0, e1, ctx); + //accumulate(e1); + + setTimePoint("EACode.encode.accumulate"); + + expand(e0.begin(), w0.begin(), ctx); + expand(e1.begin(), w1.begin(), ctx); + setTimePoint("EACode.encode.expand"); + } + + + + template + void EACode::accumulate(span x, Ctx ctx) + { + if (x.size() != mCodeSize) + throw RTE_LOC; + auto main = (u64)std::max(0, mCodeSize - 1); + T* __restrict xx = x.data(); + + for (u64 i = 0; i < main; ++i) + { + ctx.plus(xx[i + 1], xx[i + 1], xx[i]); + ctx.mulConst(xx[i + 1], xx[i + 1]); + } + } + + template + void EACode::accumulate(span x, span x2, Ctx ctx) + { + if (x.size() != mCodeSize) + throw RTE_LOC; + if (x2.size() != mCodeSize) + throw RTE_LOC; + + auto main = (u64)std::max(0, mCodeSize - 1); + T* __restrict xx1 = x.data(); + T2* __restrict xx2 = x2.data(); + + for (u64 i = 0; i < main; ++i) + { + ctx.plus(xx1[i + 1], xx1[i + 1], xx1[i]); + ctx.plus(xx2[i + 1], xx2[i + 1], xx2[i]); + ctx.mulConst(xx1[i + 1], xx1[i + 1]); + ctx.mulConst(xx2[i + 1], xx2[i + 1]); + + } + } + + + inline SparseMtx EACode::getAPar() const + { + PointList AP(mCodeSize, mCodeSize);; + for (u64 i = 0; i < mCodeSize; ++i) + { + AP.push_back(i, i); + if (i + 1 < mCodeSize) + AP.push_back(i + 1, i); + } + return AP; + } + + inline SparseMtx EACode::getA() const + { + auto APar = getAPar(); + + auto A = DenseMtx::Identity(mCodeSize); + + for (u64 i = 0; i < mCodeSize; ++i) + { + for (auto y : APar.col(i)) + { + //std::cout << y << " "; + if (y != i) + { + auto ay = A.row(y); + auto ai = A.row(i); + ay ^= ai; + + } + } + + //std::cout << "\n" << A << std::endl; + } + + return A.sparse(); + } } diff --git a/libOTe/Tools/EACode/EACodeInstantiations.cpp b/libOTe/Tools/EACode/EACodeInstantiations.cpp deleted file mode 100644 index 502398c3..00000000 --- a/libOTe/Tools/EACode/EACodeInstantiations.cpp +++ /dev/null @@ -1,15 +0,0 @@ - -#define EACODE_INSTANTIONATIONS -#include "EACode.cpp" - -namespace osuCrypto -{ - template void EACode::dualEncode(span e, span w); - template void EACode::dualEncode(span e, span w); - template void EACode::accumulate(span e); - template void EACode::accumulate(span e); - - template void EACode::dualEncode2(span e, span w, span e2, span w2); - template void EACode::dualEncode2(span e, span w, span e2, span w2); - template void EACode::accumulate(spane0, span e); -} \ No newline at end of file diff --git a/libOTe/Tools/EACode/Expander.cpp b/libOTe/Tools/EACode/Expander.cpp deleted file mode 100644 index 98bbb99e..00000000 --- a/libOTe/Tools/EACode/Expander.cpp +++ /dev/null @@ -1,433 +0,0 @@ - -#include "Expander.h" -#include "cryptoTools/Common/Range.h" - -namespace osuCrypto -{ - template - typename std::enable_if::type - ExpanderCode::expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const - { - auto r = prng.get(); - return ee[r]; - } - - template - typename std::enable_if::type - ExpanderCode::expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng) const - { - auto r = prng.get(); - - if (Add) - { - *y1 = *y1 ^ ee1[r]; - *y2 = *y2 ^ ee2[r]; - } - else - { - - *y1 = ee1[r]; - *y2 = ee2[r]; - } - } - - - template - OC_FORCEINLINE typename std::enable_if<(count > 1), T>::type - ExpanderCode::expandOne(const T* __restrict ee, detail::ExpanderModd& prng)const - { - if constexpr (count >= 8) - { - u64 rr[8]; - T w[8]; - rr[0] = prng.get(); - rr[1] = prng.get(); - rr[2] = prng.get(); - rr[3] = prng.get(); - rr[4] = prng.get(); - rr[5] = prng.get(); - rr[6] = prng.get(); - rr[7] = prng.get(); - - w[0] = ee[rr[0]]; - w[1] = ee[rr[1]]; - w[2] = ee[rr[2]]; - w[3] = ee[rr[3]]; - w[4] = ee[rr[4]]; - w[5] = ee[rr[5]]; - w[6] = ee[rr[6]]; - w[7] = ee[rr[7]]; - - auto ww = - w[0] ^ - w[1] ^ - w[2] ^ - w[3] ^ - w[4] ^ - w[5] ^ - w[6] ^ - w[7]; - - if constexpr (count > 8) - ww = ww ^ expandOne(ee, prng); - return ww; - } - else - { - - auto r = prng.get(); - auto ww = expandOne(ee, prng); - return ww ^ ee[r]; - } - } - - - template - OC_FORCEINLINE typename std::enable_if<(count > 1)>::type - ExpanderCode::expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng)const - { - if constexpr (count >= 8) - { - u64 rr[8]; - T w1[8]; - T2 w2[8]; - rr[0] = prng.get(); - rr[1] = prng.get(); - rr[2] = prng.get(); - rr[3] = prng.get(); - rr[4] = prng.get(); - rr[5] = prng.get(); - rr[6] = prng.get(); - rr[7] = prng.get(); - - w1[0] = ee1[rr[0]]; - w1[1] = ee1[rr[1]]; - w1[2] = ee1[rr[2]]; - w1[3] = ee1[rr[3]]; - w1[4] = ee1[rr[4]]; - w1[5] = ee1[rr[5]]; - w1[6] = ee1[rr[6]]; - w1[7] = ee1[rr[7]]; - - w2[0] = ee2[rr[0]]; - w2[1] = ee2[rr[1]]; - w2[2] = ee2[rr[2]]; - w2[3] = ee2[rr[3]]; - w2[4] = ee2[rr[4]]; - w2[5] = ee2[rr[5]]; - w2[6] = ee2[rr[6]]; - w2[7] = ee2[rr[7]]; - - auto ww1 = - w1[0] ^ - w1[1] ^ - w1[2] ^ - w1[3] ^ - w1[4] ^ - w1[5] ^ - w1[6] ^ - w1[7]; - auto ww2 = - w2[0] ^ - w2[1] ^ - w2[2] ^ - w2[3] ^ - w2[4] ^ - w2[5] ^ - w2[6] ^ - w2[7]; - - if constexpr (count > 8) - { - T yy1; - T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); - ww1 = ww1 ^ yy1; - ww2 = ww2 ^ yy2; - } - - if constexpr (Add) - { - *y1 = *y1 ^ ww1; - *y2 = *y2 ^ ww2; - } - else - { - *y1 = ww1; - *y2 = ww2; - } - - } - else - { - - auto r = prng.get(); - if constexpr (Add) - { - auto w1 = ee1[r]; - auto w2 = ee2[r]; - expandOne(ee1, ee2, y1, y2, prng); - *y1 = *y1 ^ w1; - *y2 = *y2 ^ w2; - - } - else - { - - T yy1; - T2 yy2; - expandOne(ee1, ee2, &yy1, &yy2, prng); - *y1 = ee1[r] ^ yy1; - *y2 = ee2[r] ^ yy2; - } - } - } - - - - template - void ExpanderCode::expand( - span e, - span w) const - { - assert(w.size() == mMessageSize); - assert(e.size() == mCodeSize); - detail::ExpanderModd prng(mSeed, mCodeSize); - - const T* __restrict ee = e.data(); - T* __restrict ww = w.data(); - - auto main = mMessageSize / 8 * 8; - u64 i = 0; - - for (; i < main; i += 8) - { -#define CASE(I) \ - case I:\ - if constexpr(Add)\ - {\ - ww[i + 0] = ww[i + 0] ^ expandOne(ee, prng);\ - ww[i + 1] = ww[i + 1] ^ expandOne(ee, prng);\ - ww[i + 2] = ww[i + 2] ^ expandOne(ee, prng);\ - ww[i + 3] = ww[i + 3] ^ expandOne(ee, prng);\ - ww[i + 4] = ww[i + 4] ^ expandOne(ee, prng);\ - ww[i + 5] = ww[i + 5] ^ expandOne(ee, prng);\ - ww[i + 6] = ww[i + 6] ^ expandOne(ee, prng);\ - ww[i + 7] = ww[i + 7] ^ expandOne(ee, prng);\ - }\ - else\ - {\ - ww[i + 0] = expandOne(ee, prng);\ - ww[i + 1] = expandOne(ee, prng);\ - ww[i + 2] = expandOne(ee, prng);\ - ww[i + 3] = expandOne(ee, prng);\ - ww[i + 4] = expandOne(ee, prng);\ - ww[i + 5] = expandOne(ee, prng);\ - ww[i + 6] = expandOne(ee, prng);\ - ww[i + 7] = expandOne(ee, prng);\ - }\ - break - - switch (mExpanderWeight) - { - CASE(5); - CASE(7); - CASE(9); - CASE(11); - CASE(21); - CASE(40); - default: - for (u64 jj = 0; jj < 8; ++jj) - { - auto r = prng.get(); - auto wv = ee[r]; - - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv = wv ^ ee[r]; - } - if constexpr (Add) - ww[i + jj] = ww[i + jj] ^ wv; - else - ww[i + jj] = wv; - - } - } -#undef CASE - } - - for (; i < mMessageSize; ++i) - { - auto wv = ee[prng.get()]; - for (auto j = 1ull; j < mExpanderWeight; ++j) - wv = wv ^ ee[prng.get()]; - - if constexpr (Add) - ww[i] = ww[i] ^ wv; - else - ww[i] = wv; - } - } - - - - template - void ExpanderCode::expand( - span e1, - span e2, - span w1, - span w2 - ) const - { - assert(w1.size() == mMessageSize); - assert(w2.size() == mMessageSize); - assert(e1.size() == mCodeSize); - assert(e2.size() == mCodeSize); - detail::ExpanderModd prng(mSeed, mCodeSize); - - const T* __restrict ee1 = e1.data(); - const T2* __restrict ee2 = e2.data(); - T* __restrict ww1 = w1.data(); - T2* __restrict ww2 = w2.data(); - - auto main = mMessageSize / 8 * 8; - u64 i = 0; - - for (; i < main; i += 8) - { -#define CASE(I) \ - case I:\ - expandOne(ee1, ee2, &ww1[i + 0], &ww2[i + 0], prng);\ - expandOne(ee1, ee2, &ww1[i + 1], &ww2[i + 1], prng);\ - expandOne(ee1, ee2, &ww1[i + 2], &ww2[i + 2], prng);\ - expandOne(ee1, ee2, &ww1[i + 3], &ww2[i + 3], prng);\ - expandOne(ee1, ee2, &ww1[i + 4], &ww2[i + 4], prng);\ - expandOne(ee1, ee2, &ww1[i + 5], &ww2[i + 5], prng);\ - expandOne(ee1, ee2, &ww1[i + 6], &ww2[i + 6], prng);\ - expandOne(ee1, ee2, &ww1[i + 7], &ww2[i + 7], prng);\ - break - - switch (mExpanderWeight) - { - CASE(5); - CASE(7); - CASE(9); - CASE(11); - CASE(21); - CASE(40); - default: - for (u64 jj = 0; jj < 8; ++jj) - { - auto r = prng.get(); - auto wv1 = ee1[r]; - auto wv2 = ee2[r]; - - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv1 = wv1 ^ ee1[r]; - wv2 = wv2 ^ ee2[r]; - } - if constexpr (Add) - { - ww1[i + jj] = ww1[i + jj] ^ wv1; - ww2[i + jj] = ww2[i + jj] ^ wv2; - } - else - { - - ww1[i + jj] = wv1; - ww2[i + jj] = wv2; - } - } - } -#undef CASE - } - - for (; i < mMessageSize; ++i) - { - auto r = prng.get(); - auto wv1 = ee1[r]; - auto wv2 = ee2[r]; - for (auto j = 1ull; j < mExpanderWeight; ++j) - { - r = prng.get(); - wv1 = wv1 ^ ee1[r]; - wv2 = wv2 ^ ee2[r]; - - } - if constexpr (Add) - { - ww1[i] = ww1[i] ^ wv1; - ww2[i] = ww2[i] ^ wv2; - } - else - { - ww1[i] = wv1; - ww2[i] = wv2; - } - } - } - -#ifndef EXPANDER_INSTANTATIONS - SparseMtx ExpanderCode::getB() const - { - //PRNG prng(mSeed); - detail::ExpanderModd prng(mSeed, mCodeSize); - PointList points(mMessageSize, mCodeSize); - - std::vector row(mExpanderWeight); - - { - - for (auto i : rng(mMessageSize)) - { - row[0] = prng.get(); - //points.push_back(i, row[0]); - for (auto j : rng(1, mExpanderWeight)) - { - //do { - row[j] = prng.get(); - //} while - auto iter = std::find(row.data(), row.data() + j, row[j]); - if (iter != row.data() + j) - { - row[j] = ~0ull; - *iter = ~0ull; - } - //throw RTE_LOC; - - } - for (auto j : rng(mExpanderWeight)) - { - - if (row[j] != ~0ull) - { - //std::cout << row[j] << " "; - points.push_back(i, row[j]); - } - else - { - //std::cout << "* "; - } - } - //std::cout << std::endl; - } - } - - return points; - } -#endif - -} diff --git a/libOTe/Tools/EACode/Expander.h b/libOTe/Tools/EACode/Expander.h deleted file mode 100644 index 5cd0e737..00000000 --- a/libOTe/Tools/EACode/Expander.h +++ /dev/null @@ -1,98 +0,0 @@ -// © 2023 Peter Rindal. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -#pragma once - -#include "cryptoTools/Common/Defines.h" -#include "libOTe/Tools/LDPC/Mtx.h" -#include "Util.h" - -namespace osuCrypto -{ - - // The encoder for the expander matrix B. - // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly - // with fixed row weight mExpanderWeight. - class ExpanderCode - { - public: - - void config( - u64 messageSize, - u64 codeSize = 0 /* default is 5* messageSize */, - u64 expanderWeight = 21, - block seed = block(33333, 33333)) - { - mMessageSize = messageSize; - mCodeSize = codeSize; - mExpanderWeight = expanderWeight; - mSeed = seed; - - } - - // the seed that generates the code. - block mSeed = block(0, 0); - - // The message size of the code. K. - u64 mMessageSize = 0; - - // The codeword size of the code. n. - u64 mCodeSize = 0; - - // The row weight of the B matrix. - u64 mExpanderWeight = 0; - - u64 parityRows() const { return mCodeSize - mMessageSize; } - u64 parityCols() const { return mCodeSize; } - - u64 generatorRows() const { return mMessageSize; } - u64 generatorCols() const { return mCodeSize; } - - - - template - typename std::enable_if<(count > 1), T>::type - expandOne(const T* __restrict ee, detail::ExpanderModd& prng)const; - - template - typename std::enable_if<(count > 1)>::type - expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng)const; - - template - typename std::enable_if::type - expandOne(const T* __restrict ee, detail::ExpanderModd& prng) const; - - template - typename std::enable_if::type - expandOne( - const T* __restrict ee1, - const T2* __restrict ee2, - T* __restrict y1, - T2* __restrict y2, - detail::ExpanderModd& prng) const; - - template - void expand( - span e, - span w) const; - - template - void expand( - span e1, - span e2, - span w1, - span w2 - ) const; - - SparseMtx getB() const; - - }; -} diff --git a/libOTe/Tools/EACode/ExpanderInstantiations.cpp b/libOTe/Tools/EACode/ExpanderInstantiations.cpp deleted file mode 100644 index 130e095d..00000000 --- a/libOTe/Tools/EACode/ExpanderInstantiations.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#define EXPANDER_INSTANTATIONS -#include "Expander.cpp" - -namespace osuCrypto -{ - - template void ExpanderCode::expand(span e, span w) const; - template void ExpanderCode::expand(span e, span w) const; - template void ExpanderCode::expand(span e, span w) const; - template void ExpanderCode::expand(span e, span w) const; - - template void ExpanderCode::expand( - span e, span e2, - span w, span w2) const; - template void ExpanderCode::expand( - span e, span e2, - span w, span w2) const; - - - template void ExpanderCode::expand( - span e, span e2, - span w, span w2) const; - //template void ExpanderCode::expand( - // span e, span e2, - // span w, span w2) const; -} \ No newline at end of file diff --git a/libOTe/Tools/EACode/Util.h b/libOTe/Tools/EACode/Util.h index e53a7ff1..8d08fe0b 100644 --- a/libOTe/Tools/EACode/Util.h +++ b/libOTe/Tools/EACode/Util.h @@ -86,14 +86,6 @@ namespace osuCrypto b[6] = AES::roundEnc(b[6], k[6]); b[7] = AES::roundEnc(b[7], k[7]); - b[0] = b[0] ^ k[0]; - b[1] = b[1] ^ k[1]; - b[2] = b[2] ^ k[2]; - b[3] = b[3] ^ k[3]; - b[4] = b[4] ^ k[4]; - b[5] = b[5] ^ k[5]; - b[6] = b[6] ^ k[6]; - b[7] = b[7] ^ k[7]; } auto src = prng.mBuffer.data(); @@ -121,7 +113,6 @@ namespace osuCrypto else { memcpy(dst, src, vals.size() * sizeof(value_type)); - //throw RTE_LOC; //assert(vals.size() % 32 == 0); for (u64 i = 0; i < vals.size(); i += 32) doMod32(vals.data() + i, &mod, modVal); diff --git a/libOTe/Tools/ExConvCode/ExConvCode.cpp b/libOTe/Tools/ExConvCode/ExConvCode.cpp index c28cefed..7f843ecb 100644 --- a/libOTe/Tools/ExConvCode/ExConvCode.cpp +++ b/libOTe/Tools/ExConvCode/ExConvCode.cpp @@ -1,557 +1,85 @@ #include "ExConvCode.h" - namespace osuCrypto { -#ifdef ENABLE_SSE - - using My__m128 = __m128; - -#else - using My__m128 = block; - - inline My__m128 _mm_load_ps(float* b) { return *(block*)b; } - - // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_ps&ig_expand=557 - inline My__m128 _mm_blendv_ps(My__m128 a, My__m128 b, My__m128 mask) - { - My__m128 dst; - for (u64 j = 0; j < 4; ++j) - { - if (mask.get(j) < 0) - dst.set(j, b.get(j)); - else - dst.set(j, a.get(j)); - } - return dst; - } - - - inline My__m128 _mm_setzero_ps() { return ZeroBlock; } -#endif - - // Compute e = G * e. - template - void ExConvCode::dualEncode(span e) - { - if (e.size() != mCodeSize) - throw RTE_LOC; - - if (mSystematic) - { - auto d = e.subspan(mMessageSize); - setTimePoint("ExConv.encode.begin"); - accumulate(d); - setTimePoint("ExConv.encode.accumulate"); - mExpander.expand(d, e.subspan(0, mMessageSize)); - setTimePoint("ExConv.encode.expand"); - } - else - { - oc::AlignedUnVector w(mMessageSize); - dualEncode(e, w); - memcpy(e.data(), w.data(), w.size() * sizeof(T)); - setTimePoint("ExConv.encode.memcpy"); - - } - } - - - // Compute e = G * e. - template - void ExConvCode::dualEncode2(span e0, span e1) - { - if (e0.size() != mCodeSize) - throw RTE_LOC; - if (e1.size() != mCodeSize) - throw RTE_LOC; - - if (mSystematic) - { - auto d0 = e0.subspan(mMessageSize); - auto d1 = e1.subspan(mMessageSize); - setTimePoint("ExConv.encode.begin"); - accumulate(d0, d1); - setTimePoint("ExConv.encode.accumulate"); - mExpander.expand( - d0, d1, - e0.subspan(0, mMessageSize), - e1.subspan(0, mMessageSize)); - setTimePoint("ExConv.encode.expand"); - } - else - { - //oc::AlignedUnVector w0(mMessageSize); - //dualEncode(e, w); - //memcpy(e.data(), w.data(), w.size() * sizeof(T)); - //setTimePoint("ExConv.encode.memcpy"); - - // not impl. - throw RTE_LOC; - - } - } - - // Compute w = G * e. - template - void ExConvCode::dualEncode(span e, span w) - { - if (e.size() != mCodeSize) - throw RTE_LOC; - - if (w.size() != mMessageSize) - throw RTE_LOC; - - if (mSystematic) - { - dualEncode(e); - memcpy(w.data(), e.data(), w.size() * sizeof(T)); - setTimePoint("ExConv.encode.memcpy"); - } - else - { - - setTimePoint("ExConv.encode.begin"); - - accumulate(e); - - setTimePoint("ExConv.encode.accumulate"); - - mExpander.expand(e, w); - setTimePoint("ExConv.encode.expand"); - } - } - - inline void refill(PRNG& prng) - { - assert(prng.mBuffer.size() == 256); - //block b[8]; - for (u64 i = 0; i < 256; i += 8) - { - //auto idx = mPrng.mBuffer[i].get(); - block* __restrict b = prng.mBuffer.data() + i; - block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); - //for (u64 j = 0; j < 8; ++j) - //{ - // b = b ^ mPrng.mBuffer.data()[idx[j]]; - //} - b[0] = AES::roundEnc(b[0], k[0]); - b[1] = AES::roundEnc(b[1], k[1]); - b[2] = AES::roundEnc(b[2], k[2]); - b[3] = AES::roundEnc(b[3], k[3]); - b[4] = AES::roundEnc(b[4], k[4]); - b[5] = AES::roundEnc(b[5], k[5]); - b[6] = AES::roundEnc(b[6], k[6]); - b[7] = AES::roundEnc(b[7], k[7]); - - b[0] = b[0] ^ k[0]; - b[1] = b[1] ^ k[1]; - b[2] = b[2] ^ k[2]; - b[3] = b[3] ^ k[3]; - b[4] = b[4] ^ k[4]; - b[5] = b[5] ^ k[5]; - b[6] = b[6] ^ k[6]; - b[7] = b[7] ^ k[7]; - } - } - -#ifndef EXCONVCODE_INSTANTIATIONS - - void ExConvCode::accOne( - PointList& pl, - u64 i, - u8* __restrict& ptr, - PRNG& prng, - block& rnd, - u64& q, - u64 qe, - u64 size) const - { - u64 j = i + 1; - pl.push_back(i, i); - - if (q + mAccumulatorSize > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; - } - - - for (u64 k = 0; k < mAccumulatorSize; k += 8, q += 8, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - rnd = block::allSame(*ptr); - ++ptr; - - //std::cout << "r " << rnd << std::endl; - auto b0 = rnd; - auto b1 = rnd.slli_epi32<1>(); - auto b2 = rnd.slli_epi32<2>(); - auto b3 = rnd.slli_epi32<3>(); - auto b4 = rnd.slli_epi32<4>(); - auto b5 = rnd.slli_epi32<5>(); - auto b6 = rnd.slli_epi32<6>(); - auto b7 = rnd.slli_epi32<7>(); - //rnd = rnd.mm_slli_epi32<8>(); - - if (j + 0 < size && b0.get(0) < 0) pl.push_back(j + 0, i); - if (j + 1 < size && b1.get(0) < 0) pl.push_back(j + 1, i); - if (j + 2 < size && b2.get(0) < 0) pl.push_back(j + 2, i); - if (j + 3 < size && b3.get(0) < 0) pl.push_back(j + 3, i); - if (j + 4 < size && b4.get(0) < 0) pl.push_back(j + 4, i); - if (j + 5 < size && b5.get(0) < 0) pl.push_back(j + 5, i); - if (j + 6 < size && b6.get(0) < 0) pl.push_back(j + 6, i); - if (j + 7 < size && b7.get(0) < 0) pl.push_back(j + 7, i); - } - - - //if (mWrapping) - { - if (j < size) - pl.push_back(j, i); - ++j; - } - - } -#endif - - - template - OC_FORCEINLINE void accOneHelper( - T* __restrict xx, - My__m128 xii, - u64 j, u64 i, u64 size, - block* b - ) - { - My__m128 Zero = _mm_setzero_ps(); - - if constexpr (std::is_same::value) - { - My__m128 bb[8]; - bb[0] = _mm_load_ps((float*)&b[0]); - bb[1] = _mm_load_ps((float*)&b[1]); - bb[2] = _mm_load_ps((float*)&b[2]); - bb[3] = _mm_load_ps((float*)&b[3]); - bb[4] = _mm_load_ps((float*)&b[4]); - bb[5] = _mm_load_ps((float*)&b[5]); - bb[6] = _mm_load_ps((float*)&b[6]); - bb[7] = _mm_load_ps((float*)&b[7]); - - - bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); - bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); - bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); - bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); - bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); - bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); - bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); - bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); - block tt[8]; - memcpy(tt, bb, 8 * 16); - if (!rangeCheck || j + 0 < size) xx[j + 0] = xx[j + 0] ^ tt[0]; - if (!rangeCheck || j + 1 < size) xx[j + 1] = xx[j + 1] ^ tt[1]; - if (!rangeCheck || j + 2 < size) xx[j + 2] = xx[j + 2] ^ tt[2]; - if (!rangeCheck || j + 3 < size) xx[j + 3] = xx[j + 3] ^ tt[3]; - if (!rangeCheck || j + 4 < size) xx[j + 4] = xx[j + 4] ^ tt[4]; - if (!rangeCheck || j + 5 < size) xx[j + 5] = xx[j + 5] ^ tt[5]; - if (!rangeCheck || j + 6 < size) xx[j + 6] = xx[j + 6] ^ tt[6]; - if (!rangeCheck || j + 7 < size) xx[j + 7] = xx[j + 7] ^ tt[7]; - } - else - { - auto bb0 = xx[i] * (b[0].get(0) < 0); - auto bb1 = xx[i] * (b[1].get(0) < 0); - auto bb2 = xx[i] * (b[2].get(0) < 0); - auto bb3 = xx[i] * (b[3].get(0) < 0); - auto bb4 = xx[i] * (b[4].get(0) < 0); - auto bb5 = xx[i] * (b[5].get(0) < 0); - auto bb6 = xx[i] * (b[6].get(0) < 0); - auto bb7 = xx[i] * (b[7].get(0) < 0); + // configure the code. The default parameters are choses to balance security and performance. + // For additional parameter choices see the paper. - if (!rangeCheck || j + 0 < size) xx[j + 0] = xx[j + 0] ^ bb0; - if (!rangeCheck || j + 1 < size) xx[j + 1] = xx[j + 1] ^ bb1; - if (!rangeCheck || j + 2 < size) xx[j + 2] = xx[j + 2] ^ bb2; - if (!rangeCheck || j + 3 < size) xx[j + 3] = xx[j + 3] ^ bb3; - if (!rangeCheck || j + 4 < size) xx[j + 4] = xx[j + 4] ^ bb4; - if (!rangeCheck || j + 5 < size) xx[j + 5] = xx[j + 5] ^ bb5; - if (!rangeCheck || j + 6 < size) xx[j + 6] = xx[j + 6] ^ bb6; - if (!rangeCheck || j + 7 < size) xx[j + 7] = xx[j + 7] ^ bb7; - } - } + //// get the expander matrix + //SparseMtx ExConvCode::getB() const + //{ + // throw RTE_LOC; + // //if (mSystematic) + // //{ + // // PointList R(mMessageSize, mCodeSize); + // // auto B = mExpander.getB().points(); - template - OC_FORCEINLINE void ExConvCode::accOne( - T* __restrict xx, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size) - { - u64 j = i + 1; - if (width) - { - My__m128 xii; - if constexpr (std::is_same::value) - xii = _mm_load_ps((float*)(xx + i)); - else - xii = _mm_setzero_ps(); + // // for (auto p : B) + // // { + // // R.push_back(p.mRow, mMessageSize + p.mCol); + // // } + // // for (u64 i = 0; i < mMessageSize; ++i) + // // R.push_back(i, i); - if (q + width > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; + // // return R; + // //} + // //else + // //{ + // // return mExpander.getB(); + // //} + //} - } - q += width; - - for (u64 k = 0; k < width; ++k, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - block rnd = block::allSame(*(u8*)ptr++); - - - block b[8]; - b[0] = rnd; - b[1] = rnd.slli_epi32<1>(); - b[2] = rnd.slli_epi32<2>(); - b[3] = rnd.slli_epi32<3>(); - b[4] = rnd.slli_epi32<4>(); - b[5] = rnd.slli_epi32<5>(); - b[6] = rnd.slli_epi32<6>(); - b[7] = rnd.slli_epi32<7>(); - - accOneHelper(xx, xii, j, i, size, b); - } - } - - if (!rangeCheck || j < size) - { - auto xj = xx[j] ^ xx[i]; - xx[j] = xj; - } - } - - template - OC_FORCEINLINE void ExConvCode::accOne( - T0* __restrict xx0, - T1* __restrict xx1, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size) - { - u64 j = i + 1; - if (width) - { - My__m128 xii0, xii1; - if constexpr (std::is_same::value) - xii0 = _mm_load_ps((float*)(xx0 + i)); - else - xii0 = _mm_setzero_ps(); - if constexpr (std::is_same::value) - xii1 = _mm_load_ps((float*)(xx1 + i)); - else - xii1 = _mm_setzero_ps(); - - if (q + width > qe) - { - refill(prng); - ptr = (u8*)prng.mBuffer.data(); - q = 0; - - } - q += width; - - for (u64 k = 0; k < width; ++k, j += 8) - { - assert(ptr < (u8*)(prng.mBuffer.data() + prng.mBuffer.size())); - block rnd = block::allSame(*(u8*)ptr++); - - block b[8]; - b[0] = rnd; - b[1] = rnd.slli_epi32<1>(); - b[2] = rnd.slli_epi32<2>(); - b[3] = rnd.slli_epi32<3>(); - b[4] = rnd.slli_epi32<4>(); - b[5] = rnd.slli_epi32<5>(); - b[6] = rnd.slli_epi32<6>(); - b[7] = rnd.slli_epi32<7>(); - - accOneHelper(xx0, xii0, j, i, size, b); - accOneHelper(xx1, xii1, j, i, size, b); - } - } - - if (!rangeCheck || j < size) - { - auto xj0 = xx0[j] ^ xx0[i]; - auto xj1 = xx1[j] ^ xx1[i]; - xx0[j] = xj0; - xx1[j] = xj1; - } - } - - - - template - void ExConvCode::accumulate(span x) - { - PRNG prng(mSeed ^ OneBlock); - - u64 i = 0; - auto size = x.size(); - auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); - u8* ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128 / 8; - u64 q = 0; - T* __restrict xx = x.data(); - - { - -#define CASE(I) case I:\ - for (; i < main; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ - for (; i < size; ++i)\ - accOne(xx, i, ptr, prng, q, qe, size);\ - break - - switch (mAccumulatorSize / 8) - { - CASE(0); - CASE(1); - CASE(2); - CASE(3); - CASE(4); - default: - throw RTE_LOC; - break; - } -#undef CASE - } - } - - - template - void ExConvCode::accumulate(span x0, span x1) - { - PRNG prng(mSeed ^ OneBlock); - - u64 i = 0; - auto size = x0.size(); - auto main = (u64)std::max(0, size - 1 - mAccumulatorSize); - u8* ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128 / 8; - u64 q = 0; - T0* __restrict xx0 = x0.data(); - T1* __restrict xx1 = x1.data(); - - { - -#define CASE(I) case I:\ - for (; i < main; ++i)\ - accOne(xx0,xx1, i, ptr, prng, q, qe, size);\ - for (; i < size; ++i)\ - accOne(xx0, xx1, i, ptr, prng, q, qe, size);\ - break - - switch (mAccumulatorSize / 8) - { - CASE(0); - CASE(1); - CASE(2); - CASE(3); - CASE(4); - default: - throw RTE_LOC; - break; - } -#undef CASE - } - } - - -#ifndef EXCONVCODE_INSTANTIATIONS - - SparseMtx ExConvCode::getB() const - { - if (mSystematic) - { - PointList R(mMessageSize, mCodeSize); - auto B = mExpander.getB().points(); - - for (auto p : B) - { - R.push_back(p.mRow, mMessageSize + p.mCol); - } - for (u64 i = 0; i < mMessageSize; ++i) - R.push_back(i, i); - - return R; - } - else - { - return mExpander.getB(); - } - - } // Get the parity check version of the accumulator - SparseMtx ExConvCode::getAPar() const - { - PRNG prng(mSeed ^ OneBlock); - - auto n = mCodeSize - mSystematic * mMessageSize; - - PointList AP(n, n);; - DenseMtx A = DenseMtx::Identity(n); - - block rnd; - u8* __restrict ptr = (u8*)prng.mBuffer.data(); - auto qe = prng.mBuffer.size() * 128; - u64 q = 0; - - for (u64 i = 0; i < n; ++i) - { - accOne(AP, i, ptr, prng, rnd, q, qe, n); - } - return AP; - } - - SparseMtx ExConvCode::getA() const - { - auto APar = getAPar(); - - auto A = DenseMtx::Identity(mCodeSize); - - u64 offset = mSystematic ? mMessageSize : 0ull; - - for (u64 i = 0; i < APar.rows(); ++i) - { - for (auto y : APar.col(i)) - { - //std::cout << y << " "; - if (y != i) - { - auto ay = A.row(y + offset); - auto ai = A.row(i + offset); - ay ^= ai; - } - } + //SparseMtx ExConvCode::getAPar() const + //{ + // throw RTE_LOC; + // //PRNG prng(mSeed ^ OneBlock); + + // //auto n = mCodeSize - mSystematic * mMessageSize; + + // //PointList AP(n, n);; + // //DenseMtx A = DenseMtx::Identity(n); + + // //block rnd; + // //u8* __restrict ptr = (u8*)prng.mBuffer.data(); + // //auto qe = prng.mBuffer.size() * 128; + // //u64 q = 0; + + // //for (u64 i = 0; i < n; ++i) + // //{ + // // accOne(AP, i, ptr, prng, rnd, q, qe, n); + // //} + // //return AP; + //} + + //// get the accumulator matrix + //SparseMtx ExConvCode::getA() const + //{ + // auto APar = getAPar(); + + // auto A = DenseMtx::Identity(mCodeSize); + + // u64 offset = mSystematic ? mMessageSize : 0ull; + + // for (u64 i = 0; i < APar.rows(); ++i) + // { + // for (auto y : APar.col(i)) + // { + // if (y != i) + // { + // auto ay = A.row(y + offset); + // auto ai = A.row(i + offset); + // ay ^= ai; + // } + // } + // } + + // return A.sparse(); + //} - //std::cout << "\n" << A << std::endl; - } - return A.sparse(); - } -#endif } \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCode.h b/libOTe/Tools/ExConvCode/ExConvCode.h index 13a25cad..0a574fcc 100644 --- a/libOTe/Tools/ExConvCode/ExConvCode.h +++ b/libOTe/Tools/ExConvCode/ExConvCode.h @@ -1,4 +1,4 @@ -// © 2023 Visa. +// © 2023 Visa. // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. @@ -8,12 +8,39 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" -#include "libOTe/Tools/EACode/Expander.h" +#include "libOTe/Tools/ExConvCode/Expander.h" #include "libOTe/Tools/EACode/Util.h" namespace osuCrypto { + template + struct has_operator_star : std::false_type + {}; + + template + struct has_operator_star < T, std::void_t< + // must have a operator*() member fn + decltype(std::declval().operator*()) + >> + : std::true_type{}; + + + template + struct is_iterator : std::false_type + {}; + + template + struct is_iterator < T, std::void_t< + // must have a operator*() member fn + // or be a pointer + std::enable_if_t< + has_operator_star::value || + std::is_pointer_v> + > + >> + : std::true_type{}; + // The encoder for the generator matrix G = B * A. dualEncode(...) is the main function // config(...) should be called first. // @@ -39,25 +66,11 @@ namespace osuCrypto // For additional parameter choices see the paper. void config( u64 messageSize, - u64 codeSize = 0 /*2 * messageSize is default */, + u64 codeSize, u64 expanderWeight = 7, - u64 accumulatorSize = 16, + u64 accumulatorWeight = 16, bool systematic = true, - block seed = block(99999, 88888)) - { - if (codeSize == 0) - codeSize = 2 * messageSize; - - if (accumulatorSize % 8) - throw std::runtime_error("ExConvCode accumulator size must be a multiple of 8." LOCATION); - - mSeed = seed; - mMessageSize = messageSize; - mCodeSize = codeSize; - mAccumulatorSize = accumulatorSize; - mSystematic = systematic; - mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); - } + block seed = block(9996754675674599, 56756745976768754)); // the seed that generates the code. block mSeed = ZeroBlock; @@ -86,73 +99,472 @@ namespace osuCrypto // return code size n. u64 generatorCols() const { return mCodeSize; } - // Compute w = G * e. e will be modified in the computation. - template - void dualEncode(span e, span w); - // Compute e[0,...,k-1] = G * e. - template - void dualEncode(span e); - + // the computation will be done over F using ctx.plus + template< + typename F, + typename CoeffCtx, + typename Iter + > + void dualEncode(Iter&& e, CoeffCtx ctx); // Compute e[0,...,k-1] = G * e. - template - void dualEncode2(span e0, span e1); + template< + typename F, + typename G, + typename CoeffCtx, + typename IterF, + typename IterG + > + void dualEncode2(IterF&& e0, IterG&& e1, CoeffCtx ctx) + { + dualEncode(e0, ctx); + dualEncode(e1, ctx); + } + + // Private functions ------------------------------------ - // get the expander matrix - SparseMtx getB() const; + static void refill(PRNG& prng) + { + assert(prng.mBuffer.size() == 256); + //block b[8]; + for (u64 i = 0; i < 256; i += 8) + { + block* __restrict b = prng.mBuffer.data() + i; + block* __restrict k = prng.mBuffer.data() + (u8)(i - 8); - // Get the parity check version of the accumulator - SparseMtx getAPar() const; + b[0] = AES::roundEnc(b[0], k[0]); + b[1] = AES::roundEnc(b[1], k[1]); + b[2] = AES::roundEnc(b[2], k[2]); + b[3] = AES::roundEnc(b[3], k[3]); + b[4] = AES::roundEnc(b[4], k[4]); + b[5] = AES::roundEnc(b[5], k[5]); + b[6] = AES::roundEnc(b[6], k[6]); + b[7] = AES::roundEnc(b[7], k[7]); + } + } - // get the accumulator matrix - SparseMtx getA() const; + // take x[i] and add it to the next 8 positions if the flag b is 1. + // + // xx[j] += b[j] * x[i] + // + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void accOne8( + Iter&& xi, + Iter&& xj, + Iter&& end, + u8 b, + CoeffCtx& ctx); - // Private functions ------------------------------------ + // accumulating row i. + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + int AccumulatorSize, + typename Iter + > + OC_FORCEINLINE void accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx); - // generate the point list for accumulating row i. - void accOne( - PointList& pl, - u64 i, - u8* __restrict& ptr, - PRNG& prng, - block& rnd, - u64& q, - u64 qe, - u64 size) const; - - // accumulating row i. - template - void accOne( - T* __restrict xx, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size); - - - // accumulating row i. - template - void accOne( - T0* __restrict xx0, - T1* __restrict xx1, - u64 i, - u8*& ptr, - PRNG& prng, - u64& q, - u64 qe, - u64 size); + // accumulating row i. generic version + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void accOneGen( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx); // accumulate x onto itself. - template - void accumulate(span x); - + template< + typename F, + typename CoeffCtx, + typename Iter + > + void accumulate( + Iter x, + CoeffCtx& ctx) + { + switch (mAccumulatorSize) + { + case 16: + accumulateFixed(std::forward(x), ctx); + break; + case 24: + accumulateFixed(std::forward(x), ctx); + break; + default: + // generic case + accumulateFixed(std::forward(x), ctx); + } + } // accumulate x onto itself. - template - void accumulate(span x0, span x1); + template< + typename F, + typename CoeffCtx, + u64 AccumulatorSize, + typename Iter + > + void accumulateFixed(Iter x, + CoeffCtx& ctx); + }; -} + + + inline void ExConvCode::config( + u64 messageSize, + u64 codeSize, + u64 expanderWeight, + u64 accumulatorSize, + bool systematic, + block seed) + { + if (codeSize == 0) + codeSize = 2 * messageSize; + + mSeed = seed; + mMessageSize = messageSize; + mCodeSize = codeSize; + mAccumulatorSize = accumulatorSize; + mSystematic = systematic; + mExpander.config(messageSize, codeSize - messageSize * systematic, expanderWeight, seed ^ CCBlock); + } + + // Compute e[0,...,k-1] = G * e. + template + void ExConvCode::dualEncode( + Iter&& e_, + CoeffCtx ctx) + { + static_assert(is_iterator::value, "must pass in an iterator to the data"); + + (void)*(e_ + mCodeSize - 1); + + auto e = ctx.template restrictPtr(e_); + + if (mSystematic) + { + auto d = e + mMessageSize; + setTimePoint("ExConv.encode.begin"); + accumulate(d, ctx); + setTimePoint("ExConv.encode.accumulate"); + mExpander.expand(d, e, ctx); + setTimePoint("ExConv.encode.expand"); + } + else + { + + setTimePoint("ExConv.encode.begin"); + accumulate(e, ctx); + setTimePoint("ExConv.encode.accumulate"); + + typename CoeffCtx::template Vec w; + ctx.resize(w, mMessageSize); + auto wIter = ctx.template restrictPtr(w.begin()); + + mExpander.expand(e, wIter, ctx); + setTimePoint("ExConv.encode.expand"); + + ctx.copy(w.begin(), w.end(), e); + setTimePoint("ExConv.encode.memcpy"); + + } + } + + // take x[i] and add it to the next 8 positions if the flag b is 1. + // + // xx[j] += b[j] * x[i] + // + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void ExConvCode::accOne8( + Iter&& xi, + Iter&& xj, + Iter&& end, + u8 b, + CoeffCtx& ctx) + { + +#ifdef ENABLE_SSE + if constexpr (std::is_same::value) + { + block rnd = block::allSame(b); + + block bshift[8]; + bshift[0] = rnd.slli_epi32<7>(); + bshift[1] = rnd.slli_epi32<6>(); + bshift[2] = rnd.slli_epi32<5>(); + bshift[3] = rnd.slli_epi32<4>(); + bshift[4] = rnd.slli_epi32<3>(); + bshift[5] = rnd.slli_epi32<2>(); + bshift[6] = rnd.slli_epi32<1>(); + bshift[7] = rnd; + + __m128 bb[8]; + auto xii = _mm_load_ps((float*)(&*xi)); + __m128 Zero = _mm_setzero_ps(); + + // bbj = bj + bb[0] = _mm_load_ps((float*)&bshift[0]); + bb[1] = _mm_load_ps((float*)&bshift[1]); + bb[2] = _mm_load_ps((float*)&bshift[2]); + bb[3] = _mm_load_ps((float*)&bshift[3]); + bb[4] = _mm_load_ps((float*)&bshift[4]); + bb[5] = _mm_load_ps((float*)&bshift[5]); + bb[6] = _mm_load_ps((float*)&bshift[6]); + bb[7] = _mm_load_ps((float*)&bshift[7]); + + // bbj = bj * xi + bb[0] = _mm_blendv_ps(Zero, xii, bb[0]); + bb[1] = _mm_blendv_ps(Zero, xii, bb[1]); + bb[2] = _mm_blendv_ps(Zero, xii, bb[2]); + bb[3] = _mm_blendv_ps(Zero, xii, bb[3]); + bb[4] = _mm_blendv_ps(Zero, xii, bb[4]); + bb[5] = _mm_blendv_ps(Zero, xii, bb[5]); + bb[6] = _mm_blendv_ps(Zero, xii, bb[6]); + bb[7] = _mm_blendv_ps(Zero, xii, bb[7]); + + block tt[8]; + memcpy(tt, bb, 8 * 16); + + assert((((b >> 0) & 1) ? *xi : ZeroBlock) == tt[0]); + assert((((b >> 1) & 1) ? *xi : ZeroBlock) == tt[1]); + assert((((b >> 2) & 1) ? *xi : ZeroBlock) == tt[2]); + assert((((b >> 3) & 1) ? *xi : ZeroBlock) == tt[3]); + assert((((b >> 4) & 1) ? *xi : ZeroBlock) == tt[4]); + assert((((b >> 5) & 1) ? *xi : ZeroBlock) == tt[5]); + assert((((b >> 6) & 1) ? *xi : ZeroBlock) == tt[6]); + assert((((b >> 7) & 1) ? *xi : ZeroBlock) == tt[7]); + + // xj += bj * xi + if (rangeCheck && xj + 0 == end) return; + ctx.plus(*(xj + 0), *(xj + 0), tt[0]); + if (rangeCheck && xj + 1 == end) return; + ctx.plus(*(xj + 1), *(xj + 1), tt[1]); + if (rangeCheck && xj + 2 == end) return; + ctx.plus(*(xj + 2), *(xj + 2), tt[2]); + if (rangeCheck && xj + 3 == end) return; + ctx.plus(*(xj + 3), *(xj + 3), tt[3]); + if (rangeCheck && xj + 4 == end) return; + ctx.plus(*(xj + 4), *(xj + 4), tt[4]); + if (rangeCheck && xj + 5 == end) return; + ctx.plus(*(xj + 5), *(xj + 5), tt[5]); + if (rangeCheck && xj + 6 == end) return; + ctx.plus(*(xj + 6), *(xj + 6), tt[6]); + if (rangeCheck && xj + 7 == end) return; + ctx.plus(*(xj + 7), *(xj + 7), tt[7]); + } + else +#endif + { + auto b0 = b & 1; + auto b1 = b & 2; + auto b2 = b & 4; + auto b3 = b & 8; + auto b4 = b & 16; + auto b5 = b & 32; + auto b6 = b & 64; + auto b7 = b & 128; + + if (rangeCheck && xj + 0 == end) return; + if (b0) ctx.plus(*(xj + 0), *(xj + 0), *xi); + if (rangeCheck && xj + 1 == end) return; + if (b1) ctx.plus(*(xj + 1), *(xj + 1), *xi); + if (rangeCheck && xj + 2 == end) return; + if (b2) ctx.plus(*(xj + 2), *(xj + 2), *xi); + if (rangeCheck && xj + 3 == end) return; + if (b3) ctx.plus(*(xj + 3), *(xj + 3), *xi); + if (rangeCheck && xj + 4 == end) return; + if (b4) ctx.plus(*(xj + 4), *(xj + 4), *xi); + if (rangeCheck && xj + 5 == end) return; + if (b5) ctx.plus(*(xj + 5), *(xj + 5), *xi); + if (rangeCheck && xj + 6 == end) return; + if (b6) ctx.plus(*(xj + 6), *(xj + 6), *xi); + if (rangeCheck && xj + 7 == end) return; + if (b7) ctx.plus(*(xj + 7), *(xj + 7), *xi); + } + } + + + + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + typename Iter + > + OC_FORCEINLINE void ExConvCode::accOneGen( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx) + { + + // xj += xi + std::remove_reference_t xj = xi + 1; + if (!rangeCheck || xj < end) + { + ctx.plus(*xj, *xj, *xi); + ctx.mulConst(*xj, *xj); + ++xj; + } + + // xj += bj * xi + u64 k = 0; + for (; k < mAccumulatorSize - 7; k += 8) + { + accOne8(xi, xj, end, *matrixCoeff++, ctx); + + if constexpr (rangeCheck) + { + auto r = end - xj; + xj += std::min(r, 8); + } + else + { + xj += 8; + } + } + for (; k < mAccumulatorSize; ) + { + auto b = *matrixCoeff++; + + for (u64 j = 0; j < 8 && k < mAccumulatorSize; ++j, ++k) + { + + if (rangeCheck == false || (xj != end)) + { + if (b & 1) + ctx.plus(*xj, *xj, *xi); + + ++xj; + b >>= 1; + } + + } + } + } + + + // add xi to all of the future locations + template< + typename F, + typename CoeffCtx, + bool rangeCheck, + int AccumulatorSize, + typename Iter + > + OC_FORCEINLINE void ExConvCode::accOne( + Iter&& xi, + Iter&& end, + u8* matrixCoeff, + CoeffCtx& ctx) + { + static_assert(AccumulatorSize, "should have called the other overload"); + static_assert(AccumulatorSize % 8 == 0, "must be a multiple of 8"); + + // xj += xi + std::remove_reference_t xj = xi + 1; + if (!rangeCheck || xj < end) + { + ctx.plus(*xj, *xj, *xi); + ctx.mulConst(*xj, *xj); + ++xj; + } + + // xj += bj * xi + for (u64 k = 0; k < AccumulatorSize; k += 8) + { + accOne8(xi, xj, end, *matrixCoeff++, ctx); + + if constexpr (rangeCheck) + { + auto r = end - xj; + xj += std::min(r, 8); + } + else + { + xj += 8; + } + } + } + + // accumulate x onto itself. + template< + typename F, + typename CoeffCtx, + u64 AccumulatorSize, + typename Iter + > + void ExConvCode::accumulateFixed( + Iter xi, + CoeffCtx& ctx) + { + auto end = xi + (mCodeSize - mSystematic * mMessageSize); + auto main = end - 1 - mAccumulatorSize; + + PRNG prng(mSeed ^ OneBlock); + u8* mtxCoeffIter = (u8*)prng.mBuffer.data(); + auto mtxCoeffEnd = mtxCoeffIter + prng.mBuffer.size() * sizeof(block) - divCeil(mAccumulatorSize, 8); + + // AccumulatorSize == 0 is the generic case, otherwise + // AccumulatorSize should be equal to mAccumulatorSize. + static_assert(AccumulatorSize % 8 == 0); + if (AccumulatorSize && mAccumulatorSize != AccumulatorSize) + throw RTE_LOC; + + while (xi < main) + { + if (mtxCoeffIter > mtxCoeffEnd) + { + // generate more mtx coefficients + refill(prng); + mtxCoeffIter = (u8*)prng.mBuffer.data(); + } + + // add xi to the next positions + if constexpr(AccumulatorSize == 0) + accOneGen(xi, end, mtxCoeffIter++, ctx); + else + accOne(xi, end, mtxCoeffIter++, ctx); + ++xi; + } + + while (xi < end) + { + if (mtxCoeffIter > mtxCoeffEnd) + { + // generate more mtx coefficients + refill(prng); + mtxCoeffIter = (u8*)prng.mBuffer.data(); + } + + // add xi to the next positions + if constexpr (AccumulatorSize == 0) + accOneGen(xi, end, mtxCoeffIter++, ctx); + else + accOne(xi, end, mtxCoeffIter++, ctx); + ++xi; + } + } + +} \ No newline at end of file diff --git a/libOTe/Tools/ExConvCode/ExConvCodeInstantiations.cpp b/libOTe/Tools/ExConvCode/ExConvCodeInstantiations.cpp deleted file mode 100644 index 1b91615e..00000000 --- a/libOTe/Tools/ExConvCode/ExConvCodeInstantiations.cpp +++ /dev/null @@ -1,17 +0,0 @@ - -#define EXCONVCODE_INSTANTIATIONS -#include "ExConvCode.cpp" - -namespace osuCrypto -{ - - template void ExConvCode::dualEncode(span e); - template void ExConvCode::dualEncode(span e); - template void ExConvCode::dualEncode(span e, span w); - template void ExConvCode::dualEncode(span e, span w); - template void ExConvCode::dualEncode2(span, span e); - template void ExConvCode::dualEncode2(span, span e); - - template void ExConvCode::accumulate(span, span e); - template void ExConvCode::accumulate(span, span e); -} diff --git a/libOTe/Tools/ExConvCode/Expander.h b/libOTe/Tools/ExConvCode/Expander.h new file mode 100644 index 00000000..1bd9803e --- /dev/null +++ b/libOTe/Tools/ExConvCode/Expander.h @@ -0,0 +1,249 @@ +// � 2023 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +#pragma once + +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Range.h" +#include "libOTe/Tools/LDPC/Mtx.h" +#include "libOTe/Tools/EACode/Util.h" + +namespace osuCrypto +{ + + // The encoder for the expander matrix B. + // B has mMessageSize rows and mCodeSize columns. It is sampled uniformly + // with fixed row weight mExpanderWeight. + class ExpanderCode + { + public: + + void config( + u64 messageSize, + u64 codeSize, + u64 expanderWeight, + block seed = block(33333, 33333)) + { + mMessageSize = messageSize; + mCodeSize = codeSize; + mExpanderWeight = expanderWeight; + mSeed = seed; + } + + // the seed that generates the code. + block mSeed = block(0, 0); + + // The message size of the code. K. + u64 mMessageSize = 0; + + // The codeword size of the code. n. + u64 mCodeSize = 0; + + // The row weight of the B matrix. + u64 mExpanderWeight = 0; + + u64 parityRows() const { return mCodeSize - mMessageSize; } + u64 parityCols() const { return mCodeSize; } + + u64 generatorRows() const { return mMessageSize; } + u64 generatorCols() const { return mCodeSize; } + + + template< + typename F, + typename CoeffCtx, + bool add, + typename SrcIter, + typename DstIter + > + void expand( + SrcIter&& input, + DstIter&& output, + CoeffCtx ctx = {} + ) const; + + + //template< + // bool Add, + // typename CoeffCtx, + // typename... F, + // typename... SrcDstIterPair + //> + //void expandMany( + // std::tuple out, + // CoeffCtx ctx = {})const; + + }; + + + + template< + typename F, + typename CoeffCtx, + bool Add, + typename SrcIter, + typename DstIter + > + void ExpanderCode::expand( + SrcIter&& input, + DstIter&& output, + CoeffCtx ctx) const + { + (void)*(input + (mCodeSize - 1)); + (void)*(output + (mMessageSize - 1)); + + detail::ExpanderModd prng(mSeed, mCodeSize); + + auto rInput = ctx.template restrictPtr(input); + auto rOutput = ctx.template restrictPtr(output); + + auto main = mMessageSize / 8 * 8; + u64 i = 0; + + for (; i < main; i += 8, rOutput+= 8) + { + if constexpr (Add == false) + { + ctx.zero(rOutput, rOutput + 8); + } + + for (auto j = 0ull; j < mExpanderWeight; ++j) + { + u64 rr[8]; + rr[0] = prng.get(); + rr[1] = prng.get(); + rr[2] = prng.get(); + rr[3] = prng.get(); + rr[4] = prng.get(); + rr[5] = prng.get(); + rr[6] = prng.get(); + rr[7] = prng.get(); + + ctx.plus(*(rOutput + 0), *(rOutput + 0), *(rInput + rr[0])); + ctx.plus(*(rOutput + 1), *(rOutput + 1), *(rInput + rr[1])); + ctx.plus(*(rOutput + 2), *(rOutput + 2), *(rInput + rr[2])); + ctx.plus(*(rOutput + 3), *(rOutput + 3), *(rInput + rr[3])); + ctx.plus(*(rOutput + 4), *(rOutput + 4), *(rInput + rr[4])); + ctx.plus(*(rOutput + 5), *(rOutput + 5), *(rInput + rr[5])); + ctx.plus(*(rOutput + 6), *(rOutput + 6), *(rInput + rr[6])); + ctx.plus(*(rOutput + 7), *(rOutput + 7), *(rInput + rr[7])); + } + } + + if constexpr (Add == false) + { + ctx.zero(rOutput, rOutput + (mMessageSize - i)); + } + + for (; i < mMessageSize; ++i, ++rOutput) + { + for (auto j = 0ull; j < mExpanderWeight; ++j) + { + ctx.plus(*rOutput, *rOutput, *(input + prng.get())); + } + } + } + + //template< + // bool Add, + // typename CoeffCtx, + // typename... F, + // typename... SrcDstIterPair + //> + //void ExpanderCode::expandMany( + // std::tuple inOuts, + // CoeffCtx ctx)const + //{ + + // std::apply([&](auto&&... inOut) {( + // [&] { + // (void)*(std::get<0>(inOut) + (mCodeSize - 1)); + // (void)*(std::get<1>(inOut) + (mMessageSize - 1)); + // }(), ...); }, inOuts); + + + // detail::ExpanderModd prng(mSeed, mCodeSize); + + // auto main = mMessageSize / 8 * 8; + // u64 i = 0; + + // std::vector rr(8 * mExpanderWeight); + + // for (; i < main; i += 8) + // { + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // rr[j * 8 + 0] = prng.get(); + // rr[j * 8 + 1] = prng.get(); + // rr[j * 8 + 2] = prng.get(); + // rr[j * 8 + 3] = prng.get(); + // rr[j * 8 + 4] = prng.get(); + // rr[j * 8 + 5] = prng.get(); + // rr[j * 8 + 6] = prng.get(); + // rr[j * 8 + 7] = prng.get(); + // } + + // std::apply([&](auto&&... inOut) {([&] { + + // auto& input = std::get<0>(inOut); + // auto& output = std::get<1>(inOut); + + // if constexpr (Add == false) + // { + // ctx.zero(output, output + 8); + // } + + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // ctx.plus(*(output + 0), *(output + 0), *(input + rr[j * 8 + 0])); + // ctx.plus(*(output + 1), *(output + 1), *(input + rr[j * 8 + 1])); + // ctx.plus(*(output + 2), *(output + 2), *(input + rr[j * 8 + 2])); + // ctx.plus(*(output + 3), *(output + 3), *(input + rr[j * 8 + 3])); + // ctx.plus(*(output + 4), *(output + 4), *(input + rr[j * 8 + 4])); + // ctx.plus(*(output + 5), *(output + 5), *(input + rr[j * 8 + 5])); + // ctx.plus(*(output + 6), *(output + 6), *(input + rr[j * 8 + 6])); + // ctx.plus(*(output + 7), *(output + 7), *(input + rr[j * 8 + 7])); + // } + + // output += 8; + // }(), ...); }, inOuts); + + // } + + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // rr[j * 8 + 0] = prng.get(); + // rr[j * 8 + 1] = prng.get(); + // rr[j * 8 + 2] = prng.get(); + // rr[j * 8 + 3] = prng.get(); + // rr[j * 8 + 4] = prng.get(); + // rr[j * 8 + 5] = prng.get(); + // rr[j * 8 + 6] = prng.get(); + // rr[j * 8 + 7] = prng.get(); + // } + + // std::apply([&](auto&&... inOut) {([&] { + + // auto& input = std::get<0>(inOut); + // auto& output = std::get<1>(inOut); + // if constexpr (Add == false) + // { + // ctx.zero(output, output + (mMessageSize - i)); + // } + + // for (u64 k = 0; i < mMessageSize; ++i, ++output, ++k) + // { + // for (auto j = 0ull; j < mExpanderWeight; ++j) + // { + // ctx.plus(*output, *output, *(input + rr[j*8 + k])); + // } + // } + // }(), ...); }, inOuts); + //} + + + +} diff --git a/libOTe/Tools/LDPC/LdpcDecoder.cpp b/libOTe/Tools/LDPC/LdpcDecoder.cpp deleted file mode 100644 index c0fb7edd..00000000 --- a/libOTe/Tools/LDPC/LdpcDecoder.cpp +++ /dev/null @@ -1,613 +0,0 @@ -#include "LdpcDecoder.h" -#ifdef ENABLE_LDPC - -#include -#include "Mtx.h" -#include "LdpcEncoder.h" -#include "LdpcSampler.h" -#include "Util.h" -#include -#include -#include -#include - -#include "LdpcImpulseDist.h" -namespace osuCrypto { - - - auto nan = std::nan(""); - - void LdpcDecoder::init(SparseMtx& H) - { - mH = H; - auto n = mH.cols(); - auto m = mH.rows(); - - assert(n > m); - mK = n - m; - - mR.resize(m, n); - mM.resize(m, n); - - mW.resize(n); - } - - - std::vector LdpcDecoder::bpDecode(span codeword, u64 maxIter) - { - auto n = mH.cols(); - - - assert(codeword.size() == n); - - std::array wVal{ { mP / (1 - mP), (1 - mP) / mP} }; - - // #1 - for (u64 i = 0; i < n; ++i) - { - assert(codeword[i] < 2); - mW[i] = wVal[codeword[i]]; - } - - return bpDecode(mW); - } - - std::vector LdpcDecoder::bpDecode(span lr, u64 maxIter) - { - - - auto n = mH.cols(); - auto m = mH.rows(); - - // #1 - for (u64 i = 0; i < n; ++i) - { - //assert(codeword[i] < 2); - mW[i] = lr[i]; - - for (auto j : mH.mCols[i]) - { - mR(j, i) = mW[i]; - } - } - - std::vector c(n); - std::vector rr; rr.reserve(100); - for (u64 ii = 0; ii < maxIter; ii++) - { - // #2 - for (u64 j = 0; j < m; ++j) - { - rr.resize(mH.mRows[j].size()); - for (u64 i : mH.mRows[j]) - { - // \Pi_{k in Nj \ {i} } (r_k^j + 1)/(r_k^j - 1) - double v = 1; - auto jj = 0; - for (u64 k : mH.mRows[j]) - { - rr[jj++] = mR(j, k); - if (k != i) - { - auto r = mR(j, k); - v *= (r + 1) / (r - 1); - } - } - - // m_j^i - auto mm = (v + 1) / (v - 1); - mM(j, i) = mm; - } - } - - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - // j indexes a row, [1,...,m] - for (u64 j : mH.mCols[i]) - { - // r_i^j = w_i * Pi_{k in Ni \ {j} } m_k^i - mR(j, i) = mW[i]; - - // j indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - if (k != j) - { - mR(j, i) *= mM(k, i); - } - } - } - } - - mL.resize(n); - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - //L(ci | wi, m^i) - mL[i] = mW[i]; - - // k indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - assert(mM(k, i) != nan); - mL[i] *= mM(k, i); - } - - c[i] = (mL[i] >= 1) ? 0 : 1; - - - mL[i] = std::log(mL[i]); - } - - if (check(c)) - { - c.resize(n - m); - return c; - } - } - - return {}; - } - - double sgn(double x) - { - if (x >= 0) - return 1; - return -1; - } - - u8 sgnBool(double x) - { - if (x >= 0) - return 0; - return 1; - } - - double phi(double x) - { - assert(x > 0); - x = std::min(20.0, x); - auto t = std::tanh(x * 0.5); - return -std::log(t); - } - - std::ostream& operator<<(std::ostream& o, const Matrix& m) - { - for (u64 i = 0; i < m.rows(); ++i) - { - for (u64 j = 0; j < m.cols(); ++j) - { - o << std::setw(4) << std::setfill(' ') << m(i, j) << " "; - } - - o << std::endl; - } - - return o; - } - - - std::vector LdpcDecoder::logbpDecode2(span codeword, u64 maxIter) - { - - auto n = mH.cols(); - //auto m = mH.rows(); - - assert(codeword.size() == n); - - std::array wVal{ - {std::log(mP / (1 - mP)), - std::log((1 - mP) / mP) - } - }; - - std::vector w(n); - for (u64 i = 0; i < n; ++i) - w[i] = wVal[codeword[i]]; - - return logbpDecode2(w, maxIter); - - } - std::vector LdpcDecoder::logbpDecode2(span llr, u64 maxIter) - { - auto n = mH.cols(); - auto m = mH.rows(); - std::vector c(n); - mL.resize(c.size()); - - - for (u64 i = 0; i < n; ++i) - { - mW[i] = llr[i]; - - for (auto j : mH.mCols[i]) - { - mR(j, i) = mW[i]; - } - } - - for (u64 ii = 0; ii < maxIter; ii++) - { - // #2 - for (u64 j = 0; j < m; ++j) - { - double v = 0; - u8 s = 1; - for (u64 k : mH.mRows[j]) - { - auto rr = mR(j, k); - v += phi(abs(rr)); - s ^= sgnBool(rr); - } - - - for (u64 k : mH.mRows[j]) - { - auto vv = phi(abs(mR(j, k))); - auto ss = sgnBool(mR(j, k)); - vv = phi(v - vv); - - mM(j, k) = (s ^ ss) ? vv : -vv; - } - } - - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - mL[i] = mW[i]; - for (u64 k : mH.mCols[i]) - { - mL[i] += mM(k, i); - } - for (u64 k : mH.mCols[i]) - { - // r_i^j = w_i * Pi_{k in Ni \ {j} } m_k^i - mR(k, i) = mL[i] - mM(k, i); - } - - c[i] = (mL[i] >= 0) ? 0 : 1; - } - - if (check(c)) - { - if (mAllowZero == false && isZero(c)) - continue; - - c.resize(n - m); - return c; - } - } - - return {}; - } - - std::vector LdpcDecoder::altDecode(span codeword, bool minSum, u64 maxIter) - { - auto _N = mH.cols(); - std::array wVal{ - {std::log(mP / (1 - mP)), - std::log((1 - mP) / mP) - } - }; - - std::vector w(_N); - for (u64 i = 0; i < _N; ++i) - { - w[i] = wVal[codeword[i]]; - } - return altDecode(w, minSum, maxIter); - } - - std::vector LdpcDecoder::altDecode(span w, bool min_sum, u64 maxIter) - { - - auto _N = mH.cols(); - auto _M = mH.rows(); - - for (u64 i = 0; i < _N; ++i) - { - mW[i] = w[i]; - } - - mL = mW; - std::vector decoded_cw(_N); - - - std::vector > forward_msg(_M); - std::vector > back_msg(_M); - for (u64 r = 0; r < _M; ++r) { - forward_msg[r].resize(mH.row(r).size()); - back_msg[r].resize(mH.row(r).size()); - } - auto maxLL = 20.0; - - for (u64 iter = 0; iter < maxIter; ++iter) { - - for (u64 r = 0; r < _M; ++r) { - - for (u64 c1 = 0; c1 < (u64)mH.row(r).size(); ++c1) { - double tmp = 1; - if (min_sum) - tmp = maxLL; - - for (u64 c2 = 0; c2 < (u64)mH.row(r).size(); ++c2) { - if (c1 == c2) - continue; - - auto i_col2 = mH.row(r)[c2]; - - double l1 = mL[i_col2] - back_msg[r][c2]; - l1 = std::min(l1, maxLL); - l1 = std::max(l1, -maxLL); - - if (min_sum) { - double sign_tmp = tmp < 0 ? -1 : 1; - double sign_l1 = l1 < 0.0 ? -1 : 1; - - tmp = sign_tmp * sign_l1 * std::min(std::abs(l1), std::abs(tmp)); - } - else - tmp = tmp * tanh(l1 / 2); - } - - - if (min_sum) { - forward_msg[r][c1] = tmp; - } - else { - forward_msg[r][c1] = 2 * atanh(tmp); - } - } - } - - back_msg = forward_msg; - - mL = mW; - - for (u64 r = 0; r < _M; ++r) { - - for (u64 i = 0; i < (u64)mH.row(r).size(); ++i) { - auto c = mH.row(r)[i]; - mL[c] += back_msg[r][i]; - } - } - - for (u64 c = 0; c < _N; ++c) { - decoded_cw[c] = mL[c] > 0 ? 0 : 1; - } - - if (check(decoded_cw)) { - decoded_cw.resize(_N - _M); - return decoded_cw; - } - - } // Iteration loop end - - return {}; - - //} - - } - - std::vector LdpcDecoder::minSumDecode(span codeword, u64 maxIter) - { - - auto n = mH.cols(); - auto m = mH.rows(); - - assert(codeword.size() == n); - - std::array wVal{ - {std::log(mP / (1 - mP)), - std::log((1 - mP) / mP)} }; - - auto nan = std::nan(""); - std::fill(mR.begin(), mR.end(), nan); - std::fill(mM.begin(), mM.end(), nan); - - // #1 - for (u64 i = 0; i < n; ++i) - { - assert(codeword[i] < 2); - mW[i] = wVal[codeword[i]]; - - for (auto j : mH.mCols[i]) - { - mR(j, i) = mW[i]; - } - } - - std::vector c(n); - std::vector rr; rr.reserve(100); - for (u64 ii = 0; ii < maxIter; ii++) - { - // #2 - for (u64 j = 0; j < m; ++j) - { - rr.resize(mH.mRows[j].size()); - for (u64 i : mH.mRows[j]) - { - // \Pi_{k in Nj \ {i} } (r_k^j + 1)/(r_k^j - 1) - double v = std::numeric_limits::max(); - double s = 1; - - for (u64 k : mH.mRows[j]) - { - if (k != i) - { - assert(mR(j, k) != nan); - - v = std::min(v, std::abs(mR(j, k))); - - s *= sgn(mR(j, k)); - } - } - - // m_j^i - mM(j, i) = s * v; - } - } - - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - // j indexes a row, [1,...,m] - for (u64 j : mH.mCols[i]) - { - // r_i^j = w_i * Pi_{k in Ni \ {j} } m_k^i - mR(j, i) = mW[i]; - - // j indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - if (k != j) - { - assert(mM(k, i) != nan); - mR(j, i) += mM(k, i); - } - } - } - } - mL.resize(n); - // i indexes a column, [1,...,n] - for (u64 i = 0; i < n; ++i) - { - //log L(ci | wi, m^i) - mL[i] = mW[i]; - - // k indexes a row, [1,...,m] - for (u64 k : mH.mCols[i]) - { - assert(mM(k, i) != nan); - mL[i] += mM(k, i); - } - - c[i] = (mL[i] >= 0) ? 0 : 1; - } - - if (check(c)) - { - c.resize(n - m); - return c; - } - } - - return {}; - } - - bool LdpcDecoder::check(const span& data) { - - // j indexes a row, [1,...,m] - for (u64 j = 0; j < mH.rows(); ++j) - { - u8 sum = 0; - - // i indexes a column, [1,...,n] - for (u64 i : mH.mRows[j]) - { - sum ^= data[i]; - } - - if (sum) - { - return false; - } - } - return true; - - } - - void tests::LdpcDecode_pb_test(const oc::CLP& cmd) - { - u64 rows = cmd.getOr("r", 40); - u64 cols = static_cast(rows * cmd.getOr("e", 2.0)); - u64 colWeight = cmd.getOr("cw", 3); - u64 dWeight = cmd.getOr("dw", 3); - u64 gap = cmd.getOr("g", 2); - - auto k = cols - rows; - - SparseMtx H; - LdpcEncoder E; - LdpcDecoder D; - - for (u64 i = 0; i < 2; ++i) - { - oc::PRNG prng(block(i, 1)); - bool b = true; - u64 tries = 0; - while (b) - { - H = sampleTriangularBand(rows, cols, - colWeight, gap, dWeight, false, prng); - // H = sampleTriangular(rows, cols, colWeight, gap, prng); - b = !E.init(H, gap); - - ++tries; - } - - D.init(H); - std::vector m(k), m2, code(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - E.encode(code, m); - auto ease = 1ull; - - - u64 min = 9999999; - u64 ee = 3; - while (true) - { - auto c = code; - for (u64 j = 0; j < ease; ++j) - { - c[j] ^= 1; - } - - u64 e = 0; - m2 = D.logbpDecode2(c); - - - if (m2 != m) - { - ++e; - min = std::min(min, ease); - } - m2 = D.altDecode(c, false); - - if (m2 != m) - { - ++e; - min = std::min(min, ease); - } - - - m2 = D.altDecode(c, true); - - if (m2 != m) - { - min = std::min(min, ease); - ++e; - } - if (e == ee) - break; - ++ease; - } - if (ease < 4 || min < 4) - { - throw std::runtime_error(LOCATION); - } - - //std::cout << "high " << ease << std::endl; - } - return; - - } - - - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/LdpcDecoder.h b/libOTe/Tools/LDPC/LdpcDecoder.h deleted file mode 100644 index 9b7009a3..00000000 --- a/libOTe/Tools/LDPC/LdpcDecoder.h +++ /dev/null @@ -1,121 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" - -#ifdef ENABLE_LDPC - -#include -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/Matrix.h" -#include "cryptoTools/Common/CLP.h" -#include "Mtx.h" -#include -namespace osuCrypto -{ - - class LdpcDecoder - { - public: - u64 mK = 0; - - bool mAllowZero = true; - double mP = 0.9; - Matrix mM, mR; - - std::vector> mMM, mRR; - std::vector mMData, mRData; - std::vector mW, mL; - - SparseMtx mH; - - LdpcDecoder() = default; - LdpcDecoder(const LdpcDecoder&) = default; - LdpcDecoder(LdpcDecoder&&) = default; - - - LdpcDecoder(SparseMtx& H) - { - init(H); - } - - void init(SparseMtx& H); - - std::vector bpDecode(span codeword, u64 maxIter = 1000); - std::vector logbpDecode2(span codeword, u64 maxIter = 1000); - std::vector altDecode(span codeword, bool minSum, u64 maxIter = 1000); - std::vector minSumDecode(span codeword, u64 maxIter = 1000); - - std::vector bpDecode(span codeword, u64 maxIter = 1000); - std::vector logbpDecode2(span codeword, u64 maxIter = 1000); - std::vector altDecode(span codeword, bool minSum, u64 maxIter = 1000); - - std::vector decode(span codeword, u64 maxIter = 1000) - { - return logbpDecode2(codeword, maxIter); - } - - bool check(const span& data); - static bool isZero(span data) - { - for (auto d : data) - if (d) return false; - return true; - } - - - inline static double LLR(double d) - { - assert(d > -1 && d < 1); - return std::log(d / (1 - d)); - } - inline static double LR(double d) - { - assert(d > -1 && d < 1); - return (d / (1 - d)); - } - - - inline static double encodeLLR(double p, bool bit) - { - assert(p >= 0.5); - assert(p < 1); - - p = bit ? (1 - p) : p; - - return LLR(p); - } - - inline static double encodeLR(double p, bool bit) - { - assert(p > 0.5); - assert(p < 1); - - p = bit ? (1 - p) : p; - - return LR(p); - } - - inline static u32 decodeLLR(double l) - { - return (l >= 0 ? 0 : 1); - } - - }; - - - - namespace tests - { - void LdpcDecode_pb_test(const oc::CLP& cmd); - - } - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/LdpcEncoder.cpp b/libOTe/Tools/LDPC/LdpcEncoder.cpp deleted file mode 100644 index fb570cbc..00000000 --- a/libOTe/Tools/LDPC/LdpcEncoder.cpp +++ /dev/null @@ -1,975 +0,0 @@ -#include "LdpcEncoder.h" -#ifdef ENABLE_INSECURE_SILVER - -//#include -#include -#include "cryptoTools/Crypto/PRNG.h" -#include "cryptoTools/Common/Timer.h" -#include "LdpcSampler.h" -#include "libOTe/Tools/Tools.h" -namespace osuCrypto -{ - namespace details - { - constexpr std::array, 16> SilverRightEncoder::diagMtx_g16_w5_seed1_t36; - constexpr std::array, 32> SilverRightEncoder::diagMtx_g32_w11_seed2_t36; - constexpr std::array SilverRightEncoder::mOffsets; - } - - bool LdpcEncoder::init(SparseMtx H, u64 gap) - { - -#ifndef NDEBUG - for (u64 i = H.cols() - H.rows() + gap, j = 0; i < H.cols(); ++i, ++j) - { - auto row = H.row(j); - assert(row[row.size() - 1] == i); - } -#endif - auto c0 = H.cols() - H.rows(); - auto c1 = c0 + gap; - auto r0 = H.rows() - gap; - - mN = H.cols(); - mM = H.rows(); - mGap = gap; - - - mA = H.subMatrix(0, 0, r0, c0); - mB = H.subMatrix(0, c0, r0, gap); - mC = H.subMatrix(0, c1, r0, H.rows() - gap); - mD = H.subMatrix(r0, 0, gap, c0); - mE = H.subMatrix(r0, c0, gap, gap); - mF = H.subMatrix(r0, c1, gap, H.rows() - gap); - mH = std::move(H); - - mCInv.init(mC); - - if (mGap) - { - SparseMtx CB; - - // CB = C^-1 B - mCInv.mult(mB, CB); - - //assert(mC.invert().mult(mB) == CB); - // Ep = F C^-1 B - mEp = mF.mult(CB); - //// Ep = F C^-1 B + E - mEp += mE; - mEp = mEp.invert(); - - return (mEp.rows() != 0); - } - - return true; - } - - void LdpcEncoder::encode(span c, span mm) - { - assert(mm.size() == mM); - assert(c.size() == mN); - - auto s = mM - mGap; - auto iter = c.begin() + mM; - span m(c.begin(), iter); - span p(iter, iter + mGap); iter += mGap; - span pp(iter, c.end()); - - - // m = mm - std::copy(mm.begin(), mm.end(), m.begin()); - std::fill(c.begin() + mM, c.end(), 0); - - // pp = A * m - mA.multAdd(m, pp); - - if (mGap) - { - std::vector t(s); - - // t = C^-1 pp = C^-1 A m - mCInv.mult(pp, t); - - // p = - F t + D m = -F C^-1 A m + D m - mF.multAdd(t, p); - mD.multAdd(m, p); - - // p = - Ep p = -Ep (-F C^-1 A m + D m) - t = mEp.mult(p); - std::copy(t.begin(), t.end(), p.begin()); - - // pp = pp + B p - mB.multAdd(p, pp); - } - - // pp = C^-1 pp - mCInv.mult(pp, pp); - } - - namespace details - { - - void DiagInverter::init(const SparseMtx& c) - { - mC = (&c); - assert(mC->rows() == mC->cols()); - -#ifndef NDEBUG - for (u64 i = 0; i < mC->rows(); ++i) - { - auto row = mC->row(i); - assert(row.size() && row[row.size() - 1] == i); - - for (u64 j = 0; j < row.size() - 1; ++j) - { - assert(row[j] < row[j + 1]); - } - } -#endif - } - - - std::vector DiagInverter::getSteps() - { - std::vector steps; - - u64 n = mC->cols(); - u64 nn = mC->cols() * 2; - - for (u64 i = 0; i < mC->rows(); ++i) - { - auto row = mC->row(i); - PointList points(nn, nn); - - points.push_back({ i, n + i }); - assert(row[row.size() - 1] == i); - for (u64 j = 0; j < (u64)row.size() - 1; ++j) - { - points.push_back({ i,row[j] }); - } - - for (u64 j = 0; j < i; ++j) - { - points.push_back({ j,j }); - } - - for (u64 j = 0; j < n; ++j) - { - points.push_back({ n + j, n + j }); - } - steps.emplace_back(nn, nn, points); - - } - - return steps; - } - - // computes x = mC^-1 * y - void DiagInverter::mult(span y, span x) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mC); - assert(mC->rows() == y.size()); - assert(mC->cols() == x.size()); - - for (u64 i = 0; i < mC->rows(); ++i) - { - auto row = mC->row(i); - x[i] = y[i]; - - assert(row[row.size() - 1] == i); - for (u64 j = 0; j < (u64)row.size() - 1; ++j) - { - x[i] ^= x[row[j]]; - } - } - } - void DiagInverter::mult(const SparseMtx& y, SparseMtx& x) - { - auto n = mC->rows(); - assert(n == y.rows()); - //assert(n == x.rows()); - //assert(y.cols() == x.cols()); - - auto xNumRows = n; - auto xNumCols = y.cols(); - - std::vector& xCol = x.mDataCol; xCol.reserve(y.mDataCol.size()); - std::vector - colSizes(xNumCols), - rowSizes(xNumRows); - - for (u64 c = 0; c < y.cols(); ++c) - { - auto cc = y.col(c); - auto yIter = cc.begin(); - auto yEnd = cc.end(); - - auto xColBegin = xCol.size(); - for (u64 i = 0; i < n; ++i) - { - u8 bit = 0; - if (yIter != yEnd && *yIter == i) - { - bit = 1; - ++yIter; - } - - auto rr = mC->row(i); - auto mIter = rr.begin(); - auto mEnd = rr.end() - 1; - - auto xIter = xCol.begin() + xColBegin; - auto xEnd = xCol.end(); - - while (mIter != mEnd && xIter != xEnd) - { - if (*mIter < *xIter) - ++mIter; - else if (*xIter < *mIter) - ++xIter; - else - { - bit ^= 1; - ++xIter; - ++mIter; - } - } - - if (bit) - { - xCol.push_back(i); - ++rowSizes[i]; - } - } - colSizes[c] = xCol.size(); - } - - x.mCols.resize(colSizes.size()); - auto iter = xCol.begin(); - for (u64 i = 0; i < colSizes.size(); ++i) - { - auto end = xCol.begin() + colSizes[i]; - x.mCols[i] = end - iter ? SparseMtx::Col(span(iter, end)) : SparseMtx::Col{}; - iter = end; - } - - x.mRows.resize(rowSizes.size()); - x.mDataRow.resize(x.mDataCol.size()); - iter = x.mDataRow.begin(); - //auto prevSize = 0ull; - for (u64 i = 0; i < rowSizes.size(); ++i) - { - auto end = iter + rowSizes[i]; - - rowSizes[i] = 0; - //auto ss = rowSizes[i]; - //rowSizes[i] = rowSizes[i] - prevSize; - //prevSize = ss; - - x.mRows[i] = SparseMtx::Row(iter != end ? span(iter, end) : span{}); - iter = end; - } - - iter = xCol.begin(); - for (u64 i = 0; i < x.cols(); ++i) - { - for (u64 j : x.col(i)) - { - x.mRows[j][rowSizes[j]++] = i; - } - } - - } - - - - void SilverLeftEncoder::init(u64 rows, std::vector rs, u64 rep) - { - mRows = rows; - mWeight = rs.size(); - assert(mWeight > 4); - - mRs = rs; - mYs.resize(rs.size()); - std::set s; - std::vector reps(rep); - - u64 trials = 0; - for (u64 i = 0; i < mWeight; ++i) - { - mYs[i] = u64(rows * rs[i]) % rows; - while ( - //(rep && reps[mYs[i] % rep]) || - (rep && mYs[i] % rep) || - s.insert(mYs[i]).second == false) - { - mYs[i] = (mYs[i] + 1) % rows; - - if (++trials > 1000) - { - std::cout << "these ratios resulted in too many collisions. " LOCATION << std::endl; - throw std::runtime_error("these ratios resulted in too many collisions. " LOCATION); - } - } - - if(rep) - reps[mYs[i] % rep] = 1; - } - } - - void SilverLeftEncoder::init(u64 rows, SilverCode code, u64 rep) - { - auto weight = code.weight(); - switch (weight) - { - case 5: - init(rows, { { 0, 0.372071, 0.576568, 0.608917, 0.854475} }, rep); - - // 0 0.0494143 0.437702 0.603978 0.731941 - - // yset 3,785 - // 0 0.372071 0.576568 0.608917 0.854475 - break; - case 11: - init(rows, { { 0, 0.00278835, 0.0883852, 0.238023, 0.240532, 0.274624, 0.390639, 0.531551, 0.637619, 0.945265, 0.965874} }, rep); - // 0 0.00278835 0.0883852 0.238023 0.240532 0.274624 0.390639 0.531551 0.637619 0.945265 0.965874 - break; - default: - // no preset parameters - throw RTE_LOC; - } - } - - void SilverLeftEncoder::encode(span pp, span m) - { - auto cols = mRows; - assert(pp.size() == mRows); - assert(m.size() == cols); - - // pp = pp + A * m - auto v = mYs; - for (u64 i = 0; i < cols; ++i) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - pp[row] ^= m[i]; - - ++v[j]; - if (v[j] == mRows) - v[j] = 0; - } - } - } - - - void SilverLeftEncoder::getPoints(PointList& points) - { - auto cols = mRows; - auto v = mYs; - - for (u64 i = 0; i < cols; ++i) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - - points.push_back({ row, i }); - - ++v[j]; - if (v[j] == mRows) - v[j] = 0; - } - } - } - - SparseMtx SilverLeftEncoder::getMatrix() - { - PointList points(mRows, mRows); - getPoints(points); - return SparseMtx(mRows, mRows, points); - } - - void SilverLeftEncoder::getTransPoints(PointList& points) - { - - auto cols = mRows; - auto v = mYs; - - for (u64 i = 0; i < cols; ) - { - auto end = cols; - for (u64 j = 0; j < mWeight; ++j) - { - if (v[j] == mRows) - v[j] = 0; - - auto jEnd = cols - v[j] + i; - end = std::min(end, jEnd); - } - //T* __restrict P = &pp[i]; - //T* __restrict PE = &pp[end]; - - while (i != end) - { - points.push_back({ i,i }); - - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - points.push_back({ i, row + cols }); - ++v[j]; - } - ++i; - } - } - - } - - SparseMtx SilverLeftEncoder::getTransMatrix() - { - PointList points(mRows, 2 * mRows); - getTransPoints(points); - return points; - } - - void SilverRightEncoder::init(u64 rows, SilverCode c, bool extend) - { - mGap = c.gap(); - assert(mGap < rows); - mCode = c; - mRows = rows; - mExtend = extend; - mCols = extend ? rows : rows - mGap; - } - - void SilverRightEncoder::encode(span x, span y) - { - assert(mExtend); - for (u64 i = 0; i < mRows; ++i) - { - x[i] = y[i]; - if (mCode == SilverCode::Weight5) - { - for (u64 j = 0; j < 4; ++j) - { - auto col = i - 16 + diagMtx_g16_w5_seed1_t36[i & 15][j]; - if (col < mRows) - x[i] = x[i] ^ x[col]; - } - } - - if (mCode == SilverCode::Weight11) - { - for (u64 j = 0; j < 10; ++j) - { - auto col = i - 32 + diagMtx_g32_w11_seed2_t36[i & 31][j]; - if (col < mRows) - x[i] = x[i] ^ x[col]; - } - } - - for (u64 j = 0; j < mOffsets.size(); ++j) - { - auto p = i - mOffsets[j] - mGap; - if (p >= mRows) - break; - x[i] = x[i] ^ x[p]; - } - } - } - bool gVerbose = false; - void SilverRightEncoder::getPoints(PointList& points, u64 colOffset) - { - auto rr = mRows; - - for (u64 i = 0; i < rr; ++i) - { - if (i < mCols) - points.push_back({ i, i + colOffset }); - - switch (mCode) - { - case SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = i - 16 + diagMtx_g16_w5_seed1_t36[i & 15][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - case SilverCode::Weight11: - for (u64 j = 0; j < 10; ++j) - { - auto col = i - 32 + diagMtx_g32_w11_seed2_t36[i & 31][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - default: - break; - } - - for (u64 j = 0; j < mOffsets.size(); ++j) - { - auto col = i - mOffsets[j] - mGap; - if (col < mRows) - points.push_back({ i, col + colOffset }); - } - - if(gVerbose) - std::cout << i << "\n" << SparseMtx(points) << std::endl; - - } - - if (mExtend) - { - for (u64 i = rr; i < mRows; ++i) - points.push_back({ i, i + colOffset }); - } - } - - SparseMtx SilverRightEncoder::getMatrix() - { - PointList points(mRows, cols()); - getPoints(points, 0); - return SparseMtx(mRows, cols(), points); - } - std::vector SilverRightEncoder::getTransMatrices() - { - std::vector ret;ret.reserve(mRows); - auto colOffset = 0; - auto rr = mRows; - - for (u64 i = 0; i < rr; ++i) - { - PointList points(mRows, mRows); - for (u64 j = 0; j < mRows; ++j) - points.push_back({ j,j }); - - switch (mCode) - { - case SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = i - 16 + diagMtx_g16_w5_seed1_t36[i & 15][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - case SilverCode::Weight11: - for (u64 j = 0; j < 10; ++j) - { - auto col = i - 32 + diagMtx_g32_w11_seed2_t36[i & 31][j]; - if (col < mCols) - points.push_back({ i, col + colOffset }); - } - - break; - default: - break; - } - - for (u64 j = 0; j < mOffsets.size(); ++j) - { - auto col = i - mOffsets[j] - mGap; - if (col < mRows) - points.push_back({ i, col + colOffset }); - } - - if (gVerbose) - std::cout << i << "\n" << SparseMtx(points) << std::endl; - - ret.emplace_back(points); - } - - //if (mExtend) - //{ - // for (u64 i = rr; i < mRows; ++i) - // points.push_back({ i, i + colOffset }); - //} - return ret; - } - SparseMtx SilverRightEncoder::getTransMatrix() - { - auto Es = getTransMatrices(); - SparseMtx E = Es.back(); - //std::cout << "Eb" << E << std::endl; - for (u64 i = Es.size() - 2; i < Es.size(); --i) - { - E = E * Es[i]; - } - - return E; - } - } - - - void tests::LdpcEncoder_diagonalSolver_test() - { - u64 n = 10; - u64 w = 4; - u64 t = 10; - - oc::PRNG prng(block(0, 0)); - std::vector x(n), y(n); - for (u64 tt = 0; tt < t; ++tt) - { - SparseMtx H = sampleTriangular(n, 0.5, prng); - - //std::cout << H << std::endl; - - for (auto& yy : y) - yy = prng.getBit(); - - details::DiagInverter HInv(H); - - HInv.mult(y, x); - - auto z = H.mult(x); - - assert(z == y); - - auto Y = sampleFixedColWeight(n, w, 3, prng, false); - - SparseMtx X; - - HInv.mult(Y, X); - - auto Z = H * X; - - assert(Z == Y); - - } - - - - - return; - } - - void tests::LdpcEncoder_encode_test() - { - - u64 rows = 16; - u64 cols = rows * 2; - u64 colWeight = 4; - u64 dWeight = 3; - u64 gap = 6; - - auto k = cols - rows; - - assert(gap >= dWeight); - - oc::PRNG prng(block(0, 2)); - - - SparseMtx H; - LdpcEncoder E; - - - //while (b) - for (u64 i = 0; i < 40; ++i) - { - bool b = true; - //std::cout << " +====================" << std::endl; - while (b) - { - H = sampleTriangularBand( - rows, cols, - colWeight, gap, - dWeight, false, prng); - //H = sampleTriangular(rows, cols, colWeight, gap, prng); - b = !E.init(H, gap); - } - - //std::cout << H << std::endl; - - std::vector m(k), c(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - - E.encode(c, m); - - auto ss = H.mult(c); - - //for (auto sss : ss) - // std::cout << int(sss) << " "; - //std::cout << std::endl; - assert(ss == std::vector(H.rows(), 0)); - - } - return; - - } - - void tests::LdpcEncoder_encode_g0_test() - { - - u64 rows = 17; - u64 cols = rows * 2; - u64 colWeight = 4; - - auto k = cols - rows; - - oc::PRNG prng(block(0, 2)); - - - SparseMtx H; - LdpcEncoder E; - - - //while (b) - for (u64 i = 0; i < 40; ++i) - { - bool b = true; - //std::cout << " +====================" << std::endl; - while (b) - { - //H = sampleTriangularBand( - // rows, cols, - // colWeight, 0, - // 1, false, prng); - // - // - - - H = sampleTriangularBand( - rows, cols, - colWeight, 8, - colWeight, colWeight, 0, 0, { 5,31 }, true, true, true, prng, prng); - //H = sampleTriangular(rows, cols, colWeight, gap, prng); - b = !E.init(H, 0); - } - - //std::cout << H << std::endl; - - std::vector m(k), c(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - - E.encode(c, m); - - auto ss = H.mult(c); - - assert(ss == std::vector(H.rows(), 0)); - - } - return; - } - - - void tests::LdpcS1Encoder_encode_test() - { - u64 rows = 100; - SilverCode weight = SilverCode::Weight5; - - details::SilverLeftEncoder zz; - zz.init(rows, weight); - - std::vector m(rows), pp(rows); - - PRNG prng(ZeroBlock); - - for (u64 i = 0; i < rows; ++i) - m[i] = prng.getBit(); - - zz.encode(pp, m); - - auto p2 = zz.getMatrix().mult(m); - - if (p2 != pp) - { - throw RTE_LOC; - } - - } - - - - void tests::LdpcS1Encoder_encode_Trans_test() - { - u64 rows = 100; - SilverCode weight = SilverCode::Weight5; - - - details::SilverLeftEncoder zz; - zz.init(rows, weight); - - std::vector m(rows), pp(rows); - - PRNG prng(ZeroBlock); - - for (u64 i = 0; i < rows; ++i) - m[i] = prng.getBit(); - - zz.dualEncode(pp, m); - - auto At = zz.getMatrix().dense().transpose().sparse(); - auto p2 = At.mult(m); - - //std::cout << "At\n" << At << std::endl; - //std::cout << "M\n" << zz.getTransMatrix() << std::endl; - - if (p2 != pp) - { - throw RTE_LOC; - } - } - - - void tests::LdpcComposit_RegRepDiagBand_encode_test() - { - u64 rows = 100; - - - PRNG prng(ZeroBlock); - using namespace details; - using Encoder = SilverEncoder; - - for (auto code : { SilverCode::Weight5 , SilverCode::Weight11 }) - { - - //{ - // gVerbose = false; - // details::SilverRightEncoder rr; - // rr.init(rows, code, true); - // auto RR = rr.getMatrix(); - // std::cout << "R\n" << RR << std::endl; - // gVerbose = false; - //} - - Encoder enc; - enc.mL.init(rows, code); - enc.mR.init(rows, code, true); - - auto H = enc.getMatrix(); - - LdpcEncoder enc2; - enc2.init(H, 0); - - auto cols = enc.cols(); - auto k = cols - rows; - std::vector m(k), c(cols), c2(cols); - - for (auto& mm : m) - mm = prng.getBit(); - - enc.encode(c, m); - enc2.encode(c2, m); - - auto ss = H.mult(c); - - if (ss != std::vector(H.rows(), 0)) - throw RTE_LOC; - if (c2 != c) - throw RTE_LOC; - - auto R = enc.mR.getMatrix(); - auto M = enc.mR.getTransMatrix(); - //std::cout << "R\n" << R << std::endl; - //std::cout << "M\n" << M << std::endl; - - auto g = SilverCode::gap(code); - auto d1 = enc.mR.mOffsets[0] + g; - auto d2 = enc.mR.mOffsets[1] + g; - - - for (u64 i = 0; i < R.cols() - g; ++i) - { - std::set ss; - - auto col = R.col(i); - for (auto cc : col) - { - ss.insert(cc); - } - - auto expSize = SilverCode::weight(code); - - if (d1 < R.rows()) - { - ++expSize; - if (ss.find(d1) == ss.end()) - throw RTE_LOC; - } - - if (d2 < R.rows()) - { - ++expSize; - if (ss.find(d2) == ss.end()) - throw RTE_LOC; - } - - if (col.size() != expSize) - throw RTE_LOC; - - ++d1; - ++d2; - - - } - } - } - - - void tests::LdpcComposit_RegRepDiagBand_Trans_test() - { - - u64 rows = 101; - - using namespace details; - SilverCode code = SilverCode::Weight5; - - - using Encoder = LdpcCompositEncoder; - PRNG prng(ZeroBlock); - - Encoder enc; - enc.mL.init(rows, code); - enc.mR.init(rows, code, true); - - auto H = enc.getMatrix(); - auto HD = H.dense(); - auto Gt = computeGen(HD).transpose(); - - - LdpcEncoder enc2; - enc2.init(H, 0); - - - auto cols = enc.cols(); - auto k = cols - rows; - - std::vector c(cols); - - for (auto& cc : c) - cc = prng.getBit(); - //std::cout << "\n"; - - auto mOld = c; - enc2.dualEncode(mOld); - mOld.resize(k); - - ////std::cout << "R\n" << enc.mR.getMatrix() << std::endl << std::endl; - //std::cout << "L\n" << enc.mL.getTransMatrix() << std::endl << std::endl; - - auto mCur = c; - enc.dualEncode(mCur); - mCur.resize(k); - } - - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/LdpcEncoder.h b/libOTe/Tools/LDPC/LdpcEncoder.h deleted file mode 100644 index e1a53413..00000000 --- a/libOTe/Tools/LDPC/LdpcEncoder.h +++ /dev/null @@ -1,1175 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" -#ifdef ENABLE_INSECURE_SILVER - -#include "Mtx.h" -#include "cryptoTools/Crypto/PRNG.h" -#include "cryptoTools/Common/Timer.h" -#include - -namespace osuCrypto -{ - namespace details - { - class DiagInverter - { - public: - - const SparseMtx* mC = nullptr; - - DiagInverter() = default; - DiagInverter(const DiagInverter&) = default; - DiagInverter& operator=(const DiagInverter&) = default; - - DiagInverter(const SparseMtx& c) { init(c); } - - void init(const SparseMtx& c); - - // returns a list of matrices to multiply with to encode. - std::vector getSteps(); - - // computes x = mC^-1 * y - void mult(span y, span x); - - // computes x = mC^-1 * y - void mult(const SparseMtx& y, SparseMtx& x); - - template - void cirTransMult(span x, span y) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mC); - assert(mC->cols() == x.size()); - - for (u64 i = mC->rows() - 1; i != ~u64(0); --i) - { - auto row = mC->row(i); - assert(row[row.size() - 1] == i); - for (u64 j = 0; j < (u64)row.size() - 1; ++j) - { - auto col = row[j]; - assert(col < i); - x[col] = x[col] ^ x[i]; - } - } - } - }; - } - - // a generic encoder for any g-ALT LDPC code - class LdpcEncoder - { - public: - - LdpcEncoder() = default; - LdpcEncoder(const LdpcEncoder&) = default; - LdpcEncoder(LdpcEncoder&&) = default; - - - u64 mN, mM, mGap; - SparseMtx mA, mH; - SparseMtx mB; - SparseMtx mC; - SparseMtx mD; - SparseMtx mE, mEp; - SparseMtx mF; - details::DiagInverter mCInv; - - // initialize the encoder with the given matrix which is - // in gap-ALT form. - bool init(SparseMtx mtx, u64 gap); - - // encode the given message m and populate c with - // the resulting codeword. - void encode(span c, span m); - - // perform the circuit transpose of the encoding algorithm. - // the inputs and output is c. - template - void dualEncode(span c) - { - if (mGap) - throw std::runtime_error(LOCATION); - assert(c.size() == mN); - - auto k = mN - mM; - span pp(c.subspan(k, mM)); - span mm(c.subspan(0, k)); - - mCInv.cirTransMult(pp, mm); - - for (u64 i = 0; i < k; ++i) - { - for (auto row : mA.col(i)) - { - c[i] = c[i] ^ pp[row]; - } - } - } - }; - - // INSECURE - // enum struct to specify which silver code variant to use. - // https://eprint.iacr.org/2021/1150 - // see also https://eprint.iacr.org/2023/882 - struct SilverCode - { - enum code - { - Weight5 = 5, - Weight11 = 11, - }; - code mCode; - - SilverCode() = default; - SilverCode(const SilverCode&) = default; - SilverCode& operator=(const SilverCode&) = default; - SilverCode(const code& c) : mCode(c) {} - - bool operator==(code c) { return mCode == c; } - bool operator!=(code c) { return mCode != c; } - operator code() - { - return mCode; - } - - u64 weight() { return weight(mCode); } - u64 gap() { - return gap(mCode); - } - static u64 weight(code c) - { - return (u64)c; - } - static u64 gap(code c) - { - switch (c) - { - case Weight5: - return 16; - break; - case Weight11: - return 32; - break; - default: - throw RTE_LOC; - break; - } - } - }; - - namespace details - { - - // the silver encoder for the left half of the matrix. - // This part of the code constists of mWeight diagonals. - // The positions of these diagonals is determined by - class SilverLeftEncoder - { - public: - - u64 mRows, mWeight; - std::vector mYs; - std::vector mRs; - - // a custom initialization which the specifed diagonal - // factional positions. - void init(u64 rows, std::vector rs, u64 rep = 0); - - // initialize the left half with the given silver presets. - void init(u64 rows, SilverCode code, u64 rep = 0); - - // encode the given message m and populate c with - // the resulting codeword. - void encode(span pp, span m); - - u64 cols() { return mRows; } - - u64 rows() { return mRows; } - - // populates points with the matrix representation of this - // encoder. - void getPoints(PointList& points); - - // return parity check matrix representation of this encoder. - SparseMtx getMatrix(); - - - // return generator matrix representation of this encoder. - void getTransPoints(PointList& points); - - // return generator matrix representation of this encoder. - SparseMtx getTransMatrix(); - - // perform the circuit transpose of the encoding algorithm. - // the output it written to ppp. - template - void dualEncode(span ppp, span mm); - - // perform the circuit transpose of the encoding algorithm twice. - // the output it written to ppp0 and ppp1. - template - void dualEncode2( - span ppp0, span ppp1, - span mm0, span mm1); - }; - - - class SilverRightEncoder - { - public: - - static constexpr std::array, 16> diagMtx_g16_w5_seed1_t36 - { { - { {0, 4, 11, 15 }}, - { {0, 8, 9, 10 } }, - { {1, 2, 10, 14 } }, - { {0, 5, 8, 15 } }, - { {3, 13, 14, 15 } }, - { {2, 4, 7, 8 } }, - { {0, 9, 12, 15 } }, - { {1, 6, 8, 14 } }, - { {4, 5, 6, 14 } }, - { {1, 3, 8, 13 } }, - { {3, 4, 7, 8 } }, - { {3, 5, 9, 13 } }, - { {8, 11, 12, 14 } }, - { {6, 10, 12, 13 } }, - { {2, 7, 8, 13 } }, - { {0, 6, 10, 15 } } - } }; - - - static constexpr std::array, 32> diagMtx_g32_w11_seed2_t36 - { { - { { 6, 7, 8, 12, 16, 17, 20, 22, 24, 25 } }, - { { 0, 1, 6, 10, 12, 13, 17, 19, 30, 31 } }, - { { 1, 4, 7, 10, 12, 16, 21, 22, 30, 31 } }, - { { 3, 5, 9, 13, 15, 21, 23, 25, 26, 27 } }, - { { 3, 8, 9, 14, 17, 19, 24, 25, 26, 28 } }, - { { 3, 11, 12, 13, 14, 16, 17, 21, 22, 30 } }, - { { 2, 4, 5, 11, 12, 17, 22, 24, 30, 31 } }, - { { 5, 8, 11, 12, 13, 17, 18, 20, 27, 29 } }, - { {13, 16, 17, 18, 19, 20, 21, 22, 26, 30 } }, - { { 3, 8, 13, 15, 16, 17, 19, 20, 21, 27 } }, - { { 0, 2, 4, 5, 6, 21, 23, 26, 28, 30 } }, - { { 2, 4, 6, 8, 10, 11, 22, 26, 28, 30 } }, - { { 7, 9, 11, 14, 15, 16, 17, 18, 24, 30 } }, - { { 0, 3, 7, 12, 13, 18, 20, 24, 25, 28 } }, - { { 1, 5, 7, 8, 12, 13, 21, 24, 26, 27 } }, - { { 0, 16, 17, 19, 22, 24, 25, 27, 28, 31 } }, - { { 0, 6, 7, 15, 16, 18, 22, 24, 29, 30 } }, - { { 2, 3, 4, 7, 15, 17, 18, 20, 22, 26 } }, - { { 2, 3, 9, 16, 17, 19, 24, 27, 29, 31 } }, - { { 1, 3, 5, 7, 13, 14, 20, 23, 24, 27 } }, - { { 0, 2, 3, 9, 10, 14, 19, 20, 21, 25 } }, - { { 4, 13, 16, 20, 21, 23, 25, 27, 28, 31 } }, - { { 1, 2, 5, 6, 9, 13, 15, 17, 20, 24 } }, - { { 0, 4, 7, 8, 12, 13, 20, 23, 28, 30 } }, - { { 0, 3, 4, 5, 8, 9, 23, 25, 26, 28 } }, - { { 0, 3, 4, 7, 8, 10, 11, 15, 21, 26 } }, - { { 5, 6, 7, 8, 10, 11, 15, 21, 22, 25 } }, - { { 0, 1, 2, 3, 8, 9, 22, 24, 27, 28 } }, - { { 1, 2, 13, 14, 15, 16, 19, 22, 29, 30 } }, - { { 2, 14, 15, 16, 19, 20, 25, 26, 28, 29 } }, - { { 8, 9, 11, 12, 13, 15, 17, 18, 23, 27 } }, - { { 0, 2, 4, 5, 6, 7, 10, 12, 14, 19 } } - } }; - - static constexpr std::array mOffsets{ {5,31} }; - - u64 mGap; - - u64 mRows, mCols; - SilverCode mCode; - bool mExtend; - - // initialize the right half of the silver encoder - // with the given preset. extend should be true if - // this is used to encode and false to determine the - // effective minimum distance. - void init(u64 rows, SilverCode c, bool extend); - - u64 cols() { return mCols; } - - u64 rows() { return mRows; } - - // encode the given message m and populate c with - // the resulting codeword. - void encode(span x, span y); - - // populates points with the matrix representation of this - // encoder. - void getPoints(PointList& points, u64 colOffset); - - // return matrix representation of this encoder. - SparseMtx getMatrix(); - - - // return generator matrix representation of this encoder. - std::vector getTransMatrices(); - SparseMtx getTransMatrix(); - - - // perform the circuit transpose of the encoding algorithm. - // the inputs and output is x. - template - void dualEncode(span x); - - // perform the circuit transpose of the encoding algorithm twice. - // the inputs and output is x0 and x1. - template - void dualEncode2(span x0, span x1); - }; - - // a full encoder expressed and the left and right encoder. - template - class LdpcCompositEncoder : public TimerAdapter - { - public: - - LEncoder mL; - REncoder mR; - - template - void encode(spanc, span mm) - { - assert(mm.size() == cols() - rows()); - assert(c.size() == cols()); - - auto s = rows(); - auto iter = c.begin() + s; - span m(c.begin(), iter); - span pp(iter, c.end()); - - // m = mm - std::copy(mm.begin(), mm.end(), m.begin()); - std::fill(c.begin() + s, c.end(), 0); - - // pp = A * m - mL.encode(pp, mm); - - // pp = C^-1 pp - mR.encode(pp, pp); - } - - template - void dualEncode(span c) - { - auto k = cols() - rows(); - assert(c.size() == cols()); - setTimePoint("encode_begin"); - span pp(c.subspan(k, rows())); - - mR.template dualEncode(pp); - setTimePoint("diag"); - mL.template dualEncode(c.subspan(0, k), pp); - setTimePoint("L"); - - } - - - template - void dualEncode2(span c0, span c1) - { - auto k = cols() - rows(); - assert(c0.size() == cols()); - - setTimePoint("encode_begin"); - span pp0(c0.subspan(k, rows())); - span pp1(c1.subspan(k, rows())); - - mR.template dualEncode2(pp0, pp1); - - setTimePoint("diag"); - mL.template dualEncode2(c0.subspan(0, k), c1.subspan(0, k), pp0, pp1); - setTimePoint("L"); - } - - - u64 cols() { return mL.cols() + mR.cols(); } - - u64 rows() { return mR.rows(); } - - void getPoints(PointList& points) - { - mL.getPoints(points); - mR.getPoints(points, mL.cols()); - } - - SparseMtx getMatrix() - { - PointList points(rows(), cols()); - getPoints(points); - return SparseMtx(rows(), cols(), points); - } - - }; - - } - - // the full silver encoder which is composed - // of the left and right sub-encoders. - struct SilverEncoder : public details::LdpcCompositEncoder - { - void init(u64 rows, SilverCode code) - { - mL.init(rows, code); - mR.init(rows, code, true); - } - }; - - - namespace tests - { - - void LdpcEncoder_diagonalSolver_test(); - void LdpcEncoder_encode_test(); - void LdpcEncoder_encode_g0_test(); - - void LdpcS1Encoder_encode_test(); - void LdpcS1Encoder_encode_Trans_test(); - - void LdpcComposit_RegRepDiagBand_encode_test(); - void LdpcComposit_RegRepDiagBand_Trans_test(); - - } - - - - - - - - - - - - - - - - - - - - - - - - // perform the circuit transpose of the encoding algorithm. - // the output it written to ppp. - template - void details::SilverLeftEncoder::dualEncode(span ppp, span mm) - { - auto cols = mRows; - assert(ppp.size() == mRows); - assert(mm.size() == cols); - - auto v = mYs; - T* __restrict pp = ppp.data(); - const T* __restrict m = mm.data(); - - for (u64 i = 0; i < cols; ) - { - auto end = cols; - for (u64 j = 0; j < mWeight; ++j) - { - if (v[j] == mRows) - v[j] = 0; - - auto jEnd = cols - v[j] + i; - end = std::min(end, jEnd); - } - T* __restrict P = &pp[i]; - T* __restrict PE = &pp[end]; - - switch (mWeight) - { - case 5: - { - const T* __restrict M0 = &m[v[0]]; - const T* __restrict M1 = &m[v[1]]; - const T* __restrict M2 = &m[v[2]]; - const T* __restrict M3 = &m[v[3]]; - const T* __restrict M4 = &m[v[4]]; - - v[0] += end - i; - v[1] += end - i; - v[2] += end - i; - v[3] += end - i; - v[4] += end - i; - i = end; - - while (P != PE) - { - *P = *P - ^ *M0 - ^ *M1 - ^ *M2 - ^ *M3 - ^ *M4 - ; - - ++M0; - ++M1; - ++M2; - ++M3; - ++M4; - ++P; - } - - - break; - } - case 11: - { - - const T* __restrict M0 = &m[v[0]]; - const T* __restrict M1 = &m[v[1]]; - const T* __restrict M2 = &m[v[2]]; - const T* __restrict M3 = &m[v[3]]; - const T* __restrict M4 = &m[v[4]]; - const T* __restrict M5 = &m[v[5]]; - const T* __restrict M6 = &m[v[6]]; - const T* __restrict M7 = &m[v[7]]; - const T* __restrict M8 = &m[v[8]]; - const T* __restrict M9 = &m[v[9]]; - const T* __restrict M10 = &m[v[10]]; - - v[0] += end - i; - v[1] += end - i; - v[2] += end - i; - v[3] += end - i; - v[4] += end - i; - v[5] += end - i; - v[6] += end - i; - v[7] += end - i; - v[8] += end - i; - v[9] += end - i; - v[10] += end - i; - i = end; - - while (P != PE) - { - *P = *P - ^ *M0 - ^ *M1 - ^ *M2 - ^ *M3 - ^ *M4 - ^ *M5 - ^ *M6 - ^ *M7 - ^ *M8 - ^ *M9 - ^ *M10 - ; - - ++M0; - ++M1; - ++M2; - ++M3; - ++M4; - ++M5; - ++M6; - ++M7; - ++M8; - ++M9; - ++M10; - ++P; - } - - break; - } - default: - while (i != end) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - pp[i] = pp[i] ^ m[row]; - ++v[j]; - } - ++i; - } - break; - } - - } - } - - // perform the circuit transpose of the encoding algorithm twice. - // the output it written to ppp0 and ppp1. - template - void details::SilverLeftEncoder::dualEncode2( - span ppp0, span ppp1, - span mm0, span mm1) - { - auto cols = mRows; - // pp = pp + m * A - auto v = mYs; - T0* __restrict pp0 = ppp0.data(); - T1* __restrict pp1 = ppp1.data(); - const T0* __restrict m0 = mm0.data(); - const T1* __restrict m1 = mm1.data(); - - for (u64 i = 0; i < cols; ) - { - auto end = cols; - for (u64 j = 0; j < mWeight; ++j) - { - if (v[j] == mRows) - v[j] = 0; - - auto jEnd = cols - v[j] + i; - end = std::min(end, jEnd); - } - switch (mWeight) - { - case 5: - while (i != end) - { - auto& r0 = v[0]; - auto& r1 = v[1]; - auto& r2 = v[2]; - auto& r3 = v[3]; - auto& r4 = v[4]; - - pp0[i] = pp0[i] - ^ m0[r0] - ^ m0[r1] - ^ m0[r2] - ^ m0[r3] - ^ m0[r4]; - - pp1[i] = pp1[i] - ^ m1[r0] - ^ m1[r1] - ^ m1[r2] - ^ m1[r3] - ^ m1[r4]; - - ++r0; - ++r1; - ++r2; - ++r3; - ++r4; - ++i; - } - break; - case 11: - while (i != end) - { - auto& r0 = v[0]; - auto& r1 = v[1]; - auto& r2 = v[2]; - auto& r3 = v[3]; - auto& r4 = v[4]; - auto& r5 = v[5]; - auto& r6 = v[6]; - auto& r7 = v[7]; - auto& r8 = v[8]; - auto& r9 = v[9]; - auto& r10 = v[10]; - - pp0[i] = pp0[i] - ^ m0[r0] - ^ m0[r1] - ^ m0[r2] - ^ m0[r3] - ^ m0[r4] - ^ m0[r5] - ^ m0[r6] - ^ m0[r7] - ^ m0[r8] - ^ m0[r9] - ^ m0[r10] - ; - - pp1[i] = pp1[i] - ^ m1[r0] - ^ m1[r1] - ^ m1[r2] - ^ m1[r3] - ^ m1[r4] - ^ m1[r5] - ^ m1[r6] - ^ m1[r7] - ^ m1[r8] - ^ m1[r9] - ^ m1[r10] - ; - - ++r0; - ++r1; - ++r2; - ++r3; - ++r4; - ++r5; - ++r6; - ++r7; - ++r8; - ++r9; - ++r10; - ++i; - } - - break; - default: - while (i != end) - { - for (u64 j = 0; j < mWeight; ++j) - { - auto row = v[j]; - pp0[i] = pp0[i] ^ m0[row]; - pp1[i] = pp1[i] ^ m1[row]; - ++v[j]; - } - ++i; - } - break; - } - } - } - - - template - void details::SilverRightEncoder::dualEncode(span x) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mExtend); - assert(cols() == x.size()); - - constexpr int FIXED_OFFSET_SIZE = 2; - if (mOffsets.size() != FIXED_OFFSET_SIZE) - throw RTE_LOC; - - std::vector offsets(mOffsets.size()); - for (u64 j = 0; j < offsets.size(); ++j) - { - offsets[j] = mRows - 1 - mOffsets[j] - mGap; - } - - u64 i = mRows - 1; - T* __restrict ofCol0 = &x[offsets[0]]; - T* __restrict ofCol1 = &x[offsets[1]]; - T* __restrict xi = &x[i]; - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - { - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 16); - - T* __restrict xx = xi - 16; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 16; ++jj) - { - - auto col0 = diagMtx_g16_w5_seed1_t36[i & 15][0]; - auto col1 = diagMtx_g16_w5_seed1_t36[i & 15][1]; - auto col2 = diagMtx_g16_w5_seed1_t36[i & 15][2]; - auto col3 = diagMtx_g16_w5_seed1_t36[i & 15][3]; - - T* __restrict xc0 = xx + col0; - T* __restrict xc1 = xx + col1; - T* __restrict xc2 = xx + col2; - T* __restrict xc3 = xx + col3; - - *xc0 = *xc0 ^ *xi; - *xc1 = *xc1 ^ *xi; - *xc2 = *xc2 ^ *xi; - *xc3 = *xc3 ^ *xi; - - *ofCol0 = *ofCol0 ^ *xi; - *ofCol1 = *ofCol1 ^ *xi; - - - --ofCol0; - --ofCol1; - - --xx; - --xi; - --i; - } - } - - break; - } - case osuCrypto::SilverCode::Weight11: - { - - - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 32); - - T* __restrict xx = xi - 32; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 32; ++jj) - { - - auto col0 = diagMtx_g32_w11_seed2_t36[i & 31][0]; - auto col1 = diagMtx_g32_w11_seed2_t36[i & 31][1]; - auto col2 = diagMtx_g32_w11_seed2_t36[i & 31][2]; - auto col3 = diagMtx_g32_w11_seed2_t36[i & 31][3]; - auto col4 = diagMtx_g32_w11_seed2_t36[i & 31][4]; - auto col5 = diagMtx_g32_w11_seed2_t36[i & 31][5]; - auto col6 = diagMtx_g32_w11_seed2_t36[i & 31][6]; - auto col7 = diagMtx_g32_w11_seed2_t36[i & 31][7]; - auto col8 = diagMtx_g32_w11_seed2_t36[i & 31][8]; - auto col9 = diagMtx_g32_w11_seed2_t36[i & 31][9]; - - T* __restrict xc0 = xx + col0; - T* __restrict xc1 = xx + col1; - T* __restrict xc2 = xx + col2; - T* __restrict xc3 = xx + col3; - T* __restrict xc4 = xx + col4; - T* __restrict xc5 = xx + col5; - T* __restrict xc6 = xx + col6; - T* __restrict xc7 = xx + col7; - T* __restrict xc8 = xx + col8; - T* __restrict xc9 = xx + col9; - - *xc0 = *xc0 ^ *xi; - *xc1 = *xc1 ^ *xi; - *xc2 = *xc2 ^ *xi; - *xc3 = *xc3 ^ *xi; - *xc4 = *xc4 ^ *xi; - *xc5 = *xc5 ^ *xi; - *xc6 = *xc6 ^ *xi; - *xc7 = *xc7 ^ *xi; - *xc8 = *xc8 ^ *xi; - *xc9 = *xc9 ^ *xi; - - *ofCol0 = *ofCol0 ^ *xi; - *ofCol1 = *ofCol1 ^ *xi; - - - --ofCol0; - --ofCol1; - - --xx; - --xi; - --i; - } - } - - break; - } - default: - throw RTE_LOC; - break; - } - - offsets[0] = ofCol0 - x.data(); - offsets[1] = ofCol1 - x.data(); - - for (; i != ~u64(0); --i) - { - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = diagMtx_g16_w5_seed1_t36[i & 15][j] + i - 16; - if (col < mRows) - x[col] = x[col] ^ x[i]; - } - break; - case osuCrypto::SilverCode::Weight11: - - for (u64 j = 0; j < 10; ++j) - { - auto col = diagMtx_g32_w11_seed2_t36[i & 31][j] + i - 32; - if (col < mRows) - x[col] = x[col] ^ x[i]; - } - break; - default: - break; - } - - for (u64 j = 0; j < FIXED_OFFSET_SIZE; ++j) - { - auto& col = offsets[j]; - - if (col >= mRows) - break; - assert(i - mOffsets[j] - mGap == col); - - x[col] = x[col] ^ x[i]; - --col; - } - } - } - - template - void details::SilverRightEncoder::dualEncode2(span x0, span x1) - { - // solves for x such that y = M x, ie x := H^-1 y - assert(mExtend); - assert(cols() == x0.size()); - assert(cols() == x1.size()); - - constexpr int FIXED_OFFSET_SIZE = 2; - if (mOffsets.size() != FIXED_OFFSET_SIZE) - throw RTE_LOC; - - std::vector offsets(mOffsets.size()); - for (u64 j = 0; j < offsets.size(); ++j) - { - offsets[j] = mRows - 1 - mOffsets[j] - mGap; - } - - u64 i = mRows - 1; - T0* __restrict ofCol00 = &x0[offsets[0]]; - T0* __restrict ofCol10 = &x0[offsets[1]]; - T1* __restrict ofCol01 = &x1[offsets[0]]; - T1* __restrict ofCol11 = &x1[offsets[1]]; - T0* __restrict xi0 = &x0[i]; - T1* __restrict xi1 = &x1[i]; - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - { - - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 16); - - T0* __restrict xx0 = xi0 - 16; - T1* __restrict xx1 = xi1 - 16; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 16; ++jj) - { - - auto col0 = diagMtx_g16_w5_seed1_t36[i & 15][0]; - auto col1 = diagMtx_g16_w5_seed1_t36[i & 15][1]; - auto col2 = diagMtx_g16_w5_seed1_t36[i & 15][2]; - auto col3 = diagMtx_g16_w5_seed1_t36[i & 15][3]; - - T0* __restrict xc00 = xx0 + col0; - T0* __restrict xc10 = xx0 + col1; - T0* __restrict xc20 = xx0 + col2; - T0* __restrict xc30 = xx0 + col3; - T1* __restrict xc01 = xx1 + col0; - T1* __restrict xc11 = xx1 + col1; - T1* __restrict xc21 = xx1 + col2; - T1* __restrict xc31 = xx1 + col3; - - *xc00 = *xc00 ^ *xi0; - *xc10 = *xc10 ^ *xi0; - *xc20 = *xc20 ^ *xi0; - *xc30 = *xc30 ^ *xi0; - - *xc01 = *xc01 ^ *xi1; - *xc11 = *xc11 ^ *xi1; - *xc21 = *xc21 ^ *xi1; - *xc31 = *xc31 ^ *xi1; - - *ofCol00 = *ofCol00 ^ *xi0; - *ofCol10 = *ofCol10 ^ *xi0; - *ofCol01 = *ofCol01 ^ *xi1; - *ofCol11 = *ofCol11 ^ *xi1; - - - --ofCol00; - --ofCol10; - --ofCol01; - --ofCol11; - - --xx0; - --xx1; - --xi0; - --xi1; - --i; - } - } - - break; - } - case osuCrypto::SilverCode::Weight11: - { - - - auto mainEnd = - roundUpTo( - *std::max_element(mOffsets.begin(), mOffsets.end()) - + mGap, - 32); - - T0* __restrict xx0 = xi0 - 32; - T1* __restrict xx1 = xi1 - 32; - - for (; i > mainEnd;) - { - for (u64 jj = 0; jj < 32; ++jj) - { - - auto col0 = diagMtx_g32_w11_seed2_t36[i & 31][0]; - auto col1 = diagMtx_g32_w11_seed2_t36[i & 31][1]; - auto col2 = diagMtx_g32_w11_seed2_t36[i & 31][2]; - auto col3 = diagMtx_g32_w11_seed2_t36[i & 31][3]; - auto col4 = diagMtx_g32_w11_seed2_t36[i & 31][4]; - auto col5 = diagMtx_g32_w11_seed2_t36[i & 31][5]; - auto col6 = diagMtx_g32_w11_seed2_t36[i & 31][6]; - auto col7 = diagMtx_g32_w11_seed2_t36[i & 31][7]; - auto col8 = diagMtx_g32_w11_seed2_t36[i & 31][8]; - auto col9 = diagMtx_g32_w11_seed2_t36[i & 31][9]; - - T0* __restrict xc00 = xx0 + col0; - T0* __restrict xc10 = xx0 + col1; - T0* __restrict xc20 = xx0 + col2; - T0* __restrict xc30 = xx0 + col3; - T0* __restrict xc40 = xx0 + col4; - T0* __restrict xc50 = xx0 + col5; - T0* __restrict xc60 = xx0 + col6; - T0* __restrict xc70 = xx0 + col7; - T0* __restrict xc80 = xx0 + col8; - T0* __restrict xc90 = xx0 + col9; - - T1* __restrict xc01 = xx1 + col0; - T1* __restrict xc11 = xx1 + col1; - T1* __restrict xc21 = xx1 + col2; - T1* __restrict xc31 = xx1 + col3; - T1* __restrict xc41 = xx1 + col4; - T1* __restrict xc51 = xx1 + col5; - T1* __restrict xc61 = xx1 + col6; - T1* __restrict xc71 = xx1 + col7; - T1* __restrict xc81 = xx1 + col8; - T1* __restrict xc91 = xx1 + col9; - - *xc00 = *xc00 ^ *xi0; - *xc10 = *xc10 ^ *xi0; - *xc20 = *xc20 ^ *xi0; - *xc30 = *xc30 ^ *xi0; - *xc40 = *xc40 ^ *xi0; - *xc50 = *xc50 ^ *xi0; - *xc60 = *xc60 ^ *xi0; - *xc70 = *xc70 ^ *xi0; - *xc80 = *xc80 ^ *xi0; - *xc90 = *xc90 ^ *xi0; - - *xc01 = *xc01 ^ *xi1; - *xc11 = *xc11 ^ *xi1; - *xc21 = *xc21 ^ *xi1; - *xc31 = *xc31 ^ *xi1; - *xc41 = *xc41 ^ *xi1; - *xc51 = *xc51 ^ *xi1; - *xc61 = *xc61 ^ *xi1; - *xc71 = *xc71 ^ *xi1; - *xc81 = *xc81 ^ *xi1; - *xc91 = *xc91 ^ *xi1; - - *ofCol00 = *ofCol00 ^ *xi0; - *ofCol10 = *ofCol10 ^ *xi0; - - *ofCol01 = *ofCol01 ^ *xi1; - *ofCol11 = *ofCol11 ^ *xi1; - - - --ofCol00; - --ofCol10; - --ofCol01; - --ofCol11; - - --xx0; - --xx1; - - --xi0; - --xi1; - --i; - } - } - - break; - } - default: - throw RTE_LOC; - break; - } - - offsets[0] = ofCol00 - x0.data(); - offsets[1] = ofCol10 - x0.data(); - - for (; i != ~u64(0); --i) - { - - switch (mCode) - { - case osuCrypto::SilverCode::Weight5: - - for (u64 j = 0; j < 4; ++j) - { - auto col = diagMtx_g16_w5_seed1_t36[i & 15][j] + i - 16; - if (col < mRows) - { - x0[col] = x0[col] ^ x0[i]; - x1[col] = x1[col] ^ x1[i]; - } - } - break; - case osuCrypto::SilverCode::Weight11: - - for (u64 j = 0; j < 10; ++j) - { - auto col = diagMtx_g32_w11_seed2_t36[i & 31][j] + i - 32; - if (col < mRows) - { - x0[col] = x0[col] ^ x0[i]; - x1[col] = x1[col] ^ x1[i]; - } - } - break; - default: - break; - } - - for (u64 j = 0; j < FIXED_OFFSET_SIZE; ++j) - { - auto& col = offsets[j]; - - if (col >= mRows) - break; - assert(i - mOffsets[j] - mGap == col); - - x0[col] = x0[col] ^ x0[i]; - x1[col] = x1[col] ^ x1[i]; - --col; - } - } - } - - - - -} -#endif // ENABLE_INSECURE_SILVER - diff --git a/libOTe/Tools/LDPC/LdpcImpulseDist.cpp b/libOTe/Tools/LDPC/LdpcImpulseDist.cpp deleted file mode 100644 index 325b5595..00000000 --- a/libOTe/Tools/LDPC/LdpcImpulseDist.cpp +++ /dev/null @@ -1,1132 +0,0 @@ - - -#define _CRT_SECURE_NO_WARNINGS -#include "LdpcImpulseDist.h" - -#ifdef ENABLE_LDPC - -#include "LdpcDecoder.h" -#include "Util.h" -#include -#include -#include "LdpcSampler.h" -#include "cryptoTools/Common/Timer.h" -#include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" -#include "libOTe/Tools/Tools.h" - -#ifdef ENABLE_ALGO994 -extern "C" { -#include "libOTe/Tools/LDPC/Algo994/data_defs.h" -} -#endif - -#include -#include - -#include // put_time - -namespace osuCrypto -{ - - - struct ListIter - { - std::vector set; - - ListDecoder mType; - u64 mTotalI = 0, mTotalEnd; - u64 mI = 0, mEnd = 1, mN = 0, mCurWeight = 0; - - u64 mD = 1; - void initChase(u64 d) - { - assert(d < 24); - mType = ListDecoder::Chase; - mD = d; - mI = 0; - ++(*this); - } - - void initOsd(u64 d, u64 n, bool startAtZero) - { - assert(d < 24); - mType = ListDecoder::OSD; - - mN = n; - mTotalI = 0; - mTotalEnd = 1ull << d; - - mCurWeight = startAtZero ? 0 : 1; - mI = 0; - mEnd = choose(n, mCurWeight); - - set = ithCombination(mI, n, mCurWeight); - } - - - void operator++() - { - assert(*this); - if (mType == ListDecoder::Chase) - { - set.clear(); - ++mI; - oc::BitIterator ii((u8*)&mI); - for (u64 i = 0; i < mD; ++i) - { - if (*ii) - set.push_back(i); - ++ii; - } - } - else - { - - ++mI; - ++mTotalI; - - if (mI == mEnd) - { - mI = 0; - ++mCurWeight; - mEnd = choose(mN, mCurWeight); - } - - if (mN >= mCurWeight) - { - set.resize(mCurWeight); - ithCombination(mI, mN, set); - } - else - set.clear(); - } - } - - std::vector& operator*() - { - return set; - } - - - operator bool() const - { - if (mType == ListDecoder::Chase) - return mI != (1ull << mD); - else - { - return mTotalI != mTotalEnd && mCurWeight <= mN; - } - } - }; - - - template - void sort_indexes(span v, span idx) { - - // initialize original index locations - assert(v.size() == idx.size()); - std::iota(idx.begin(), idx.end(), 0); - - //std::partial_sort() - std::stable_sort(idx.begin(), idx.end(), - [&v](size_t i1, size_t i2) {return v[i1] < v[i2]; }); - } - - std::mutex minWeightMtx; - u32 minWeight(0); - std::vector minCW; - - std::unordered_set heatSet; - std::vector heatMap; - std::vector heatMapCount; - u64 nextTimeoutIdx; - - struct Worker - { - std::vector llr; - std::vector y; - std::vector llrSetList; - std::vector codeword; - std::vector sortIdx, permute, weights; - std::vector> backProps; - LdpcDecoder D; - DynSparseMtx H; - DenseMtx DH; - std::vector dSet, eSet; - std::unordered_set eeSet; - bool verbose = false; - BPAlgo algo = BPAlgo::LogBP; - ListDecoder listDecoder = ListDecoder::OSD; - u64 Ng = 3; - oc::PRNG prng; - bool abs = false; - double timeout; - - void impulseDist(u64 i, u64 k, u64 Nd, u64 maxIter, bool randImpulse) - { - if (prng.mBufferByteCapacity == 0) - prng.SetSeed(oc::sysRandomSeed()); - - auto n = D.mH.cols(); - auto m = D.mH.rows(); - - llr.resize(n);// , lr.resize(n); - y.resize(m); - codeword.resize(n); - sortIdx.resize(n); - backProps.resize(m); - weights.resize(n); - - //auto p = 0.9999; - auto llr0 = LdpcDecoder::encodeLLR(0.501, 0); - auto llr1 = LdpcDecoder::encodeLLR(0.999, 1); - //auto lr0 = encodeLR(0.501, 0); - //auto lr1 = encodeLR(0.9999999, 1); - - std::fill(llr.begin(), llr.end(), llr0); - std::vector impulse; - if (randImpulse) - { - std::set set; - //u64 w = prng.get(); - //w = (w % (k)) + 1; - //assert(k + 1 < n); - while (set.size() != k) - { - auto i = prng.get() % n; - set.insert(i); - } - impulse.insert(impulse.end(), set.begin(), set.end()); - } - else - { - impulse = ithCombination(i, n, k); - - } - - for (auto i : impulse) - llr[i] = llr1; - - switch (algo) - { - case BPAlgo::LogBP: - D.logbpDecode2(llr, maxIter); - break; - case BPAlgo::AltLogBP: - D.altDecode(llr, false, maxIter); - break; - case BPAlgo::MinSum: - D.altDecode(llr, true, maxIter); - break; - default: - std::cout << "bad algo " << (int)algo << std::endl; - std::abort(); - break; - } - //bpDecode(lr, maxIter); - //for (auto& l : mL) - for (u64 i = 0; i < n; ++i) - { - if (abs) - llr[i] = std::abs(D.mL[i]); - else - llr[i] = (D.mL[i]); - } - - sort_indexes(llr, sortIdx); - - - u64 ii = 0; - dSet.clear(); - eSet.clear(); - - bool sparse = false; - if (sparse) - H = D.mH; - else - DH = D.mH.dense(); - - VecSortSet col; - - while (ii < n && eSet.size() < m) - { - auto c = sortIdx[ii++]; - - if (sparse) - col = H.col(c); - else - DH.colIndexSet(c, col.mData); - - bool set = false; - - for (auto r : col) - { - if (r >= eSet.size()) - { - if (set) - { - if (sparse) - H.rowAdd(r, eSet.size()); - else - DH.row(r) ^= DH.row(eSet.size()); - //assert(H(r, c) == 0); - } - else - { - set = true; - if (sparse) - H.rowSwap(eSet.size(), r); - else - DH.row(eSet.size()).swap(DH.row(r)); - } - } - } - - if (set == false) - { - dSet.push_back(c); - } - else - { - eSet.push_back(c); - } - } - - if (!sparse) - H = DH; - - - //auto HH1 = H.sparse().dense().gausianElimination(); - //auto HH2 = mH.dense().gausianElimination(); - //assert(HH1 == HH2); - - - if (eSet.size() != m) - { - std::cout << "bad eSet size " << LOCATION << std::endl; - abort(); - } - - //auto gap = dSet.size(); - - while (dSet.size() < Nd) - { - auto col = sortIdx[ii++]; - dSet.push_back(col); - } - - permute = eSet; - permute.insert(permute.end(), dSet.begin(), dSet.end()); - permute.insert(permute.end(), sortIdx.begin() + permute.size(), sortIdx.end()); - //if (v) - //{ - - // auto H2 = H.selectColumns(permute); - - // std::cout << " " << gap << "\n" << H2 << std::endl; - // //for (auto l : mL) - - // for (u64 i = 0; i < n; ++i) - // { - // std::cout << decodeLLR(D.mL[permute[i]]) << " "; - // } - // std::cout << std::endl; - //} - - std::fill(y.begin(), y.end(), 0); - - llrSetList.clear(); - - for (u64 i = m; i < n; ++i) - { - auto col = permute[i]; - codeword[col] = LdpcDecoder::decodeLLR(D.mL[col]); - - - if (codeword[col]) - { - llrSetList.push_back(col); - for (auto row : H.col(col)) - { - y[row] ^= 1; - } - } - } - - eeSet.clear(); - eeSet.insert(eSet.begin(), eSet.end()); - for (u64 i = 0; i < m - 1; ++i) - { - backProps[i].clear(); - for (auto c : H.row(i)) - { - if (c != permute[i] && - eeSet.find(c) != eeSet.end()) - { - backProps[i].push_back(c); - } - } - } - - - - ListIter cIter; - - if (listDecoder == ListDecoder::Chase) - cIter.initChase(Nd); - else - cIter.initOsd(Nd, std::min(Ng, n - m), llrSetList.size() > 0); - - //u32 s = 0; - //u32 e = 1 << Nd; - while (cIter) - { - - //for (u64 i = m + Nd; i < n; ++i) - //{ - // auto col = permute[i]; - // codeword[col] = y[i]; - //} - - ////yp = y; - std::fill(codeword.begin(), codeword.end(), 0); - - - // check if BP resulted in any - // of the top bits being set to 1. - // if so, preload the partially - // solved solution for this. - if (llrSetList.size()) - { - for (u64 i = 0; i < m; ++i) - codeword[permute[i]] = y[i]; - for (auto i : llrSetList) - codeword[i] = 1; - } - - // next, iterate over setting some - // of the remaining bits. - auto& oneSet = *cIter; - //for (u64 i = m; i < m + Nd; ++i, ++iter) - for (auto i : oneSet) - { - - auto col = permute[i + m]; - - // check if this was not already set to - // 1 by the BP. - if (codeword[col] == 0) - { - codeword[col] = 1; - for (auto row : H.col(col)) - { - codeword[permute[row]] ^= 1; - } - } - } - - ++cIter; - - // now perform back prop on the remaining - // postions. - for (u64 i = m - 1; i != ~0ull; --i) - { - for (auto c : backProps[i]) - { - if (codeword[c]) - codeword[permute[i]] ^= 1; - } - } - - - // check if its a codework (should always be one) - if (D.check(codeword) == false) { - std::cout << "bad codeword " << LOCATION << std::endl; - abort(); - } - - // record the weight. - auto w = std::accumulate(codeword.begin(), codeword.end(), 0ull); - ++weights[w]; - - if (w && w < minWeight + 10) - { - - oc::RandomOracle ro(sizeof(u64)); - ro.Update(codeword.data(), codeword.size()); - u64 h; - ro.Final(h); - - std::lock_guard lock(minWeightMtx); - - if (w < minWeight) - { - if (verbose) - std::cout << " w=" << w << std::flush; - - minWeight = (u32) w; - minCW.clear(); - minCW.insert(minCW.end(), codeword.begin(), codeword.end()); - - if (timeout > 0) - { - u64 nn = static_cast((i + 1) * timeout); - nextTimeoutIdx = std::max(nextTimeoutIdx, nn); - } - } - - - - if (heatSet.find(h) == heatSet.end()) - { - heatSet.insert(h); - //heatMap[w].resize(n); - for (u64 i = 0; i < n; ++i) - { - if (codeword[i]) - ++heatMap[i]; - ++heatMapCount[w]; - } - } - - } - } - } - - }; - - - - - u64 impulseDist( - SparseMtx& mH, - u64 Nd, u64 Ng, - u64 w, - u64 maxIter, u64 nt, bool randImpulse, u64 trials, BPAlgo algo, - ListDecoder listDecoder, bool verbose, double timeout) - { - assert(Nd < 32); - auto n = mH.cols(); - //auto m = mH.rows(); - - LdpcDecoder D; - D.init(mH); - minWeight = 999999999; - minCW.clear(); - - - nt = nt ? nt : 1; - std::vector wrks(nt); - - for (auto& ww : wrks) - { - ww.algo = algo; - ww.D.init(mH); - ww.verbose = verbose; - - ww.Ng = Ng; - ww.listDecoder = listDecoder; - ww.timeout = timeout; - } - nextTimeoutIdx = 0; - if (randImpulse) - { - std::vector thrds(nt); - - for (u64 t = 0; t < nt; ++t) - { - thrds[t] = std::thread([&, t]() { - - for (u64 i = t; i < trials; i += nt) - { - wrks[t].impulseDist(i, w, Nd, maxIter, randImpulse); - } - }); - } - - for (u64 t = 0; t < nt; ++t) - thrds[t].join(); - } - else - { - - bool timedOut = false; - std::vector thrds(nt); - - for (u64 t = 0; t < nt; ++t) - { - thrds[t] = std::thread([&, t]() { - - u64 ii = t; - for (u64 k = 0; k < w + 1; ++k) - { - - auto f = choose(n, k); - for (; ii < f && timedOut == false; ii += nt) - { - - wrks[t].impulseDist(ii, k, Nd, maxIter, randImpulse); - - - if (k == w && (ii % 100) == 0) - { - std::lock_guard lock(minWeightMtx); - - if (nextTimeoutIdx && nextTimeoutIdx < ii) - timedOut = true; - } - } - - ii -= f; - } - } - ); - } - - for (u64 t = 0; t < nt; ++t) - thrds[t].join(); - - } - - std::vector weights(n); - for (u64 t = 0; t < nt; ++t) - { - for (u64 i = 0; i < wrks[t].weights.size(); ++i) - weights[i] += wrks[t].weights[i]; - } - - auto i = 1; - while (weights[i] == 0) ++i; - auto ret = i; - - if (verbose) - { - std::cout << std::endl; - auto j = 0; - while (j++ < 10) - { - std::cout << i << " " << weights[i] << std::endl; - ++i; - } - } - - return ret; - } - - std::string return_current_time_and_date() - { - auto now = std::chrono::system_clock::now(); - auto in_time_t = std::chrono::system_clock::to_time_t(now); - - std::stringstream ss; - ss << std::put_time(std::localtime(&in_time_t), "%c %X"); - return ss.str(); - } - - void LdpcDecode_impulse(const oc::CLP& cmd) - { - // general parameter - // the number of rows of H, "m" - auto rowVec = cmd.getManyOr("r", { 50 }); - - // the expansion ratio, e="n/m" - double e = cmd.getOr("e", 2.0); - - // the number of trials. - u64 trial = cmd.getOr("trials", 1); - - // which trial to start at. - u64 tStart = cmd.getOr("tStart", 0); - - // a specific set of trials to run. - auto tSet = cmd.getManyOr("tSet", {}); - - // the need used to generate the samples - u64 seed = cmd.getOr("seed", 0); - - // the seed(s) to generate the left half. One for each trial. - auto lSeed = cmd.getManyOr("lSeed", {}); - - // verbose flag. - bool verbose = cmd.isSet("v"); - - // the number of threads - u64 nt = cmd.getOr("nt", cmd.isSet("nt") ? std::thread::hardware_concurrency() : 1); - - // A flag to just test the uniform matrix. - bool uniform = cmd.isSet("u"); - - - // silver parameters - // ================================ - - // use the silver preset for the given column weight. - bool silver = cmd.isSet("slv"); - - // The column weight for the left half - u64 colWeight = cmd.getOr("cw", 5); - - // the column weight for the right half. Counts the - // main diagonal. - u64 dWeight = cmd.getOr("dw", 2); - - // The size of the gap. - u64 gap = cmd.getOr("g", 1); - - // the number of diagonals on the left half. - u64 diag = cmd.getOr("diag", 0); - - // the number of diagonals on the right half. - u64 dDiag = cmd.getOr("dDiag", 0); - - // the extra diagonal bands below the main diagonal (right half) - // The values denote how far below the diagonal they should be. - // requires -trim - auto doubleBand = cmd.getMany("db"); - - // delete the first g columns of the right half - bool trim = cmd.isSet("trim"); - - // extend the right half to be square. - bool extend = cmd.isSet("extend"); - - // how often the right half should repeat. - u64 period = cmd.getOr("period", 0); - - // a flag to randomly sample the position of the left diagonals. - // requires -diag > 0. - bool randY = cmd.isSet("randY"); - - // the slopes of the left diagonals, default 1. - // requires -diag > 0. - slopes_ = cmd.getManyOr("slope", {}); - - // the fixed indexes of the left diagononals - // requires -diag > 0. - ys_ = cmd.getManyOr("ys", {}); - - // the fractional positions of the left diagononals. - // requires -diag > 0. - yr_ = cmd.getManyOr("yr", {}); - - // print the factional positions of the left diagonals. - bool printYs = cmd.isSet("py"); - - // sample the right half to be regular (the same number of - // ones in the rows as the columns). - bool reg = cmd.isSet("reg"); - - // print the heat map for where we found minimum codewords - bool hm = cmd.isSet("hm"); - - // when sample, dont check that H corresponds to a value - // LDPC code - bool noCheck = cmd.isSet("noCheck"); - - // the path to the log file. - std::string logPath = cmd.getOr("log", ""); - - // the amount of "time" that can pass in the impluse technique - // without finding a new minimum. the value denote the faction - // of the total number of impluses which should be performed. - double timeout = cmd.getOr("to", 0.0); - - - // estimator parameters - // ============================== - - // the weight of the noise impulse. - u64 w = cmd.getOr("w", 1); - - // once partial gaussian elemination is performed, - // we will consider all codewords which have the - // next Nd out-of Ng bits set to one. e.g. if we have - // Nd=2, Ng=4, the the codewords - // - // xx...xx 1100 0000... - // xx...xx 1010 0000... - // ... - // xx...xx 0011 0000... - // - // will be tried where the x bits are solved for. - - u64 Nd = cmd.getOr("Nd", 10); - u64 Ng = cmd.getOr("Ng", 50); - - // the number of BD iterations that should be performed. - u64 iter = cmd.getOr("iter", 10); - - // should random noise impulses of the given weight be tried. - bool rand = cmd.isSet("rand"); - - // how many random noise impulse should be tried. - u64 n = cmd.getOr("rand", 100); - - // the type of list decoder. - ListDecoder listDecoder = (ListDecoder)cmd.getOr("ld", 1); - - // print the regular diagonal so it can be used to generate a - // actual code, e.g. diagMtx_g32_w11_seed2_t36. - printDiag = cmd.isSet("printDiag"); - - auto algo = (BPAlgo)cmd.getOr("bp", 2); - - // algo994 parameters - auto trueDist = cmd.isSet("true"); -#ifdef ENABLE_ALGO994 - alg994 = cmd.getOr("algo994", ALG_SAVED_UNROLLED); - num_saved_generators = cmd.getOr("numGen", 5); - num_cores = (int)nt; - num_permutations = cmd.getOr("numPerm", 10); - print_matrices = 0; -#endif - - SparseMtx H; - LdpcDecoder D; - std::stringstream label; - - - label << return_current_time_and_date() << "\n"; - - if (e != 2) - label << "-ldpc -e " << e << " "; - - if (silver) - label << " -slv "; - - label << "-r "; - for (auto rows : rowVec) - label << rows << " "; - - if (trueDist) - label << "-true "; - else - { - label << " -ld " << (int)listDecoder << " -bp " << int(algo) << "-Nd " << Nd << " -Ng " << Ng << " -w " << w; - if (timeout) - label << " -to " << timeout; - } - label << " -nt " << nt; - - if (tSet.size()) - { - label << " -tSet "; - for (auto t : tSet) - label << t << " "; - - } - else - { - label << " -trials " << trial; - if (tStart) - label << " -tStart " << tStart; - } - - label << " -seed " << seed; - - if (uniform) - { - label << " -u "; - if (cmd.isSet("cw")) - label << " -cw " << colWeight; - } - else - { - if (cmd.isSet("lb")) - label << "-ld"; - - label << " -cw " << colWeight << " -g " << gap << " -dw " << dWeight - << " -diag " << diag << " -dDiag " << dDiag; - - if (doubleBand.size()) - { - label << " -db "; - for (auto db : doubleBand) - label << db << " "; - } - - if (trim) - label << " -trim "; - - if (extend) - label << " -extend "; - - label << " -period " << period; - //if(zp) - // label << " -zp "; - - if (reg) - label << " -reg"; - - } - - std::ofstream log; - if (logPath.size()) - { - log.open(logPath, std::ios::out | std::ios::app); - } - - if (log.is_open()) - log << "\n" << label.str() << std::endl; - - if (tSet.size() == 0) - { - for (u64 i = tStart; i < trial; ++i) - tSet.push_back(i); - } - - for (auto rows : rowVec) - { - - //if (zp) - //{ - // if (isPrime(rows + 1) == false) - // rows = nextPrime(rows + 1) - 1; - - //} - u64 cols = static_cast(rows * e); - - if (uniform && trim) - { - cols -= gap; - } - - std::vector dd; - - heatMap.clear(); - heatMap.resize(cols); - heatMapCount.clear(); - heatMapCount.resize(cols); - heatSet.clear(); - - if (log.is_open()) - log << rows << ": "; - - auto lIter = lSeed.begin(); - for (auto i : tSet) - { - oc::PRNG lPrng, rPrng(block(seed, i)); - - if (lIter != lSeed.end()) - { - lPrng.SetSeed(block(3, *lIter++)); - - if (lIter == lSeed.end()) - lIter = lSeed.begin(); - } - else - lPrng.SetSeed(block(seed, i)); - - - if (uniform) - { - if (cmd.isSet("cw")) - { - - - H = sampleFixedColWeight(rows, cols, colWeight, rPrng, !noCheck); - } - else - H = sampleUniformSystematic(rows, cols, rPrng); - } - else if (silver) - { - - SilverEncoder enc; - SilverCode code; - - if (colWeight == 5) - { - code = SilverCode::Weight5; - } - else if (colWeight == 11) - code = SilverCode::Weight11; - else - { - std::cout << "-slv can only be used with -cw 5 or -cw 11" << std::endl; - throw RTE_LOC; - } - - enc.mR.init(rows, code, extend); - enc.mL.init(rows, code, cmd.isSet("rep") ? SilverCode::gap(code) : 0); - H = enc.getMatrix(); - } - else if (reg) - { - H = sampleRegTriangularBand( - rows, cols, - colWeight, gap, - dWeight, diag, dDiag, period, - doubleBand, trim, extend, randY, rPrng); - //std::cout << H << std::endl; - } - else - { - H = sampleTriangularBand( - rows, cols, - colWeight, gap, - dWeight, diag, dDiag, period, - doubleBand, trim, extend, randY, lPrng,rPrng); - } - - //impulseDist(5, 5000); - //oc::Timer timer; - //timer.setTimePoint(""); - - //timer.setTimePoint("e"); - - if(verbose) - std::cout << "\n" << H << std::endl; - - if (trueDist) - { - auto d = minDist2(H.dense(), nt, false); - dd.push_back(d); - } - else - { - auto d = impulseDist(H, Nd, Ng, w, iter, nt, rand, n, algo, listDecoder, verbose, timeout); - dd.push_back(d); - } - - if (log.is_open()) - log << " " << dd.back() << std::flush; - - if (verbose) - { - std::cout << dd.back(); - - for (auto c : minCW) - { - if (c) - std::cout << oc::Color::Green << int(c) << " " << oc::Color::Default; - else - std::cout << int(c) << " "; - } - std::cout << std::endl; - - if (hm || verbose) - { - u64 max = 0ull; - for (u64 i = 0; i < heatMap.size(); ++i) - { - max = std::max(max, heatMap[i]); - } - - double tick = max / 10.0; - - - for (u64 j = 1; j <= 10; ++j) - { - for (u64 i = 0; i < heatMap.size(); ++i) - { - if (heatMap[i] >= j * tick) - std::cout << "* "; - else - std::cout << " "; - } - std::cout << "|\n"; - } - std::cout << std::flush; - - heatMap.clear(); - heatMap.resize(cols); - heatMapCount.clear(); - heatMapCount.resize(cols); - heatSet.clear(); - } - - } - else if (!cmd.isSet("silent")) - { - std::cout << dd.back() << " " << std::flush; - - if (printYs) - { - std::cout << "~ "; - for (auto y : lastYs_ ) - std::cout << y << " "; - - std::cout << "~ "; - for (auto y : yr_) - std::cout << y << " "; - - std::cout << std::endl; - } - } - //std::cout << timer << std::endl;; - - } - - if (log.is_open()) - log << std::endl; - - auto tt = tSet.size(); - auto min = *std::min_element(dd.begin(), dd.end()); - auto max = *std::max_element(dd.begin(), dd.end()); - auto avg = std::accumulate(dd.begin(), dd.end(), 0ull) / double(tt); - //avg = avg / tt; - - //std::cout << "\r"; - //auto str = ss.str(); - //for (u64 i = 0; i < str.size(); ++i) - // std::cout << " "; - //std::cout << ; - - { - std::cout << oc::Color::Green << "\r" << rows << ": "; - std::cout << min << " " << avg << " " << max << " ~ " << oc::Color::Default; - for (auto d : dd) - std::cout << d << " "; - - std::cout << std::endl; - } - - - - if (hm && !verbose) - { - u64 max = 0ull; - for (u64 i = 0; i < heatMap.size(); ++i) - { - max = std::max(max, heatMap[i]); - } - - double tick = max / 10.0; - - - for (u64 j = 1; j <= 10; ++j) - { - for (u64 i = 0; i < heatMap.size(); ++i) - { - if (heatMap[i] >= j * tick) - std::cout << "* "; - else - std::cout << " "; - } - std::cout << "|\n"; - } - std::cout << std::flush; - - heatMap.clear(); - heatMap.resize(cols); - heatMapCount.clear(); - heatMapCount.resize(cols); - heatSet.clear(); - } - - } - - - - - return; - - } - - - -} - -#endif diff --git a/libOTe/Tools/LDPC/LdpcSampler.cpp b/libOTe/Tools/LDPC/LdpcSampler.cpp deleted file mode 100644 index 0693da62..00000000 --- a/libOTe/Tools/LDPC/LdpcSampler.cpp +++ /dev/null @@ -1,256 +0,0 @@ -#include "libOTe/Tools/LDPC/LdpcSampler.h" -#ifdef ENABLE_LDPC - -#include "libOTe/Tools/LDPC/LdpcEncoder.h" -#include -#include "libOTe/Tools/LDPC/Util.h" -#include - -#ifdef ENABLE_ALGO994 -extern "C" { -#include "libOTe/Tools/LDPC/Algo994/data_defs.h" -} -#endif - -namespace osuCrypto -{ - - std::vector slopes_, ys_, lastYs_; - std::vector yr_; - bool printDiag = false; - - - void sampleRegDiag(u64 rows, u64 gap, u64 weight, oc::PRNG& prng, PointList& points) - { - - if (rows < gap * 2) - throw RTE_LOC; - - auto cols = rows - gap; - std::vector rowWeights(cols); - std::vector> rowSets(rows - gap); - - - for (u64 c = 0; c < cols; ++c) - { - std::set s; - //auto remCols = cols - c; - - for (u64 j = 0; j < gap; ++j) - { - auto rowIdx = c + j; - - // how many remaining slots are there to the left (speial case at the start) - u64 remA = std::max(gap - rowIdx, 0); - - // how many remaining slots there are to the right. - u64 remB = std::min(j, cols - c); - - u64 rem = remA + remB; - - - auto& w = rowWeights[rowIdx % cols]; - auto needed = (weight - w); - - if (needed > rem) - throw RTE_LOC; - - if (needed && needed == rem) - { - s.insert(rowIdx); - points.push_back({ rowIdx, c }); - ++w; - } - } - - if (s.size() > weight) - throw RTE_LOC; - - while (s.size() != weight) - { - auto j = (prng.get() % gap); - auto r = c + j; - - auto& w = rowWeights[r % cols]; - - if (w != weight && s.insert(r).second) - { - ++w; - points.push_back({ r, c }); - } - } - - for (auto ss : s) - { - rowSets[ss % cols].insert(c); - if (rowSets[ss % cols].size() > weight) - throw RTE_LOC; - } - - if (c > gap && rowSets[c].size() != weight) - { - SparseMtx H(c + gap + 1, c + 1, points); - std::cout << H << std::endl << std::endl; - throw RTE_LOC; - } - - - } - - if (printDiag) - { - - //std::vector> hh; - auto pp = points; - - for (u64 i = 0; i < rows - 1; ++i) - pp.push_back({ i, i + 1 }); - - SparseMtx H(rows, rows, points); - //std::cout << (SparseMtx(pp)) << std::endl << std::endl; - - std::cout << "{{\n"; - - - for (u64 i = 0; i < cols; ++i) - { - std::cout << "{{ "; - bool first = true; - //hh.emplace_back(); - - for (u64 j = 0; j < (u64)H.col(i).size(); ++j) - { - auto c = H.col(i)[j]; - c = (c - i) % cols; - - if (!first) - std::cout << ", "; - std::cout << c; - - //hh[i].push_back(H.row(i)[j]); - first = false; - } - - //if (i + cols < H.rows() && 0) - //{ - // for (u64 j = 0; j < (u64)H.row(i + cols).size(); ++j) - // { - - // auto c = H.row(i + cols)[j]; - // c = (c + cols - 1 - i) % cols; - - // if (!first) - // std::cout << ", "; - // std::cout << c; - // //hh[i].push_back(H.row(i+cols)[j]); - // first = false; - // } - //} - - - - std::cout << "}},\n"; - } - std::cout << "}}" << std::endl; - - - //{ - // u64 rowIdx = 0; - // for (auto row : hh) - // { - // std::set s; - // std::cout << "("; - // for (auto c : row) - // { - // std::cout << int(c) << " "; - // s.insert(); - // } - // std::cout << std::endl << "{"; - - // for (auto c : s) - // std::cout << c << " "; - // std::cout << std::endl; - // ++rowIdx; - // } - //} - } - } - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - void sampleRegTriangularBand(u64 rows, u64 cols, - u64 weight, u64 gap, u64 dWeight, - u64 diag, u64 dDiag, u64 period, - std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& prng, PointList& points) - { - //auto dHeight =; - - assert(extend == false || trim == true); - assert(gap < rows); - assert(dWeight > 0); - assert(dWeight <= gap + 1); - - if (trim == false) - throw RTE_LOC; - - - if (period == 0 || period > rows) - period = rows; - - if (extend) - { - for (u64 i = 0; i < gap; ++i) - { - points.push_back({ rows - gap + i, cols - gap + i }); - } - } - - //auto b = trim ? cols - rows + gap : cols - rows; - auto b = cols - rows; - auto diagOffset = sampleFixedColWeight(rows, b, weight, diag, randY, prng, points); - u64 e = rows - gap; - auto e2 = cols - gap; - - - PointList diagPoints(period + gap, period); - sampleRegDiag(period + gap, gap, dWeight - 1, prng, diagPoints); - //std::cout << "cols " << cols << std::endl; - - std::set> ss; - for (u64 i = 0; i < e; ++i) - { - - points.push_back({ i, b + i }); - if (b + i >= cols) - throw RTE_LOC; - - //if (ss.insert({ i, b + i })); - for (auto db : doubleBand) - { - assert(db >= 1); - u64 j = db + gap + i; - - if (j < rows) - points.push_back({ j, b + i }); - } - } - - auto blks = (e + period - 1) / (period); - for (u64 i = 0; i < blks; ++i) - { - auto ii = i * period; - for (auto p : diagPoints) - { - auto r = ii + p.mRow + 1; - auto c = ii + p.mCol + b; - if (r < rows && c < e2) - points.push_back({ r, c }); - } - } - - } -} -#endif diff --git a/libOTe/Tools/LDPC/LdpcSampler.h b/libOTe/Tools/LDPC/LdpcSampler.h deleted file mode 100644 index 4bcb8848..00000000 --- a/libOTe/Tools/LDPC/LdpcSampler.h +++ /dev/null @@ -1,751 +0,0 @@ -#pragma once -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" -#ifdef ENABLE_LDPC -#include "Mtx.h" -#include "cryptoTools/Crypto/PRNG.h" -#include -#include -#include "cryptoTools/Common/CLP.h" -#include "cryptoTools/Common/BitVector.h" -#include "Util.h" -#include "libOTe/Tools/Tools.h" -#include - -namespace osuCrypto -{ - - //inline void push(PointList& p, Point x) - //{ - // //for (u64 i = 0; i < p.size(); ++i) - // //{ - // // if (p[i].mCol == x.mCol && p[i].mRow == x.mRow) - // // { - // // assert(0); - // // } - // //} - - // //std::cout << "{" << x.mRow << ", " << x.mCol << " } " << std::endl; - // p.push_back(x); - //} - - - extern std::vector slopes_, ys_, lastYs_; - extern std::vector yr_; - extern bool printDiag; - // samples a uniform partiy check matrix with - // each column having weight w. - inline std::vector sampleFixedColWeight( - u64 rows, u64 cols, - u64 w, u64 diag, bool randY, - oc::PRNG& prng, PointList& points) - { - std::vector& diagOffsets = lastYs_; - diagOffsets.clear(); - - diag = std::min(diag, w); - - if (slopes_.size() == 0) - slopes_ = { {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1} }; - - if (slopes_.size() < diag) - throw RTE_LOC; - - if (diag) - { - if (randY) - { - yr_.clear(); - std::uniform_real_distribution<> dist(0, 1); - yr_.push_back(0); - //s.insert(0); - - while (yr_.size() != diag) - { - auto v = dist(prng); - yr_.push_back(v); - } - - std::sort(yr_.begin(), yr_.end()); - //diagOffsets.insert(diagOffsets.end(), s.begin(), s.end()); - } - - if (yr_.size()) - { - if (yr_.size() < diag) - { - std::cout << "yr.size() < diag" << std::endl; - throw RTE_LOC; - } - std::set ss; - //ss.insert(0); - //diagOffsets.resize(diag); - - for (u64 i = 0; ss.size() < diag; ++i) - { - auto p = u64(rows * yr_[i]) % rows; - while (ss.insert(p).second == false) - p = u64(p + 1) % rows; - } - - //for (u64 i = 0; i < diag; ++i) - //{ - //} - diagOffsets.clear(); - for (auto s : ss) - diagOffsets.push_back(s); - } - else if (ys_.size()) - { - - if (ys_.size() < diag) - throw RTE_LOC; - - diagOffsets = ys_; - diagOffsets.resize(diag); - - } - else - { - diagOffsets.resize(diag); - - for (u64 i = 0; i < diag; ++i) - diagOffsets[i] = rows / 2; - } - - } - std::set set; - for (u64 i = 0; i < cols; ++i) - { - set.clear(); - if (diag && i < rows) - { - //assert(diag <= minWeight); - for (u64 j = 0; j < diag; ++j) - { - i64& r = diagOffsets[j]; - - - - r = (slopes_[j] + r); - - if (r >= i64(rows)) - r -= rows; - - if (r < 0) - r += rows; - - if (r >= i64(rows) || r < 0) - { - //std::cout << i << " " << r << " " << rows << std::endl; - throw RTE_LOC; - } - - - set.insert(r); - //auto& pat = patterns[i % patterns.size()]; - //for (u64 k = 0; k < pat.size(); ++k) - //{ - - // auto nn = set.insert((r + pat[k]) % rows); - // if (!nn.second) - // set.erase(nn.first); - //} - - //auto r2 = (r + 1 + (i & 15)) % rows; - //nn = set.insert(r2).second; - //if (nn) - // points.push_back({ r2, i }); - } - } - - for (auto ss : set) - { - points.push_back({ ss, i }); - } - - while (set.size() < w) - { - auto j = prng.get() % rows; - if (set.insert(j).second) - points.push_back({ j, i }); - } - } - - return diagOffsets; - } - - DenseMtx computeGen(DenseMtx& H); - - // samples a uniform partiy check matrix with - // each column having weight w. - inline SparseMtx sampleFixedColWeight(u64 rows, u64 cols, u64 w, oc::PRNG& prng, bool checked) - { - PointList points(rows, cols); - sampleFixedColWeight(rows, cols, w, false, false, prng, points); - - if (checked) - { - u64 i = 1; - SparseMtx H(rows, cols, points); - - auto D = H.dense(); - auto g = computeGen(D); - - while(g.rows() == 0) - { - ++i; - points.mPoints.clear(); - sampleFixedColWeight(rows, cols, w, false, false, prng, points); - - H = SparseMtx(rows, cols, points); - D = H.dense(); - g = computeGen(D); - } - - //std::cout << "("< - inline void shuffle(Iter begin, Iter end, oc::PRNG& prng) - { - u64 n = u64(end - begin); - - for (u64 i = 0; i < n; ++i) - { - auto j = prng.get() % (n - i); - - std::swap(*begin, *(begin + j)); - - ++begin; - } - } - - // samples a uniform set of size weight in the - // inteveral [begin, end). If diag, then begin - // will always be in the set. - inline std::set sampleCol(u64 begin, u64 end, u64 mod, u64 weight, bool diag, oc::PRNG& prng) - { - //std::cout << "sample " << prng.get() << std::endl; - - std::set idxs; - - auto n = end - begin; - if (n < weight) - { - std::cout << "n < weight " << LOCATION << std::endl; - abort(); - } - - if (diag) - { - idxs.insert(begin % mod); - ++begin; - --n; - --weight; - } - - if (n < 3 * weight) - { - auto nn = std::min(3 * weight, n); - std::vector set(nn); - std::iota(set.begin(), set.end(), begin); - - shuffle(set.begin(), set.end(), prng); - - //for (u64 i = 0; i < weight; ++i) - //{ - // std::cout << set[i] << std::endl; - //} - //for(auto s : set) - //auto iter = set.beg - for (u64 i = 0; i < weight; ++i) - idxs.insert((set[i]) % mod); - } - else - { - while (idxs.size() < weight) - { - auto x = prng.get() % n; - //std::cout << x << std::endl; - idxs.insert((x + begin) % mod); - } - } - - return idxs; - } - - - inline std::set sampleCol(u64 begin, u64 end, u64 weight, bool diag, oc::PRNG& prng) - { - std::set idxs; - - auto n = end - begin; - if (n < weight) - { - std::cout << "n < weight " << LOCATION << std::endl; - abort(); - } - - if (diag) - { - idxs.insert(begin); - ++begin; - --n; - --weight; - } - - if (n < 3 * weight) - { - auto nn = std::min(3 * weight, n); - std::vector set(nn); - std::iota(set.begin(), set.end(), begin); - - shuffle(set.begin(), set.end(), prng); - - //for (u64 i = 0; i < weight; ++i) - //{ - // std::cout << set[i] << std::endl; - //} - - idxs.insert(set.begin(), set.begin() + weight); - } - else - { - while (idxs.size() < weight) - { - auto x = prng.get() % n; - //std::cout << x << std::endl; - idxs.insert(x + begin); - } - } - return idxs; - } - - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline void sampleTriangular(u64 rows, u64 cols, u64 weight, u64 gap, oc::PRNG& prng, PointList& points) - { - auto b = cols - rows + gap; - sampleFixedColWeight(rows, b, weight, 0, false, prng, points); - - for (u64 i = 0; i < rows - gap; ++i) - { - auto w = std::min(weight - 1, (rows - i) / 2); - auto s = sampleCol(i + 1, rows, w, false, prng); - - points.push_back({ i, b + i }); - for (auto ss : s) - points.push_back({ ss, b + i }); - - } - } - - - inline void sampleUniformSystematic(u64 rows, u64 cols, oc::PRNG& prng, PointList& points) - { - - for (u64 i = 0; i < rows; ++i) - { - points.push_back({ i, cols - rows + i }); - - for (u64 j = 0; j < cols - rows; ++j) - { - if (prng.get()) - { - points.push_back({ i,j }); - } - } - } - - - } - - inline SparseMtx sampleUniformSystematic(u64 rows, u64 cols, oc::PRNG& prng) - { - PointList points(rows, cols); - sampleUniformSystematic(rows, cols, prng, points); - return SparseMtx(rows, cols, points); - } - - void sampleRegDiag( - u64 rows, u64 gap, u64 weight, - oc::PRNG& prng, PointList& points - ); - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - void sampleRegTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, - std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& prng, PointList& points); - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline void sampleTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, - std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& lPrng, PRNG& rPrng, PointList& points) - { - auto dHeight = gap + 1; - - assert(extend == false || trim == true); - assert(gap < rows); - assert(dWeight > 0); - assert(dWeight <= dHeight); - - if (extend) - { - for (u64 i = 0; i < gap; ++i) - { - points.push_back({ rows - gap + i, cols - gap + i }); - } - } - - //auto b = trim ? cols - rows + gap : cols - rows; - auto b = cols - rows; - - auto diagOffset = sampleFixedColWeight(rows, b, weight, diag, randY, lPrng, points); - - u64 ii = trim ? 0 : rows - gap; - u64 e = trim ? rows - gap : rows; - - - //if (doubleBand.size()) - //{ - // if (dDiag || !trim) - // { - // std::cout << "assumed no dDiag and assumed trim" << std::endl; - // abort(); - // } - - // //for (auto db : doubleBand) - // //{ - // // assert(db >= 1); - - // // for (u64 j = db + gap, c = b; j < rows; ++j, ++c) - // // { - // // points.push_back({ j, c }); - // // } - // //} - //} - - if (period && dDiag) - throw RTE_LOC; - - if (period) - { - if (trim == false) - throw RTE_LOC; - - for (u64 p = 0; p < period; ++p) - { - std::set s; - - auto ww = dWeight - 1; - - assert(ww < dHeight); - - s = sampleCol(1, dHeight, ww, false, rPrng); - - for (auto db : doubleBand) - { - assert(db >= 1); - u64 j = db + gap; - s.insert(j); - - } - - for (u64 i = p; i < e; i += period) - { - - points.push_back({ i % rows, b + i }); - for (auto ss : s) - { - if(i + ss < rows) - points.push_back({ (i+ss), b + i }); - } - } - - } - } - else - { - - for (u64 i = 0; i < e; ++i, ++ii) - { - auto ww = dWeight - 1; - for (auto db : doubleBand) - { - assert(db >= 1); - u64 j = db + gap + ii; - - if (j >= rows) - { - if (dDiag) - ++ww; - } - else - points.push_back({ j, b + i }); - - } - assert(ww < dHeight); - - auto s = sampleCol(ii + 1, ii + dHeight, ww, false, rPrng); - - points.push_back({ ii % rows, b + i }); - for (auto ss : s) - points.push_back({ ss % rows, b + i }); - - } - } - } - - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, oc::PRNG& prng) - { - PointList points(rows, cols); - sampleTriangularBand(rows, cols, weight, - gap, dWeight, diag, 0, 0, - {}, false, false, false, prng, prng, points); - - return SparseMtx(rows, cols, points); - } - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& lPrng, - oc::PRNG& rPrng) - { - PointList points(rows, cols); - sampleTriangularBand( - rows, cols, - weight, gap, - dWeight, diag, dDiag, period, doubleBand, trim, extend, randY, - lPrng, rPrng, points); - - auto cc = (trim && !extend) ? cols - gap : cols; - - return SparseMtx(rows, cc, points); - } - - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleRegTriangularBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, u64 dDiag, u64 period, std::vector doubleBand, - bool trim, bool extend, bool randY, - oc::PRNG& prng) - { - auto cc = (trim && !extend) ? cols - gap : cols; - PointList points(rows, cc); - sampleRegTriangularBand( - rows, cols, - weight, gap, - dWeight, diag, dDiag, period, doubleBand, trim, extend, randY, - prng, points); - - - return SparseMtx(rows, cc, points); - } - - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline void sampleTriangularLongBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, bool doubleBand, - oc::PRNG& prng, PointList& points) - { - auto dHeight = gap + 1; - assert(gap < rows); - assert(dWeight < weight); - assert(dWeight <= dHeight); - - //sampleFixedColWeight(rows, cols - rows, weight, diag, prng, points); - - std::set s; - for (u64 i = 0, ii = rows - gap; i < cols; ++i, ++ii) - { - if (doubleBand) - { - assert(dWeight >= 2); - s = sampleCol(ii + 1, ii + dHeight - 1, rows, dWeight - 2, false, prng); - s.insert((ii + dHeight) % rows); - //points.push_back({ () % rows, i }); - } - else - s = sampleCol(ii + 1, ii + dHeight, rows, dWeight - 1, false, prng); - - - s.insert(ii % rows); - - if (i < rows) - { - while (s.size() != weight) - { - auto j = prng.get() % rows; - s.insert(j); - } - } - - //points.push_back({ ii % rows, i }); - for (auto ss : s) - points.push_back({ ss % rows, i }); - - - - } - } - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangularLongBand( - u64 rows, u64 cols, - u64 weight, u64 gap, - u64 dWeight, u64 diag, bool doubleBand, - oc::PRNG& prng) - { - PointList points(rows, cols); - sampleTriangularLongBand( - rows, cols, - weight, gap, - dWeight, diag, doubleBand, - prng, points); - - return SparseMtx(rows, cols, points); - } - - //// sample a parity check which is approx triangular with. - //// The diagonal will have fixed weight = dWeight. - //// The other columns will have weight = weight. - //inline void sampleTriangularBand2(u64 rows, u64 cols, u64 weight, u64 gap, u64 dWeight, oc::PRNG& prng, PointList& points) - //{ - // auto dHeight = gap + 1; - // assert(dWeight > 0); - // assert(dWeight <= dHeight); - - // sampleFixedColWeight(rows, cols - rows, weight, false, prng, points); - - // auto b = cols - rows; - // for (u64 i = 0, ii = rows - gap; i < rows; ++i, ++ii) - // { - // if (ii >= rows) - // { - // auto s = sampleCol(ii + 1, ii + dHeight, dWeight - 1, false, prng); - // points.push_back({ ii % rows, b + i }); - // for (auto ss : s) - // points.push_back({ ss % rows, b + i }); - // } - // else - // { - // auto s = sampleCol(ii, ii + dHeight, dWeight, false, prng); - // for (auto ss : s) - // points.push_back({ ss % rows, b + i }); - // } - - // } - //} - - //// sample a parity check which is approx triangular with. - //// The diagonal will have fixed weight = dWeight. - //// The other columns will have weight = weight. - //inline SparseMtx sampleTriangularBand2(u64 rows, u64 cols, u64 weight, u64 gap, u64 dWeight, oc::PRNG& prng) - //{ - // PointList points; - // sampleTriangularBand2(rows, cols, weight, gap, dWeight, prng, points); - // return SparseMtx(rows, cols, points); - //} - - - // sample a parity check which is approx triangular with. - // The other columns will have weight = weight. - inline void sampleTriangular(u64 rows, double density, oc::PRNG& prng, PointList& points) - { - assert(density > 0); - - u64 t = static_cast(~u64{ 0 } *density); - - for (u64 i = 0; i < rows; ++i) - { - points.push_back({ i, i }); - - for (u64 j = 0; j < i; ++j) - { - if (prng.get() < t) - { - points.push_back({ i, j }); - } - } - } - } - - // sample a parity check which is approx triangular with. - // The diagonal will have fixed weight = dWeight. - // The other columns will have weight = weight. - inline SparseMtx sampleTriangular(u64 rows, double density, oc::PRNG& prng) - { - PointList points(rows, rows); - sampleTriangular(rows, density, prng, points); - return SparseMtx(rows, rows, points); - } - - - inline SparseMtx sampleTriangular(u64 rows, u64 cols, u64 weight, u64 gap, oc::PRNG& prng) - { - PointList points(rows, cols); - sampleTriangular(rows, cols, weight, gap, prng, points); - return SparseMtx(rows, cols, points); - } - - - -} -#endif \ No newline at end of file diff --git a/libOTe/Tools/LDPC/Mtx.cpp b/libOTe/Tools/LDPC/Mtx.cpp index 6430a207..1354ba1f 100644 --- a/libOTe/Tools/LDPC/Mtx.cpp +++ b/libOTe/Tools/LDPC/Mtx.cpp @@ -2,7 +2,6 @@ #include "cryptoTools/Crypto/PRNG.h" #include "Util.h" #include "cryptoTools/Common/Matrix.h" -#include "LdpcSampler.h" namespace osuCrypto { diff --git a/libOTe/Tools/LDPC/Util.h b/libOTe/Tools/LDPC/Util.h index a05844ce..783a1b36 100644 --- a/libOTe/Tools/LDPC/Util.h +++ b/libOTe/Tools/LDPC/Util.h @@ -61,13 +61,6 @@ namespace osuCrypto while (i < mK - 1 && mSet[i] + 1 == mSet[i + 1]) ++i; - //if (i == mK - 1 && mSet.back() == mN - 1) - //{ - // mSet.clear(); - // return; - // //assert(mSet.back() != mN - 1); - //} - ++mSet[i]; for (u64 j = 0; j < i; ++j) mSet[j] = j; @@ -79,13 +72,7 @@ namespace osuCrypto } }; -#ifdef ENABLE_ALGO994 - extern int alg994; - extern int num_saved_generators; - extern int num_cores; - extern int num_permutations; - extern int print_matrices; -#endif + } diff --git a/libOTe/Tools/Pprf/PprfUtil.h b/libOTe/Tools/Pprf/PprfUtil.h new file mode 100644 index 00000000..4d294a15 --- /dev/null +++ b/libOTe/Tools/Pprf/PprfUtil.h @@ -0,0 +1,266 @@ +#pragma once +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Aligned.h" +#include +#include + +namespace osuCrypto +{ + + + // the various formats that the output of the + // Pprf can be generated. + enum class PprfOutputFormat + { + // The i'th row holds the i'th leaf for all trees. + // The j'th tree is in the j'th column. + ByLeafIndex, + + // The i'th row holds the i'th tree. + // The j'th leaf is in the j'th column. + ByTreeIndex, + + // The native output mode. The output will be + // a single row with all leaf values. + // Every 8 trees are mixed together where the + // i'th leaf for each of the 8 tree will be next + // to each other. For example, let tij be the j'th + // leaf of the i'th tree. If we have m leaves, then + // + // t00 t10 ... t70 t01 t11 ... t71 ... t0m t1m ... t7m + // t80 t90 ... t_{15,0} t81 t91 ... t_{15,1} ... t8m t9m ... t_{15,m} + // ... + // + // These are all flattened into a single row. + Interleaved, + + // call the user's callback. The leaves will be in + // Interleaved format. + Callback + }; + + + + namespace pprf + { + + template + void copyOut( + VecF& leaf, + VecF& output, + u64 totalTrees, + u64 treeIndex, + PprfOutputFormat oFormat, + std::function& callback) + { + auto curSize = std::min(totalTrees - treeIndex, 8); + auto domain = leaf.size() / 8; + if (oFormat == PprfOutputFormat::ByLeafIndex) + { + if (curSize == 8) + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + auto oIdx = totalTrees * leafIndex + treeIndex; + auto iIdx = leafIndex * 8; + output[oIdx + 0] = leaf[iIdx + 0]; + output[oIdx + 1] = leaf[iIdx + 1]; + output[oIdx + 2] = leaf[iIdx + 2]; + output[oIdx + 3] = leaf[iIdx + 3]; + output[oIdx + 4] = leaf[iIdx + 4]; + output[oIdx + 5] = leaf[iIdx + 5]; + output[oIdx + 6] = leaf[iIdx + 6]; + output[oIdx + 7] = leaf[iIdx + 7]; + } + } + else + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + //auto oi = output[leafIndex].subspan(treeIndex, curSize); + //auto& ii = leaf[leafIndex]; + auto oIdx = totalTrees * leafIndex + treeIndex; + auto iIdx = leafIndex * 8; + for (u64 j = 0; j < curSize; ++j) + output[oIdx + j] = leaf[iIdx + j]; + } + } + + } + else if (oFormat == PprfOutputFormat::ByTreeIndex) + { + + if (curSize == 8) + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + auto iIdx = leafIndex * 8; + + output[(treeIndex + 0) * domain + leafIndex] = leaf[iIdx + 0]; + output[(treeIndex + 1) * domain + leafIndex] = leaf[iIdx + 1]; + output[(treeIndex + 2) * domain + leafIndex] = leaf[iIdx + 2]; + output[(treeIndex + 3) * domain + leafIndex] = leaf[iIdx + 3]; + output[(treeIndex + 4) * domain + leafIndex] = leaf[iIdx + 4]; + output[(treeIndex + 5) * domain + leafIndex] = leaf[iIdx + 5]; + output[(treeIndex + 6) * domain + leafIndex] = leaf[iIdx + 6]; + output[(treeIndex + 7) * domain + leafIndex] = leaf[iIdx + 7]; + } + } + else + { + for (u64 leafIndex = 0; leafIndex < domain; ++leafIndex) + { + auto iIdx = leafIndex * 8; + for (u64 j = 0; j < curSize; ++j) + output[(treeIndex + j) * domain + leafIndex] = leaf[iIdx + j]; + } + } + + } + else if (oFormat == PprfOutputFormat::Callback) + callback(treeIndex, leaf); + else + throw RTE_LOC; + } + + template + void allocateExpandBuffer( + u64 depth, + u64 numTrees, + bool programPuncturedPoint, + std::vector& buff, + span>& sums, + span& leaf, + CoeffCtx& ctx) + { + + u64 elementSize = ctx.template byteSize(); + + // num of bytes they will take up. + u64 numBytes = + depth * numTrees * sizeof(std::array) + // each internal level of the tree has two sums + elementSize * numTrees * 2 + // we must program numTrees inactive F leaves + elementSize * numTrees * 2 * programPuncturedPoint; // if we are programing the active lead, then we have numTrees more. + + // allocate the buffer and partition them. + buff.resize(numBytes); + sums = span>((std::array*)buff.data(), depth * numTrees); + leaf = span((u8*)(sums.data() + sums.size()), + elementSize * numTrees * 2 + + elementSize * numTrees * 2 * programPuncturedPoint + ); + + void* sEnd = sums.data() + sums.size(); + void* lEnd = leaf.data() + leaf.size(); + void* end = buff.data() + buff.size(); + if (sEnd > end || lEnd != end) + throw RTE_LOC; + } + + template + void validateExpandFormat( + PprfOutputFormat oFormat, + VecF& output, + u64 domain, + u64 pntCount) + { + if (oFormat == PprfOutputFormat::Interleaved && pntCount % 8) + throw std::runtime_error("For Interleaved output format, pointCount must be a multiple of 8 (general case not impl). " LOCATION); + + + switch (oFormat) + { + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + case osuCrypto::PprfOutputFormat::Interleaved: + if (output.size() != domain * pntCount) + throw RTE_LOC; + break; + case osuCrypto::PprfOutputFormat::Callback: + if (output.size()) + throw RTE_LOC; + break; + default: + throw RTE_LOC; + break; + } + + } + + + struct TreeAllocator + { + TreeAllocator() = default; + TreeAllocator(const TreeAllocator&) = delete; + TreeAllocator(TreeAllocator&&) = default; + + using ValueType = AlignedArray; + std::list> mTrees; + std::vector> mFreeTrees; + //std::mutex mMutex; + u64 mTreeSize = 0, mNumTrees = 0; + + void reserve(u64 num, u64 size) + { + //std::lock_guard lock(mMutex); + mTreeSize = size; + mNumTrees += num; + mTrees.clear(); + mFreeTrees.clear(); + mTrees.emplace_back(num * size); + auto iter = mTrees.back().data(); + for (u64 i = 0; i < num; ++i) + { + mFreeTrees.push_back(span(iter, size)); + assert((u64)mFreeTrees.back().data() % 32 == 0); + iter += size; + } + } + + span get() + { + //std::lock_guard lock(mMutex); + if (mFreeTrees.size() == 0) + { + assert(mTreeSize); + mTrees.emplace_back(mTreeSize); + mFreeTrees.push_back(span(mTrees.back().data(), mTreeSize)); + assert((u64)mFreeTrees.back().data() % 32 == 0); + ++mNumTrees; + } + + auto ret = mFreeTrees.back(); + mFreeTrees.pop_back(); + return ret; + } + + void clear() + { + mTrees = {}; + mFreeTrees = {}; + mTreeSize = 0; + mNumTrees = 0; + } + }; + + + inline void allocateExpandTree( + TreeAllocator& alloc, + std::vector>>& levels) + { + span> tree = alloc.get(); + assert((u64)tree.data() % 32 == 0); + levels[0] = tree.subspan(0, 1); + auto rem = tree.subspan(2); + for (auto i = 1ull; i < levels.size(); ++i) + { + levels[i] = rem.subspan(0, levels[i - 1].size() * 2); + assert((u64)levels[i].data() % 32 == 0); + rem = rem.subspan(levels[i].size()); + } + } + + + } + +} \ No newline at end of file diff --git a/libOTe/Tools/Pprf/RegularPprf.cpp b/libOTe/Tools/Pprf/RegularPprf.cpp new file mode 100644 index 00000000..0a6c4455 --- /dev/null +++ b/libOTe/Tools/Pprf/RegularPprf.cpp @@ -0,0 +1,11 @@ +#include "cryptoTools/Crypto/AES.h" + +namespace osuCrypto { + // A public PRF/PRG that we will use for deriving the GGM tree. + extern const std::array gGgmAes = []() { + std::array aes; + aes[0].setKey(toBlock(3242342)); + aes[1].setKey(toBlock(8993849)); + return aes; + }(); +} diff --git a/libOTe/Tools/Pprf/RegularPprf.h b/libOTe/Tools/Pprf/RegularPprf.h new file mode 100644 index 00000000..439150af --- /dev/null +++ b/libOTe/Tools/Pprf/RegularPprf.h @@ -0,0 +1,1039 @@ +#pragma once +#include "libOTe/config.h" + +#ifdef ENABLE_PPRF +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/Matrix.h" +#include "cryptoTools/Common/Timer.h" +#include "cryptoTools/Common/Aligned.h" +#include "cryptoTools/Common/Range.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "libOTe/Tools/Coproto.h" +#include +#include "libOTe/Tools/CoeffCtx.h" +#include "PprfUtil.h" + +namespace osuCrypto +{ + + extern const std::array gGgmAes; + + + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class RegularPprfSender : public TimerAdapter { + public: + u64 mDomain = 0, mDepth = 0, mPntCount = 0; + std::vector mValue; + Matrix> mBaseOTs; + + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; + + std::function mOutputFn; + + + RegularPprfSender() = default; + + RegularPprfSender(const RegularPprfSender&) = delete; + + RegularPprfSender(RegularPprfSender&&) = delete; + + RegularPprfSender(u64 domainSize, u64 pointCount) { + configure(domainSize, pointCount); + } + + void configure(u64 domainSize, u64 pointCount) + { + if (domainSize & 1) + throw std::runtime_error("Pprf domain must be even. " LOCATION); + if (domainSize < 2) + throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); + + mDomain = domainSize; + mDepth = log2ceil(mDomain); + mPntCount = pointCount; + + mBaseOTs.resize(0, 0); + } + + + // the number of base OTs that should be set. + u64 baseOtCount() const { + return mDepth * mPntCount; + } + + // returns true if the base OTs are currently set. + bool hasBaseOts() const { + return mBaseOTs.size(); + } + + + void setBase(span> baseMessages) { + if (baseOtCount() != static_cast(baseMessages.size())) + throw RTE_LOC; + + mBaseOTs.resize(mPntCount, mDepth); + for (u64 i = 0; i < static_cast(mBaseOTs.size()); ++i) + mBaseOTs(i) = baseMessages[i]; + } + + task<> expand( + Socket& chl, + const VecF& value, + block seed, + VecF& output, + PprfOutputFormat oFormat, + bool programPuncturedPoint, + u64 numThreads, + CoeffCtx ctx = {}) + { + if (programPuncturedPoint) + setValue(value); + + setTimePoint("SilentMultiPprfSender.start"); + + pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); + + MC_BEGIN(task<>, this, numThreads, oFormat, &output, seed, &chl, programPuncturedPoint, ctx, + treeIndex = u64{}, + tree = span>{}, + levels = std::vector> >{}, + leafIndex = u64{}, + leafLevelPtr = (VecF*)nullptr, + leafLevel = VecF{}, + buff = std::vector{}, + encSums = span>{}, + leafMsgs = span{}, + mTreeAlloc = pprf::TreeAllocator{} + ); + + mTreeAlloc.reserve(numThreads, (1ull << mDepth) + 2); + setTimePoint("SilentMultiPprfSender.reserve"); + + levels.resize(mDepth); + pprf::allocateExpandTree(mTreeAlloc, levels); + + for (treeIndex = 0; treeIndex < mPntCount; treeIndex += 8) + { + // for interleaved format, the leaf level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + leafIndex = treeIndex * mDomain; + leafLevelPtr = &output; + } + else + { + // we will use leaf level as a buffer before + // copying the result to the output. + leafIndex = 0; + ctx.resize(leafLevel, mDomain * 8); + leafLevelPtr = &leafLevel; + } + + // allocate the send buffer and partition it. + pprf::allocateExpandBuffer( + mDepth - 1, std::min(8, mPntCount - treeIndex), programPuncturedPoint, buff, encSums, leafMsgs, ctx); + + // exapnd the tree + expandOne(seed, treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs, ctx); + + MC_AWAIT(chl.send(std::move(buff))); + + // if we aren't interleaved, we need to copy the + // leaf layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); + + } + + mBaseOTs = {}; + //mTreeAlloc.del(tree); + mTreeAlloc.clear(); + + setTimePoint("SilentMultiPprfSender.de-alloc"); + + MC_END(); + } + + void setValue(span value) { + + mValue.resize(mPntCount); + + if (value.size() == 1) { + std::fill(mValue.begin(), mValue.end(), value[0]); + } + else { + if ((u64)value.size() != mPntCount) + throw RTE_LOC; + + std::copy(value.begin(), value.end(), mValue.begin()); + } + } + + void clear() { + mBaseOTs.resize(0, 0); + mDomain = 0; + mDepth = 0; + mPntCount = 0; + } + + void expandOne( + block aesSeed, + u64 treeIdx, + bool programPuncturedPoint, + span>> levels, + VecF& leafLevel, + const u64 leafOffset, + span> encSums, + span leafMsgs, + CoeffCtx ctx) + { + auto remTrees = std::min(8, mPntCount - treeIdx); + + // the first level should be size 1, the root of the tree. + // we will populate it with random seeds using aesSeed in counter mode + // based on the tree index. + assert(levels[0].size() == 1); + mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); + + assert(encSums.size() == (mDepth - 1) * remTrees); + auto encSumIter = encSums.begin(); + + // space for our sums of each level. Should always be less then + // 24 levels... If not increase the limit or make it a vector. + std::array, 2> sums; + + // use the optimized approach for intern nodes of the tree + // For each level perform the following. + for (u64 d = 0; d < mDepth - 1; ++d) + { + // clear the sums + memset(&sums, 0, sizeof(sums)); + + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // The previous level of the GGM tree. + auto parents = levels[d]; + + // The next level of theGGM tree that we are populating. + auto children = levels[d + 1]; + + // For each child, populate the child by expanding the parent. + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) + { + // The value of the parent. + auto& parent = parents.data()[parentIdx]; + + auto& child0 = children.data()[childIdx]; + auto& child1 = children.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + // inspired by the Expand Accumualte idea to + // use + // + // child0 = AES(parent) ^ parent + // child1 = AES(parent) + parent + // + // but instead we are a bit more conservative and + // compute + // + // child0 = AES:Round(AES(parent), parent) + // = AES:Round(AES(parent), 0) ^ parent + // child1 = AES(parent) + parent + // + // That is, we applies an additional AES round function + // to the first child before XORing it with parent. + child0[0] = AES::roundEnc(child1[0], parent[0]); + child0[1] = AES::roundEnc(child1[1], parent[1]); + child0[2] = AES::roundEnc(child1[2], parent[2]); + child0[3] = AES::roundEnc(child1[3], parent[3]); + child0[4] = AES::roundEnc(child1[4], parent[4]); + child0[5] = AES::roundEnc(child1[5], parent[5]); + child0[6] = AES::roundEnc(child1[6], parent[6]); + child0[7] = AES::roundEnc(child1[7], parent[7]); + + // Update the running sums for this level. We keep + // a left and right totals for each level. + sums[0][0] = sums[0][0] ^ child0[0]; + sums[0][1] = sums[0][1] ^ child0[1]; + sums[0][2] = sums[0][2] ^ child0[2]; + sums[0][3] = sums[0][3] ^ child0[3]; + sums[0][4] = sums[0][4] ^ child0[4]; + sums[0][5] = sums[0][5] ^ child0[5]; + sums[0][6] = sums[0][6] ^ child0[6]; + sums[0][7] = sums[0][7] ^ child0[7]; + + // child1 = AES(parent) + parent + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + sums[1][0] = sums[1][0] ^ child1[0]; + sums[1][1] = sums[1][1] ^ child1[1]; + sums[1][2] = sums[1][2] ^ child1[2]; + sums[1][3] = sums[1][3] ^ child1[3]; + sums[1][4] = sums[1][4] ^ child1[4]; + sums[1][5] = sums[1][5] ^ child1[5]; + sums[1][6] = sums[1][6] ^ child1[6]; + sums[1][7] = sums[1][7] ^ child1[7]; + + } + + // encrypt the sums and write them to the output. + for (u64 j = 0; j < remTrees; ++j) + { + (*encSumIter)[0] = sums[0][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][1]; + (*encSumIter)[1] = sums[1][j] ^ mBaseOTs[treeIdx + j][mDepth - 1 - d][0]; + ++encSumIter; + } + } + assert(encSumIter == encSums.end()); + + auto d = mDepth - 1; + + // The previous level of the GGM tree. + auto level0 = levels[d]; + + // The total number of parents in this level. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // The next level of theGGM tree that we are populating. + std::array child; + + // clear the sums + std::array leafSums; + ctx.resize(leafSums[0], 8); + ctx.resize(leafSums[1], 8); + ctx.zero(leafSums[0].begin(), leafSums[0].end()); + ctx.zero(leafSums[1].begin(), leafSums[1].end()); + + // for the leaf nodes we need to hash both children. + for (u64 parentIdx = 0, outIdx = leafOffset, childIdx = 0; parentIdx < width; ++parentIdx) + { + // The value of the parent. + auto& parent = level0.data()[parentIdx]; + + // The bit that indicates if we are on the left child (0) + // or on the right child (1). + for (u64 keep = 0; keep < 2; ++keep, ++childIdx, outIdx += 8) + { + // The child that we will write in this iteration. + + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); + + ctx.fromBlock(leafLevel[outIdx + 0], child[0]); + ctx.fromBlock(leafLevel[outIdx + 1], child[1]); + ctx.fromBlock(leafLevel[outIdx + 2], child[2]); + ctx.fromBlock(leafLevel[outIdx + 3], child[3]); + ctx.fromBlock(leafLevel[outIdx + 4], child[4]); + ctx.fromBlock(leafLevel[outIdx + 5], child[5]); + ctx.fromBlock(leafLevel[outIdx + 6], child[6]); + ctx.fromBlock(leafLevel[outIdx + 7], child[7]); + + // leafSum += child + auto& leafSum = leafSums[keep]; + ctx.plus(leafSum[0], leafSum[0], leafLevel[outIdx + 0]); + ctx.plus(leafSum[1], leafSum[1], leafLevel[outIdx + 1]); + ctx.plus(leafSum[2], leafSum[2], leafLevel[outIdx + 2]); + ctx.plus(leafSum[3], leafSum[3], leafLevel[outIdx + 3]); + ctx.plus(leafSum[4], leafSum[4], leafLevel[outIdx + 4]); + ctx.plus(leafSum[5], leafSum[5], leafLevel[outIdx + 5]); + ctx.plus(leafSum[6], leafSum[6], leafLevel[outIdx + 6]); + ctx.plus(leafSum[7], leafSum[7], leafLevel[outIdx + 7]); + } + + } + + if (programPuncturedPoint) + { + // For the leaf level, we are going to do something special. + // The other party is currently missing both leaf children of + // the active parent. Since this is the leaf level, we want + // the inactive child to just be the normal value but the + // active child should be the correct value XOR the delta. + // This will be done by sending the sums and the sums plus + // delta and ensure that they can only decrypt the correct ones. + VecF leafOts; + ctx.resize(leafOts, 2); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + // we will construct two OT strings. Let + // s0, s1 be the left and right child sums. + // + // m0 = (s0 , s1 + val) + // m1 = (s0 + val, s1 ) + // + // these will be encrypted by the OT keys + for (u64 k = 0; k < 2; ++k) + { + if (k == 0) + { + // m0 = (s0, s1 + val) + ctx.copy(leafOts[0], leafSums[0][j]); + ctx.plus(leafOts[1], leafSums[1][j], mValue[treeIdx + j]); + } + else + { + // m1 = (s0+val, s1) + ctx.plus(leafOts[0], leafSums[0][j], mValue[treeIdx + j]); + ctx.copy(leafOts[1], leafSums[1][j]); + } + + // copy m0 into the output buffer. + span buff = leafMsgs.subspan(0, 2 * ctx.template byteSize()); + leafMsgs = leafMsgs.subspan(buff.size()); + ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); + + // encrypt the output buffer. + otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + } + } + } + else + { + VecF leafOts; + ctx.resize(leafOts, 1); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + for (u64 k = 0; k < 2; ++k) + { + // copy the sum k into the output buffer. + ctx.copy(leafOts[0], leafSums[k][j]); + span buff = leafMsgs.subspan(0, ctx.template byteSize()); + leafMsgs = leafMsgs.subspan(buff.size()); + ctx.serialize(leafOts.begin(), leafOts.end(), buff.begin()); + + // encrypt the output buffer. + otMasker.SetSeed(mBaseOTs[treeIdx + j][0][1 ^ k], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + } + } + } + + assert(leafMsgs.size() == 0); + } + + + }; + + + template< + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class RegularPprfReceiver : public TimerAdapter + { + public: + u64 mDomain = 0, mDepth = 0, mPntCount = 0; + using VecF = typename CoeffCtx::template Vec; + using VecG = typename CoeffCtx::template Vec; + + //std::vector mPoints; + + Matrix mBaseOTs; + + Matrix mBaseChoices; + + std::function mOutputFn; + + RegularPprfReceiver() = default; + RegularPprfReceiver(const RegularPprfReceiver&) = delete; + RegularPprfReceiver(RegularPprfReceiver&&) = delete; + + void configure(u64 domainSize, u64 pointCount) + { + if (domainSize & 1) + throw std::runtime_error("Pprf domain must be even. " LOCATION); + if (domainSize < 2) + throw std::runtime_error("Pprf domain must must be at least 2. " LOCATION); + + mDomain = domainSize; + mDepth = log2ceil(mDomain); + mPntCount = pointCount; + + mBaseOTs.resize(0, 0); + } + + + // this function sample mPntCount integers in the range + // [0,domain) and returns these as the choice bits. + BitVector sampleChoiceBits(PRNG& prng) + { + BitVector choices(mPntCount * mDepth); + + // The points are read in blocks of 8, so make sure that there is a + // whole number of blocks. + mBaseChoices.resize(mPntCount, mDepth); + for (u64 i = 0; i < mPntCount; ++i) + { + u64 idx = prng.get() % mDomain; + for (u64 j = 0; j < mDepth; ++j) + mBaseChoices(i, j) = *BitIterator((u8*)&idx, j); + } + + for (u64 i = 0; i < mBaseChoices.size(); ++i) + { + choices[i] = mBaseChoices(i); + } + + return choices; + } + + // choices is in the same format as the output from sampleChoiceBits. + void setChoiceBits(const BitVector& choices) + { + // Make sure we're given the right number of OTs. + if (choices.size() != baseOtCount()) + throw RTE_LOC; + + mBaseChoices.resize(mPntCount, mDepth); + for (u64 i = 0; i < mPntCount; ++i) + { + u64 idx = 0; + for (u64 j = 0; j < mDepth; ++j) + { + mBaseChoices(i, j) = choices[mDepth * i + j]; + idx |= u64(choices[mDepth * i + j]) << j; + } + + if (idx >= mDomain) + throw std::runtime_error("provided choice bits index outside of the domain." LOCATION); + } + } + + + // the number of base OTs that should be set. + u64 baseOtCount() const + { + return mDepth * mPntCount; + } + + // returns true if the base OTs are currently set. + bool hasBaseOts() const + { + return mBaseOTs.size(); + } + + + void setBase(span baseMessages) + { + if (baseOtCount() != static_cast(baseMessages.size())) + throw RTE_LOC; + + // The OTs are used in blocks of 8, so make sure that there is a whole + // number of blocks. + mBaseOTs.resize(roundUpTo(mPntCount, 8), mDepth); + memcpy(mBaseOTs.data(), baseMessages.data(), baseMessages.size() * sizeof(block)); + } + + std::vector getPoints(PprfOutputFormat format) + { + std::vector pnts(mPntCount); + getPoints(pnts, format); + return pnts; + } + void getPoints(span points, PprfOutputFormat format) + { + if ((u64)points.size() != mPntCount) + throw RTE_LOC; + + switch (format) + { + case PprfOutputFormat::ByLeafIndex: + case PprfOutputFormat::ByTreeIndex: + + memset(points.data(), 0, points.size() * sizeof(u64)); + for (u64 j = 0; j < mPntCount; ++j) + { + for (u64 k = 0; k < mDepth; ++k) + points[j] |= u64(mBaseChoices(j, k)) << k; + + assert(points[j] < mDomain); + } + + + break; + case PprfOutputFormat::Interleaved: + case PprfOutputFormat::Callback: + + getPoints(points, PprfOutputFormat::ByLeafIndex); + + // in interleaved mode we generate 8 trees in a batch. + // the i'th leaf of these 8 trees are next to eachother. + for (u64 j = 0; j < points.size(); ++j) + { + auto subTree = j % 8; + auto batch = j / 8; + points[j] = (batch * mDomain + points[j]) * 8 + subTree; + } + + //interleavedPoints(points, mDomain, format); + + break; + default: + throw RTE_LOC; + break; + } + } + + // programPuncturedPoint says whether the sender is trying to program the + // active child to be its correct value XOR delta. If it is not, the + // active child will just take a random value. + task<> expand( + Socket& chl, + VecF& output, + PprfOutputFormat oFormat, + bool programPuncturedPoint, + u64 numThreads, + CoeffCtx ctx = {}) + { + pprf::validateExpandFormat(oFormat, output, mDomain, mPntCount); + + MC_BEGIN(task<>, this, oFormat, &output, &chl, programPuncturedPoint, ctx, + treeIndex = u64{}, + levels = std::vector>>{}, + leafIndex = u64{}, + leafLevelPtr = (VecF*)nullptr, + leafLevel = VecF{}, + buff = std::vector{}, + encSums = span>{}, + leafMsgs = span{}, + mTreeAlloc = pprf::TreeAllocator{}, + points = std::vector{} + ); + + setTimePoint("SilentMultiPprfReceiver.start"); + points.resize(mPntCount); + getPoints(points, PprfOutputFormat::ByLeafIndex); + + mTreeAlloc.reserve(1, (1ull << mDepth) + 2); + setTimePoint("SilentMultiPprfSender.reserve"); + + levels.resize(mDepth); + pprf::allocateExpandTree(mTreeAlloc, levels); + + for (treeIndex = 0; treeIndex < mPntCount; treeIndex += 8) + { + // for interleaved format, the leaf level of the tree + // is simply the output. + if (oFormat == PprfOutputFormat::Interleaved) + { + leafIndex = treeIndex * mDomain; + leafLevelPtr = &output; + } + else + { + // we will use leaf level as a buffer before + // copying the result to the output. + leafIndex = 0; + ctx.resize(leafLevel, mDomain * 8); + leafLevelPtr = &leafLevel; + } + + // allocate the send buffer and partition it. + pprf::allocateExpandBuffer(mDepth - 1, std::min(8, mPntCount - treeIndex), + programPuncturedPoint, buff, encSums, leafMsgs, ctx); + + MC_AWAIT(chl.recv(buff)); + + // exapnd the tree + expandOne(treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs, points, ctx); + + // if we aren't interleaved, we need to copy the + // leaf layer to the output. + if (oFormat != PprfOutputFormat::Interleaved) + pprf::copyOut(leafLevel, output, mPntCount, treeIndex, oFormat, mOutputFn); + } + + setTimePoint("SilentMultiPprfReceiver.join"); + + mBaseOTs = {}; + //mTreeAlloc.del(tree); + mTreeAlloc.clear(); + + setTimePoint("SilentMultiPprfReceiver.de-alloc"); + + MC_END(); + } + + void clear() + { + mBaseOTs.resize(0, 0); + mBaseChoices.resize(0, 0); + mDomain = 0; + mDepth = 0; + mPntCount = 0; + } + + void expandOneInternal( + u64 treeIdx, + span>> levels, + span, 2>> theirSums, + CoeffCtx& ctx) + { + } + + //treeIndex, programPuncturedPoint, levels, *leafLevelPtr, leafIndex, encSums, leafMsgs + void expandOne( + u64 treeIdx, + bool programPuncturedPoint, + span>> levels, + VecF& leafLevel, + const u64 outputOffset, + span> theirSums, + span leafMsg, + span points, + CoeffCtx& ctx) + { + auto remTrees = std::min(8, mPntCount - treeIdx); + assert(theirSums.size() == remTrees * (mDepth - 1)); + + // We change the hash function for the leaf so lets update + // inactiveChildValues to use the new hash and subtract + // these from the leafSums + std::array leafSums; + if (mDepth > 1) + { + auto theirSumsIter = theirSums.begin(); + + // special case for the first level. + auto l1 = levels[1]; + for (u64 i = 0; i < remTrees; ++i) + { + // For the non-active path, set the child of the root node + // as the OT message XOR'ed with the correction sum. + + int active = mBaseChoices[i + treeIdx].back(); + l1[active ^ 1][i] = mBaseOTs[i + treeIdx].back() ^ (*theirSumsIter)[active ^ 1]; + l1[active][i] = ZeroBlock; + ++theirSumsIter; + //if (!i) + // std::cout << " unmask " + // << mBaseOTs[i + treeIdx].back() << " ^ " + // << theirSums[0][active ^ 1][i] << " = " + // << l1[active ^ 1][i] << std::endl; + + } + + // space for our sums of each level. + std::array, 2> mySums; + + // this will be the value of both children of active an parent + // before the active child is updated. We will need to subtract + // this value as the main loop does not distinguish active parents. + std::array inactiveChildValues; + inactiveChildValues[0] = AES::roundEnc(mAesFixedKey.ecbEncBlock(ZeroBlock), ZeroBlock); + inactiveChildValues[1] = mAesFixedKey.ecbEncBlock(ZeroBlock); + + // For all other levels, expand the GGM tree and add in + // the correction along the active path. + for (u64 d = 1; d < mDepth - 1; ++d) + { + // initialized the sums with inactiveChildValue so that + // it will cancel when we expand the actual inactive child. + std::fill(mySums[0].begin(), mySums[0].end(), inactiveChildValues[0]); + std::fill(mySums[1].begin(), mySums[1].end(), inactiveChildValues[1]); + + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; + + // The next level that we want to construct. + auto level1 = levels[d + 1]; + + for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) + { + // The value of the parent. + auto parent = level0[parentIdx]; + + auto& child0 = level1.data()[childIdx]; + auto& child1 = level1.data()[childIdx + 1]; + mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); + + // inspired by the Expand Accumualte idea to + // use + // + // child0 = AES(parent) ^ parent + // child1 = AES(parent) + parent + // + // but instead we are a bit more conservative and + // compute + // + // child0 = AES:Round(AES(parent), parent) + // = AES:Round(AES(parent), 0) ^ parent + // child1 = AES(parent) + parent + // + // That is, we applies an additional AES round function + // to the first child before XORing it with parent. + child0[0] = AES::roundEnc(child1[0], parent[0]); + child0[1] = AES::roundEnc(child1[1], parent[1]); + child0[2] = AES::roundEnc(child1[2], parent[2]); + child0[3] = AES::roundEnc(child1[3], parent[3]); + child0[4] = AES::roundEnc(child1[4], parent[4]); + child0[5] = AES::roundEnc(child1[5], parent[5]); + child0[6] = AES::roundEnc(child1[6], parent[6]); + child0[7] = AES::roundEnc(child1[7], parent[7]); + + // Update the running sums for this level. We keep + // a left and right totals for each level. Note that + // we are actually XOR in the incorrect value of the + // children of the active parent but this will cancel + // with inactiveChildValue thats already there. + mySums[0][0] = mySums[0][0] ^ child0[0]; + mySums[0][1] = mySums[0][1] ^ child0[1]; + mySums[0][2] = mySums[0][2] ^ child0[2]; + mySums[0][3] = mySums[0][3] ^ child0[3]; + mySums[0][4] = mySums[0][4] ^ child0[4]; + mySums[0][5] = mySums[0][5] ^ child0[5]; + mySums[0][6] = mySums[0][6] ^ child0[6]; + mySums[0][7] = mySums[0][7] ^ child0[7]; + + // child1 = AES(parent) + parent + child1[0] = child1[0] + parent[0]; + child1[1] = child1[1] + parent[1]; + child1[2] = child1[2] + parent[2]; + child1[3] = child1[3] + parent[3]; + child1[4] = child1[4] + parent[4]; + child1[5] = child1[5] + parent[5]; + child1[6] = child1[6] + parent[6]; + child1[7] = child1[7] + parent[7]; + + mySums[1][0] = mySums[1][0] ^ child1[0]; + mySums[1][1] = mySums[1][1] ^ child1[1]; + mySums[1][2] = mySums[1][2] ^ child1[2]; + mySums[1][3] = mySums[1][3] ^ child1[3]; + mySums[1][4] = mySums[1][4] ^ child1[4]; + mySums[1][5] = mySums[1][5] ^ child1[5]; + mySums[1][6] = mySums[1][6] ^ child1[6]; + mySums[1][7] = mySums[1][7] ^ child1[7]; + + } + + + // we have to update the non-active child of the active parent. + for (u64 i = 0; i < remTrees; ++i) + { + // the index of the leaf node that is active. + auto leafIdx = points[i + treeIdx]; + + // The index of the active (missing) child node. + auto missingChildIdx = leafIdx >> (mDepth - 1 - d); + + // The index of the active child node sibling. + auto siblingIdx = missingChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = siblingIdx & 1; + + // our sums & OTs cancel and we are leaf with the + // correct value for the inactive child. + level1[siblingIdx][i] = + (*theirSumsIter)[notAi] ^ + mySums[notAi][i] ^ + mBaseOTs[i + treeIdx][mDepth - 1 - d]; + + ++theirSumsIter; + + // we have to set the active child to zero so + // the next children are predictable. + level1[missingChildIdx][i] = ZeroBlock; + } + } + + auto d = mDepth - 1; + // The already constructed level. Only missing the + // GGM tree node value along the active path. + auto level0 = levels[d]; + + // The next level of theGGM tree that we are populating. + std::array child; + + // We will iterate over each node on this level and + // expand it into it's two children. Note that the + // active node will also be expanded. Later we will just + // overwrite whatever the value was. This is an optimization. + auto width = divCeil(mDomain, 1ull << (mDepth - d)); + + VecF temp; + ctx.resize(temp, 2); + for (u64 k = 0; k < 2; ++k) + { + ctx.resize(leafSums[k], 8); + ctx.zero(leafSums[k].begin(), leafSums[k].end()); + ctx.fromBlock(temp[k], gGgmAes[k].hashBlock(ZeroBlock)); + ctx.minus(leafSums[k][0], leafSums[k][0], temp[k]); + for (u64 i = 1; i < 8; ++i) + ctx.copy(leafSums[k][i], leafSums[k][0]); + } + + // for leaf nodes both children should be hashed. + for (u64 parentIdx = 0, childIdx = 0, outputIdx = outputOffset; parentIdx < width; ++parentIdx) + { + // The value of the parent. + auto parent = level0[parentIdx]; + + for (u64 keep = 0; keep < 2; ++keep, ++childIdx, outputIdx += 8) + { + // Each parent is expanded into the left and right children + // using a different AES fixed-key. Therefore our OWF is: + // + // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); + // + // where each half defines one of the children. + gGgmAes[keep].hashBlocks<8>(parent.data(), child.data()); + + ctx.fromBlock(leafLevel[outputIdx + 0], child[0]); + ctx.fromBlock(leafLevel[outputIdx + 1], child[1]); + ctx.fromBlock(leafLevel[outputIdx + 2], child[2]); + ctx.fromBlock(leafLevel[outputIdx + 3], child[3]); + ctx.fromBlock(leafLevel[outputIdx + 4], child[4]); + ctx.fromBlock(leafLevel[outputIdx + 5], child[5]); + ctx.fromBlock(leafLevel[outputIdx + 6], child[6]); + ctx.fromBlock(leafLevel[outputIdx + 7], child[7]); + + auto& leafSum = leafSums[keep]; + ctx.plus(leafSum[0], leafSum[0], leafLevel[outputIdx + 0]); + ctx.plus(leafSum[1], leafSum[1], leafLevel[outputIdx + 1]); + ctx.plus(leafSum[2], leafSum[2], leafLevel[outputIdx + 2]); + ctx.plus(leafSum[3], leafSum[3], leafLevel[outputIdx + 3]); + ctx.plus(leafSum[4], leafSum[4], leafLevel[outputIdx + 4]); + ctx.plus(leafSum[5], leafSum[5], leafLevel[outputIdx + 5]); + ctx.plus(leafSum[6], leafSum[6], leafLevel[outputIdx + 6]); + ctx.plus(leafSum[7], leafSum[7], leafLevel[outputIdx + 7]); + } + } + } + else + { + for (u64 k = 0; k < 2; ++k) + { + ctx.resize(leafSums[k], 8); + ctx.zero(leafSums[k].begin(), leafSums[k].end()); + } + } + + // leaf level. + if (programPuncturedPoint) + { + // Now processes the leaf level. This one is special + // because we must XOR in the correction value as + // before but we must also fixed the child value for + // the active child. To do this, we will receive 4 + // values. Two for each case (left active or right active). + //timer.setTimePoint("recv.recvleaf"); + VecF leafOts; + ctx.resize(leafOts, 2); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + + // The index of the child on the active path. + auto activeChildIdx = points[j + treeIdx]; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + // offset to the first or second ot message, based on the one we want + auto offset = ctx.template byteSize() * 2 * notAi; + + + // decrypt the ot string + span buff = leafMsg.subspan(offset, ctx.template byteSize() * 2); + leafMsg = leafMsg.subspan(buff.size() * 2); + otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); + + auto out0 = (activeChildIdx & ~1ull) * 8 + j + outputOffset; + auto out1 = (activeChildIdx | 1ull) * 8 + j + outputOffset; + + ctx.minus(leafLevel[out0], leafOts[0], leafSums[0][j]); + ctx.minus(leafLevel[out1], leafOts[1], leafSums[1][j]); + } + } + else + { + VecF leafOts; + ctx.resize(leafOts, 1); + PRNG otMasker; + + for (u64 j = 0; j < remTrees; ++j) + { + // The index of the child on the active path. + auto activeChildIdx = points[j + treeIdx]; + + // The index of the other (inactive) child. + auto inactiveChildIdx = activeChildIdx ^ 1; + + // The indicator as to the left or right child is inactive + auto notAi = inactiveChildIdx & 1; + + // offset to the first or second ot message, based on the one we want + auto offset = ctx.template byteSize() * notAi; + + // decrypt the ot string + span buff = leafMsg.subspan(offset, ctx.template byteSize()); + leafMsg = leafMsg.subspan(buff.size() * 2); + otMasker.SetSeed(mBaseOTs[j + treeIdx][0], divCeil(buff.size(), sizeof(block))); + for (u64 i = 0; i < buff.size(); ++i) + buff[i] ^= otMasker.get(); + + ctx.deserialize(buff.begin(), buff.end(), leafOts.begin()); + + std::array out{ + (activeChildIdx & ~1ull) * 8 + j + outputOffset, + (activeChildIdx | 1ull) * 8 + j + outputOffset + }; + + auto keep = leafLevel.begin() + out[notAi]; + auto zero = leafLevel.begin() + out[notAi ^ 1]; + + ctx.minus(*keep, leafOts[0], leafSums[notAi][j]); + ctx.zero(zero, zero + 1); + } + } + } + }; +} + +#endif \ No newline at end of file diff --git a/libOTe/Tools/QuasiCyclicCode.h b/libOTe/Tools/QuasiCyclicCode.h index d1429fc9..bcf95183 100644 --- a/libOTe/Tools/QuasiCyclicCode.h +++ b/libOTe/Tools/QuasiCyclicCode.h @@ -1,5 +1,5 @@ #pragma once -// © 2022 Visaß. +// © 2022 Visa. // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. @@ -25,384 +25,333 @@ namespace osuCrypto { - // https://eprint.iacr.org/2019/1159.pdf - struct QuasiCyclicCode : public TimerAdapter - { - private: - u64 mScaler = 0; - u64 mNumThreads = 1; + // https://eprint.iacr.org/2019/1159.pdf + struct QuasiCyclicCode : public TimerAdapter + { + private: + + // the length of the encoding + u64 mMessageSize = 0; + + + // the length of the input. mCodeSize = mMessageSize * mScaler; + u64 mCodeSize = 0; + + // the next prime starting at mMessageSize. The + // real code size will in fact be of size mScaler * mPrimeMosulus + u64 mPrimeModulus = 0; + public: - // the length of the encoding - u64 mP = 0; + //size of the input + u64 size() { return mCodeSize; } + + // initialize the compressing matrix that maps a + // vector of size n * scaler to a vector of size n. + void init(u64 n, u64 scaler = 2) + { + if (scaler <= 1) + throw RTE_LOC; + + mMessageSize = n; + mPrimeModulus = nextPrime(n); + mCodeSize = mMessageSize * scaler; + } + void init2(u64 messageSize, u64 codeSize) + { + auto scaler = divCeil(codeSize, messageSize); + + mMessageSize = messageSize; + mPrimeModulus = nextPrime(messageSize); + mCodeSize = codeSize; + } + + static void bitShiftXor(span dest, span in, u8 bitShift) + { + if (bitShift > 127) + throw RTE_LOC; + if (u64(in.data()) % 16) + throw RTE_LOC; + + if (bitShift >= 64) + { + bitShift -= 64; + const int bitShift2 = 64 - bitShift; + u8* inPtr = ((u8*)in.data()) + sizeof(u64); + + auto end = std::min(dest.size(), in.size() - 1); + for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) + { + block + b0 = toBlock(inPtr), + b1 = toBlock(inPtr + sizeof(u64)); - // the length of the input. mM = mP * mScaler; - u64 mM = 0; + b0 = (b0 >> bitShift); + b1 = (b1 << bitShift2); - public: - - //size of the input - u64 size() { return mM; } - - // initialize the compressing matrix that maps a - // vector of size n * scaler to a vector of size n. - void init(u64 n, u64 scaler = 2) - { - if (scaler <= 1) - throw RTE_LOC; - - if (isPrime(n) == false) - throw RTE_LOC; - - mP = n;// nextPrime(n); - //mN = n; - mM = mP * scaler; - mScaler = scaler; - } + dest[i] = dest[i] ^ b0 ^ b1; + } - static void bitShiftXor(span dest, span in, u8 bitShift) - { - if (bitShift > 127) - throw RTE_LOC; - if (u64(in.data()) % 16) - throw RTE_LOC; + if (end != static_cast(dest.size())) + { + u64 b0 = *(u64*)inPtr; + b0 = (b0 >> bitShift); - if (bitShift >= 64) - { - bitShift -= 64; - const int bitShift2 = 64 - bitShift; - u8* inPtr = ((u8*)in.data()) + sizeof(u64); - - auto end = std::min(dest.size(), in.size() - 1); - for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) - { - block - b0 = toBlock(inPtr), - b1 = toBlock(inPtr + sizeof(u64)); + *(u64*)(&dest[end]) ^= b0; + } + } + else if (bitShift) + { + const int bitShift2 = 64 - bitShift; + u8* inPtr = (u8*)in.data(); + + auto end = std::min(dest.size(), in.size() - 1); + for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) + { + block + b0 = toBlock(inPtr), + b1 = toBlock(inPtr + sizeof(u64)); + + b0 = (b0 >> bitShift); + b1 = (b1 << bitShift2); + + //bv0.append((u8*)&b0, 128); + //bv1.append((u8*)&b1, 128); + + dest[i] = dest[i] ^ b0 ^ b1; + } + + if (end != static_cast(dest.size())) + { + block b0 = toBlock(inPtr); + b0 = (b0 >> bitShift); + + //bv0.append((u8*)&b0, 128); + + dest[end] = dest[end] ^ b0; + + u64 b1 = *(u64*)(inPtr + sizeof(u64)); + b1 = (b1 << bitShift2); + + //bv1.append((u8*)&b1, 64); + + *(u64*)&dest[end] ^= b1; + } + + + + //std::cout << " b0 " << bv0 << std::endl; + //std::cout << " b1 " << bv1 << std::endl; + } + else + { + auto end = std::min(dest.size(), in.size()); + for (u64 i = 0; i < end; ++i) + { + dest[i] = dest[i] ^ in[i]; + } + } + } - b0 = (b0 >> bitShift); - b1 = (b1 << bitShift2); + static void modp(span dest, span in, u64 p) + { + auto pBlocks = (p + 127) / 128; + auto pBytes = (p + 7) / 8; - dest[i] = dest[i] ^ b0 ^ b1; - } + if (static_cast(dest.size()) < pBlocks) + throw RTE_LOC; - if (end != static_cast(dest.size())) - { - u64 b0 = *(u64*)inPtr; - b0 = (b0 >> bitShift); + if (static_cast(in.size()) < pBlocks) + throw RTE_LOC; - *(u64*)(&dest[end]) ^= b0; - } - } - else if (bitShift) - { - const int bitShift2 = 64 - bitShift; - u8* inPtr = (u8*)in.data(); + auto count = (in.size() * 128 + p - 1) / p; - auto end = std::min(dest.size(), in.size() - 1); - for (u64 i = 0; i < end; ++i, inPtr += sizeof(block)) - { - block - b0 = toBlock(inPtr), - b1 = toBlock(inPtr + sizeof(u64)); - - b0 = (b0 >> bitShift); - b1 = (b1 << bitShift2); - - //bv0.append((u8*)&b0, 128); - //bv1.append((u8*)&b1, 128); - - dest[i] = dest[i] ^ b0 ^ b1; - } - - if (end != static_cast(dest.size())) - { - block b0 = toBlock(inPtr); - b0 = (b0 >> bitShift); - - //bv0.append((u8*)&b0, 128); - - dest[end] = dest[end] ^ b0; - - u64 b1 = *(u64*)(inPtr + sizeof(u64)); - b1 = (b1 << bitShift2); - - //bv1.append((u8*)&b1, 64); - - *(u64*)&dest[end] ^= b1; - } - - - - //std::cout << " b0 " << bv0 << std::endl; - //std::cout << " b1 " << bv1 << std::endl; - } - else - { - auto end = std::min(dest.size(), in.size()); - for (u64 i = 0; i < end; ++i) - { - dest[i] = dest[i] ^ in[i]; - } - } - } - - static void modp(span dest, span in, u64 p) - { - auto pBlocks = (p + 127) / 128; - auto pBytes = (p + 7) / 8; - - if (static_cast(dest.size()) < pBlocks) - throw RTE_LOC; - - if (static_cast(in.size()) < pBlocks) - throw RTE_LOC; - - auto count = (in.size() * 128 + p - 1) / p; + memcpy(dest.data(), in.data(), pBytes); - memcpy(dest.data(), in.data(), pBytes); - - for (u64 i = 1; i < count; ++i) - { - auto begin = i * p; - auto end = std::min(i * p + p, in.size() * 128); + for (u64 i = 1; i < count; ++i) + { + auto begin = i * p; + auto end = std::min(i * p + p, in.size() * 128); - auto shift = begin & 127; - auto beginBlock = in.data() + (begin / 128); - auto endBlock = in.data() + ((end + 127) / 128); + auto shift = begin & 127; + auto beginBlock = in.data() + (begin / 128); + auto endBlock = in.data() + ((end + 127) / 128); - if (endBlock > in.data() + in.size()) - throw RTE_LOC; + if (endBlock > in.data() + in.size()) + throw RTE_LOC; - auto in_i = span(beginBlock, endBlock); + auto in_i = span(beginBlock, endBlock); - bitShiftXor(dest, in_i, static_cast(shift)); - } + bitShiftXor(dest, in_i, static_cast(shift)); + } - auto offset = (p & 7); - if (offset) - { - u8 mask = (1 << offset) - 1; - auto idx = p / 8; - ((u8*)dest.data())[idx] &= mask; - } + auto offset = (p & 7); + if (offset) + { + u8 mask = (1 << offset) - 1; + auto idx = p / 8; + ((u8*)dest.data())[idx] &= mask; + } - auto rem = dest.size() * 16 - pBytes; - if (rem) - memset(((u8*)dest.data()) + pBytes, 0, rem); - } + auto rem = dest.size() * 16 - pBytes; + if (rem) + memset(((u8*)dest.data()) + pBytes, 0, rem); + } - void dualEncode(span X) - { - std::vector XX(X.size()); - for (auto i : rng(X.size())) - { - if (X[i] > 1) - throw RTE_LOC; + void dualEncode(span X) + { + std::vector XX(X.size()); + for (auto i : rng(X.size())) + { + if (X[i] > 1) + throw RTE_LOC; - XX[i] = block(X[i], X[i]); - } - dualEncode(XX); - for (auto i : rng(X.size())) - { - X[i] = XX[i] == ZeroBlock ? 0 : 1; - } - } + XX[i] = block(X[i], X[i]); + } + dualEncode(XX); + for (auto i : rng(X.size())) + { + X[i] = XX[i] == ZeroBlock ? 0 : 1; + } + } - inline void transpose(span s, MatrixView r) - { - MatrixView ss((u8*)s.data(), s.size(), sizeof(block)); - MatrixView rr((u8*)r.data(), r.rows(), r.cols() * sizeof(block)); - ::oc::transpose(ss, rr); - } + inline void transpose(span s, MatrixView r) + { + MatrixView ss((u8*)s.data(), s.size(), sizeof(block)); + auto colLen = r.cols() * sizeof(block); + if (colLen < r.size() / 8) + throw RTE_LOC; - void dualEncode(span X) - { - if(X.size() != mM) - throw RTE_LOC; - const u64 rows(128); + MatrixView rr((u8*)r.data(), r.rows(), colLen); + ::oc::transpose(ss, rr); + } - auto nBlocks = (mP + rows-1) / rows; - auto n2Blocks = ((mM-mP) + rows -1) / rows; - Matrix XT(rows, n2Blocks); - transpose(X.subspan(mP), XT); + void dualEncode(span X) + { + if (X.size() != size()) + throw RTE_LOC; + const u64 rows(128); - auto n64 = i64(nBlocks * 2); - - std::vector a(mScaler - 1); + //auto scaler = divCeil(mCodeSize, mPrimeModulus); + auto remSize = X.size() - mMessageSize; + auto scalerMinusOne = divCeil(remSize, mPrimeModulus); + + // the number of blocks required to represent a single poly + auto polyBlockSize = divCeil(mPrimeModulus, 128); + + // the number of blocks required to represent scalerMinusOne poly's + auto multiPolyBlockSize = polyBlockSize * scalerMinusOne; + + Matrix XT(rows, multiPolyBlockSize); + transpose(X.subspan(mMessageSize), XT); + + auto polyU64Size = i64(polyBlockSize * 2); - MatrixcModP1(128, nBlocks, AllocType::Uninitialized); + std::vector a(scalerMinusOne); - //std::unique_ptr brs(new ThreadBarrier[mScaler + 1]); - //for (u64 i = 0; i <= mScaler; ++i) - //brs[i].reset(mNumThreads); + MatrixcModP1(128, polyBlockSize, AllocType::Uninitialized); - //auto routine = [&](u64 index) - { - //u64 j = 0; + FFTPoly bPoly; + FFTPoly cPoly; - //{ - // std::array tpBuffer; - // auto numBlocks = mM / 128; - // auto begin = index * numBlocks / mNumThreads; - // auto end = (index + 1) * numBlocks / mNumThreads; + AlignedUnVector temp128(2 * polyBlockSize); - // for (u64 i = begin; i < end; ++i) - // { - // u64 j = i * tpBuffer.size(); + FFTPoly::DecodeCache cache; + for (u64 s = 0; s < scalerMinusOne; s += 1) + { + auto a64 = spanCast(temp128).subspan(polyU64Size); + PRNG pubPrng(toBlock(s) ^ CCBlock); + pubPrng.get(a64.data(), a64.size()); + a[s].encode(a64); + } + + for (u64 i = 0; i < rows; i += 1) + { + + for (u64 s = 0; s < scalerMinusOne; ++s) + { + auto& aPoly = a[s]; + auto b64 = spanCast(XT[i]).subspan(s * polyU64Size, polyU64Size); + + bPoly.encode(b64); + + if (s == 0) + { + cPoly.mult(aPoly, bPoly); + } + else + { + bPoly.multEq(aPoly); + cPoly.addEq(bPoly); + } + } + + // decode c[i] and store it at t64Ptr + cPoly.decode(spanCast(temp128), cache, true); + + // reduce s[i] mod (x^p - 1) and store it at cModP1[i] + modp(cModP1[i], temp128, mPrimeModulus); + } + + AlignedArray tpBuffer; + auto numBlocks = divCeil(mMessageSize, 128); + auto begin = 0 * numBlocks; + auto end = (1) * numBlocks; + for (u64 i = begin; i < end; ++i) + { + u64 j = i * tpBuffer.size(); + auto min = std::min(tpBuffer.size(), mMessageSize - j); + + for (u64 k = 0; k < tpBuffer.size(); ++k) + tpBuffer[k] = cModP1(k, i); + + transpose128(tpBuffer.data()); + + auto end2 = i * tpBuffer.size() + min; + for (u64 k = 0; j < end2; ++j, ++k) + X[j] = X[j] ^ tpBuffer[k]; + } + + } - // for (u64 k = 0; k < tpBuffer.size(); ++k) - // tpBuffer[k] = X[j + k]; - // transpose128(tpBuffer); + DenseMtx getMatrix() + { + + DenseMtx mtx(mCodeSize, mMessageSize); - // auto end = i * tpBuffer.size() + 128; - // for (u64 k = 0; j < end; ++j, ++k) - // X[j] = tpBuffer[k]; - // } + for (u64 i = 0; i < mCodeSize; ++i) + { + std::vector in(mCodeSize); + in[i] = oc::AllOneBlock; - // if (index == 0) - // setTimePoint("sender.expand.qc.transposeXor"); - //} - - //brs[j++].decrementWait(); - - FFTPoly bPoly; - FFTPoly cPoly; - - AlignedUnVector temp128(2 * nBlocks); - - FFTPoly::DecodeCache cache; - for (u64 s = 1; s < mScaler; s += 1) - { - auto a64 = spanCast(temp128).subspan(n64); - PRNG pubPrng(toBlock(s)); - pubPrng.get(a64.data(), a64.size()); - //memset(a64.data(), 0, a64.size() * sizeof(u64)); - //a64[0] = 1; - - a[s - 1].encode(a64); - } - - - //auto multAddReduce = [this, nBlocks, n64, &a, &bPoly, &cPoly, &temp128, &cache](span b128, span dest) - //{ - - //}; - - for (u64 i = 0; i < rows; i += 1) - { - - for (u64 s = 0; s < mScaler-1; ++s) - { - auto& aPoly = a[s]; - auto b64 = spanCast(XT[i]).subspan(s * n64, n64); - - bPoly.encode(b64); - - if (s == 0) - { - cPoly.mult(aPoly, bPoly); - } - else - { - bPoly.multEq(aPoly); - cPoly.addEq(bPoly); - } - } - - // decode c[i] and store it at t64Ptr - cPoly.decode(spanCast(temp128), cache, true); - - //for (u64 j = 0; j < nBlocks; ++j) - // temp128[j] = temp128[j] ^ XT[i][j]; - - // reduce s[i] mod (x^p - 1) and store it at cModP1[i] - modp(cModP1[i], temp128, mP); - } - //multAddReduce(rT[i], cModP1[i]); - - //if (index == 0) - // setTimePoint("sender.expand.qc.mulAddReduce"); - - //brs[j++].decrementWait(); - - { - - AlignedArray tpBuffer; - auto numBlocks = (mP + 127) / 128; - auto begin = 0 * numBlocks / mNumThreads; - auto end = (1) * numBlocks / mNumThreads; - for (u64 i = begin; i < end; ++i) - { - u64 j = i * tpBuffer.size(); - auto min = std::min(tpBuffer.size(), mP - j); - - for (u64 k = 0; k < tpBuffer.size(); ++k) - tpBuffer[k] = cModP1(k, i); - - transpose128(tpBuffer.data()); - - auto end = i * tpBuffer.size() + min; - for (u64 k = 0; j < end; ++j, ++k) - X[j] = X[j] ^ tpBuffer[k]; - } - - //if (index == 0) - // setTimePoint("sender.expand.qc.transposeXor"); - } - }; - - //std::vector thrds(mNumThreads - 1); - //for (u64 i = 0; i < thrds.size(); ++i) - // thrds[i] = std::thread(routine, i); - - //routine(thrds.size()); - - //for (u64 i = 0; i < thrds.size(); ++i) - // thrds[i].join(); - } - - - DenseMtx getMatrix() - { - - DenseMtx mtx(mM, mP); - - - - for (u64 i = 0; i < mM; ++i) - { - std::vector in(mM); - in[i] = oc::AllOneBlock; - - dualEncode(in); - - u64 w = 0; - for (u64 j = 0; j < mP; ++j) - { - if (in[j] == oc::AllOneBlock) - { - ++w; - mtx(i, j) = 1; - } - else if (in[j] == oc::ZeroBlock) - { - } - else - throw RTE_LOC; - } - - if (std::abs((long long)(mP - w)) < mP / 2 - std::sqrt(mP)) - throw RTE_LOC; - } - - return mtx; - } - }; + dualEncode(in); + + u64 w = 0; + for (u64 j = 0; j < mMessageSize; ++j) + { + if (in[j] == oc::AllOneBlock) + { + ++w; + mtx(i, j) = 1; + } + else if (in[j] == oc::ZeroBlock) + { + } + else + throw RTE_LOC; + } + + if (std::abs((long long)(mPrimeModulus - w)) < mPrimeModulus / 2 - std::sqrt(mPrimeModulus)) + throw RTE_LOC; + } + + return mtx; + } + }; } diff --git a/libOTe/Tools/SilentPprf.cpp b/libOTe/Tools/SilentPprf.cpp deleted file mode 100644 index 28bd2850..00000000 --- a/libOTe/Tools/SilentPprf.cpp +++ /dev/null @@ -1,995 +0,0 @@ -#include "SilentPprf.h" -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - -#include -#include -#include -#include -#include - -namespace osuCrypto -{ - - void SilentMultiPprfSender::setBase(span> baseMessages) - { - if (baseOtCount() != static_cast(baseMessages.size())) - throw RTE_LOC; - - mBaseOTs.resize(mPntCount, mDepth); - for (u64 i = 0; i < static_cast(mBaseOTs.size()); ++i) - mBaseOTs(i) = baseMessages[i]; - } - - void SilentMultiPprfReceiver::setBase(span baseMessages) - { - if (baseOtCount() != static_cast(baseMessages.size())) - throw RTE_LOC; - - // The OTs are used in blocks of 8, so make sure that there is a whole - // number of blocks. - mBaseOTs.resize(roundUpTo(mPntCount, 8), mDepth); - memcpy(mBaseOTs.data(), baseMessages.data(), baseMessages.size() * sizeof(block)); - } - - // This function copies the leaf values of the GGM tree - // to the output location. There are two modes for this - // function. If interleaved == false, then each tree is - // copied to a different contiguous regions of the output. - // If interleaved == true, then trees are interleaved such that .... - // @lvl - the GGM tree leafs. - // @output - the location that the GGM leafs should be written to. - // @numTrees - How many trees there are in total. - // @tIdx - the index of the first tree. - // @oFormat - do we interleave the output? - // @mal - ... - void copyOut( - span> lvl, - MatrixView output, - u64 totalTrees, - u64 tIdx, - PprfOutputFormat oFormat, - std::function> lvl)>& callback) - { - if (oFormat == PprfOutputFormat::ByLeafIndex) - { - - auto curSize = std::min(totalTrees - tIdx, 8); - if (curSize == 8) - { - - for (u64 i = 0; i < output.rows(); ++i) - { - auto oi = output[i].subspan(tIdx, 8); - auto& ii = lvl[i]; - oi[0] = ii[0]; - oi[1] = ii[1]; - oi[2] = ii[2]; - oi[3] = ii[3]; - oi[4] = ii[4]; - oi[5] = ii[5]; - oi[6] = ii[6]; - oi[7] = ii[7]; - } - } - else - { - for (u64 i = 0; i < output.rows(); ++i) - { - auto oi = output[i].subspan(tIdx, curSize); - auto& ii = lvl[i]; - for (u64 j = 0; j < curSize; ++j) - oi[j] = ii[j]; - } - } - - } - else if (oFormat == PprfOutputFormat::ByTreeIndex) - { - - auto curSize = std::min(totalTrees - tIdx, 8); - if (curSize == 8) - { - for (u64 i = 0; i < output.cols(); ++i) - { - auto& ii = lvl[i]; - output(tIdx + 0, i) = ii[0]; - output(tIdx + 1, i) = ii[1]; - output(tIdx + 2, i) = ii[2]; - output(tIdx + 3, i) = ii[3]; - output(tIdx + 4, i) = ii[4]; - output(tIdx + 5, i) = ii[5]; - output(tIdx + 6, i) = ii[6]; - output(tIdx + 7, i) = ii[7]; - } - } - else - { - for (u64 i = 0; i < output.cols(); ++i) - { - auto& ii = lvl[i]; - for (u64 j = 0; j < curSize; ++j) - output(tIdx + j, i) = ii[j]; - } - } - - } - else if (oFormat == PprfOutputFormat::Callback) - callback(tIdx, lvl); - else - throw RTE_LOC; - } - - u64 interleavedPoint(u64 point, u64 treeIdx, u64 totalTrees, u64 domain, PprfOutputFormat format) - { - switch (format) - { - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - { - - if (domain <= point) - return ~u64(0); - - auto subTree = treeIdx % 8; - auto forest = treeIdx / 8; - - return (forest * domain + point) * 8 + subTree; - } - break; - default: - throw RTE_LOC; - break; - } - //auto totalTrees = points.size(); - - } - - void interleavedPoints(span points, u64 domain, PprfOutputFormat format) - { - - for (u64 i = 0; i < points.size(); ++i) - { - points[i] = interleavedPoint(points[i], i, points.size(), domain, format); - } - } - - u64 getActivePath(const span& choiceBits) - { - u64 point = 0; - for (u64 i = 0; i < choiceBits.size(); ++i) - { - auto shift = choiceBits.size() - i - 1; - - point |= u64(1 ^ choiceBits[i]) << shift; - } - return point; - } - - void SilentMultiPprfReceiver::getPoints(span points, PprfOutputFormat format) - { - - switch (format) - { - case PprfOutputFormat::ByLeafIndex: - case PprfOutputFormat::ByTreeIndex: - - memset(points.data(), 0, points.size() * sizeof(u64)); - for (u64 j = 0; j < mPntCount; ++j) - { - points[j] = getActivePath(mBaseChoices[j]); - } - - break; - case PprfOutputFormat::Interleaved: - case PprfOutputFormat::Callback: - - if ((u64)points.size() != mPntCount) - throw RTE_LOC; - if (points.size() % 8) - throw RTE_LOC; - - getPoints(points, PprfOutputFormat::ByLeafIndex); - interleavedPoints(points, mDomain, format); - - break; - default: - throw RTE_LOC; - break; - } - } - - - BitVector SilentMultiPprfReceiver::sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng) - { - BitVector choices(mPntCount * mDepth); - - // The points are read in blocks of 8, so make sure that there is a - // whole number of blocks. - mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); - for (u64 i = 0; i < mPntCount; ++i) - { - u64 idx; - switch (format) - { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - } while (idx >= modulus); - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - - if (modulus > mPntCount * mDomain) - throw std::runtime_error("modulus too big. " LOCATION); - if (modulus < mPntCount * mDomain / 2) - throw std::runtime_error("modulus too small. " LOCATION); - - // make sure that at least the first element of this tree - // is within the modulus. - idx = interleavedPoint(0, i, mPntCount, mDomain, format); - if (idx >= modulus) - throw RTE_LOC; - - - do { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = prng.getBit(); - idx = getActivePath(mBaseChoices[i]); - - idx = interleavedPoint(idx, i, mPntCount, mDomain, format); - } while (idx >= modulus); - - - break; - default: - throw RTE_LOC; - break; - } - - } - - for (u64 i = 0; i < mBaseChoices.size(); ++i) - { - choices[i] = mBaseChoices(i); - } - - return choices; - } - - void SilentMultiPprfReceiver::setChoiceBits(PprfOutputFormat format, BitVector choices) - { - // Make sure we're given the right number of OTs. - if (choices.size() != baseOtCount()) - throw RTE_LOC; - - mBaseChoices.resize(roundUpTo(mPntCount, 8), mDepth); - for (u64 i = 0; i < mPntCount; ++i) - { - for (u64 j = 0; j < mDepth; ++j) - mBaseChoices(i, j) = choices[mDepth * i + j]; - - switch (format) - { - case osuCrypto::PprfOutputFormat::ByLeafIndex: - case osuCrypto::PprfOutputFormat::ByTreeIndex: - if (getActivePath(mBaseChoices[i]) >= mDomain) - throw RTE_LOC; - - break; - case osuCrypto::PprfOutputFormat::Interleaved: - case osuCrypto::PprfOutputFormat::Callback: - { - auto idx = getActivePath(mBaseChoices[i]); - auto idx2 = interleavedPoint(idx, i, mPntCount, mDomain, format); - if(idx2 > mPntCount * mDomain) - throw std::runtime_error("the base ot choice bits index outside of the domain. see sampleChoiceBits(...). " LOCATION); - break; - } - default: - throw RTE_LOC; - break; - } - } - } - - namespace - { - // A public PRF/PRG that we will use for deriving the GGM tree. - const std::array gAes = []() { - std::array aes; - aes[0].setKey(toBlock(3242342)); - aes[1].setKey(toBlock(8993849)); - return aes; - }(); - } - - - void SilentMultiPprfSender::expandOne( - block aesSeed, - u64 treeIdx, - bool programActivePath, - span>> levels, - span, 2>> encSums, - span> lastOts) - { - // The number of real trees for this iteration. - auto min = std::min(8, mPntCount - treeIdx); - - // the first level should be size 1, the root of the tree. - // we will populate it with random seeds using aesSeed in counter mode - // based on the tree index. - assert(levels[0].size() == 1); - mAesFixedKey.ecbEncCounterMode(aesSeed ^ block(treeIdx), levels[0][0]); - - assert(encSums.size() == mDepth - programActivePath); - assert(encSums.size() < 24); - - // space for our sums of each level. Should always be less then - // 24 levels... If not increase the limit or make it a vector. - std::array, 2>, 24> sums; - memset(&sums, 0, sizeof(sums)); - - // For each level perform the following. - for (u64 d = 0; d < mDepth; ++d) - { - // The previous level of the GGM tree. - auto level0 = levels[d]; - - // The next level of theGGM tree that we are populating. - auto level1 = levels[d + 1]; - - // The total number of parents in this level. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - // use the optimized approach for intern nodes of the tree - if (d + 1 < mDepth && 0) - { - // For each child, populate the child by expanding the parent. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx, childIdx += 2) - { - // The value of the parent. - auto& parent = level0.data()[parentIdx]; - - auto& child0 = level1.data()[childIdx]; - auto& child1 = level1.data()[childIdx + 1]; - mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); - - child0[0] = child1[0] ^ parent[0]; - child0[1] = child1[1] ^ parent[1]; - child0[2] = child1[2] ^ parent[2]; - child0[3] = child1[3] ^ parent[3]; - child0[4] = child1[4] ^ parent[4]; - child0[5] = child1[5] ^ parent[5]; - child0[6] = child1[6] ^ parent[6]; - child0[7] = child1[7] ^ parent[7]; - - // Update the running sums for this level. We keep - // a left and right totals for each level. - auto& sum = sums[d]; - sum[0][0] = sum[0][0] ^ child0[0]; - sum[0][1] = sum[0][1] ^ child0[1]; - sum[0][2] = sum[0][2] ^ child0[2]; - sum[0][3] = sum[0][3] ^ child0[3]; - sum[0][4] = sum[0][4] ^ child0[4]; - sum[0][5] = sum[0][5] ^ child0[5]; - sum[0][6] = sum[0][6] ^ child0[6]; - sum[0][7] = sum[0][7] ^ child0[7]; - - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; - - sum[1][0] = sum[1][0] ^ child1[0]; - sum[1][1] = sum[1][1] ^ child1[1]; - sum[1][2] = sum[1][2] ^ child1[2]; - sum[1][3] = sum[1][3] ^ child1[3]; - sum[1][4] = sum[1][4] ^ child1[4]; - sum[1][5] = sum[1][5] ^ child1[5]; - sum[1][6] = sum[1][6] ^ child1[6]; - sum[1][7] = sum[1][7] ^ child1[7]; - - } - } - else - { - // for the leaf nodes we need to hash both children. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto& parent = level0.data()[parentIdx]; - - // The bit that indicates if we are on the left child (0) - // or on the right child (1). - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - // The child that we will write in this iteration. - auto& child = level1[childIdx]; - - // The sum that this child node belongs to. - auto& sum = sums[d][keep]; - - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gAes[keep].hashBlocks<8>(parent.data(), child.data()); - - // Update the running sums for this level. We keep - // a left and right totals for each level. - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } - } - } - } - - // For all but the last level, mask the sums with the - // OT strings and send them over. - for (u64 d = 0; d < mDepth - programActivePath; ++d) - { - for (u64 j = 0; j < min; ++j) - { - encSums[d][0][j] = sums[d][0][j] ^ mBaseOTs[treeIdx + j][d][0]; - encSums[d][1][j] = sums[d][1][j] ^ mBaseOTs[treeIdx + j][d][1]; - } - } - - if (programActivePath) - { - // For the last level, we are going to do something special. - // The other party is currently missing both leaf children of - // the active parent. Since this is the last level, we want - // the inactive child to just be the normal value but the - // active child should be the correct value XOR the delta. - // This will be done by sending the sums and the sums plus - // delta and ensure that they can only decrypt the correct ones. - auto d = mDepth - 1; - assert(lastOts.size() == min); - for (u64 j = 0; j < min; ++j) - { - // Construct the sums where we will allow the delta (mValue) - // to either be on the left child or right child depending - // on which has the active path. - lastOts[j][0] = sums[d][0][j]; - lastOts[j][1] = sums[d][1][j] ^ mValue[treeIdx + j]; - lastOts[j][2] = sums[d][1][j]; - lastOts[j][3] = sums[d][0][j] ^ mValue[treeIdx + j]; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - std::array masks, maskIn; - maskIn[0] = mBaseOTs[treeIdx + j][d][0]; - maskIn[1] = mBaseOTs[treeIdx + j][d][0] ^ AllOneBlock; - maskIn[2] = mBaseOTs[treeIdx + j][d][1]; - maskIn[3] = mBaseOTs[treeIdx + j][d][1] ^ AllOneBlock; - mAesFixedKey.hashBlocks<4>(maskIn.data(), masks.data()); - - // Add the OT masks to the sums and send them over. - lastOts[j][0] = lastOts[j][0] ^ masks[0]; - lastOts[j][1] = lastOts[j][1] ^ masks[1]; - lastOts[j][2] = lastOts[j][2] ^ masks[2]; - lastOts[j][3] = lastOts[j][3] ^ masks[3]; - } - } - } - - void allocateExpandBuffer( - u64 depth, - u64 activeChildXorDelta, - std::vector& buff, - span< std::array, 2>>& sums, - span< std::array>& last) - { - - using SumType = std::array, 2>; - using LastType = std::array; - u64 numSums = depth - activeChildXorDelta; - u64 numLast = activeChildXorDelta * 8; - u64 numBlocks = numSums * 16 + numLast * 4; - buff.resize(numBlocks); - sums = span((SumType*)buff.data(), numSums); - last = span((LastType*)(sums.data() + sums.size()), numLast); - - void* sEnd = sums.data() + sums.size(); - void* lEnd = last.data() + last.size(); - void* end = buff.data() + buff.size(); - if (sEnd > end || lEnd > end) - throw RTE_LOC; - } - - void allocateExpandTree( - u64 dpeth, - TreeAllocator& alloc, - span>& tree, - std::vector>>& levels) - { - tree = alloc.get(); - assert((u64)tree.data() % 32 == 0); - levels[0] = tree.subspan(0, 1); - auto rem = tree.subspan(2); - for (auto i : rng(1ull, dpeth)) - { - levels[i] = rem.subspan(0, levels[i - 1].size() * 2); - assert((u64)levels[i].data() % 32 == 0); - rem = rem.subspan(levels[i].size()); - } - } - - void validateExpandFormat( - PprfOutputFormat oFormat, - MatrixView output, - u64 domain, - u64 pntCount - ) - { - - if (oFormat == PprfOutputFormat::ByLeafIndex) - { - if (output.rows() != domain) - throw RTE_LOC; - - if (output.cols() != pntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::ByTreeIndex) - { - if (output.cols() != domain) - throw RTE_LOC; - - if (output.rows() != pntCount) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Interleaved) - { - if (output.cols() != 1) - throw RTE_LOC; - if (domain & 1) - throw RTE_LOC; - - auto rows = output.rows(); - if (rows > (domain * pntCount) || - rows / 128 != (domain * pntCount) / 128) - throw RTE_LOC; - if (pntCount & 7) - throw RTE_LOC; - } - else if (oFormat == PprfOutputFormat::Callback) - { - if (domain & 1) - throw RTE_LOC; - if (pntCount & 7) - throw RTE_LOC; - } - else - { - throw RTE_LOC; - } - } - - - task<> SilentMultiPprfSender::expand( - Socket& chl, - span value, - block seed, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 numThreads) - { - if (activeChildXorDelta) - setValue(value); - - setTimePoint("SilentMultiPprfSender.start"); - - validateExpandFormat(oFormat, output, mDomain, mPntCount); - - MC_BEGIN(task<>, this, numThreads, oFormat, output, seed, &chl, activeChildXorDelta, - i = u64{}, - mTreeAllocDepth = u64{}, - tree = span>{}, - levels = std::vector>>{}, - buff = std::vector{}, - sums = span, 2>>{}, - last = span>{} - ); - - //if (oFormat == PprfOutputFormat::Callback && numThreads > 1) - // throw RTE_LOC; - - mTreeAllocDepth = mDepth + (oFormat != PprfOutputFormat::Interleaved); - mTreeAlloc.reserve(numThreads, (1ull << mTreeAllocDepth) + 2); - setTimePoint("SilentMultiPprfSender.reserve"); - - levels.resize(mDepth + 1); - allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); - - for (i = 0; i < mPntCount; i += 8) - { - // for interleaved format, the last level of the tree - // is simply the output. - if (oFormat == PprfOutputFormat::Interleaved) - { - auto b = (AlignedArray*)output.data(); - auto forest = i / 8; - b += forest * mDomain; - - levels.back() = span>(b, mDomain); - } - - // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); - - // exapnd the tree - expandOne(seed, i, activeChildXorDelta, levels, sums, last); - - MC_AWAIT(chl.send(std::move(buff))); - - // if we aren't interleaved, we need to copy the - // last layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) - copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); - } - - mBaseOTs = {}; - mTreeAlloc.del(tree); - mTreeAlloc.clear(); - - setTimePoint("SilentMultiPprfSender.de-alloc"); - - MC_END(); - } - - void SilentMultiPprfSender::setValue(span value) - { - mValue.resize(mPntCount); - - if (value.size() == 1) - { - std::fill(mValue.begin(), mValue.end(), value[0]); - } - else - { - if ((u64)value.size() != mPntCount) - throw RTE_LOC; - - std::copy(value.begin(), value.end(), mValue.begin()); - } - } - - void SilentMultiPprfSender::clear() - { - mBaseOTs.resize(0, 0); - mDomain = 0; - mDepth = 0; - mPntCount = 0; - } - - void SilentMultiPprfReceiver::expandOne( - - u64 treeIdx, - bool programActivePath, - span>> levels, - span, 2>> theirSums, - span> lastOts) - { - // This thread will process 8 trees at a time. - - // special case for the first level. - auto l1 = levels[1]; - for (u64 i = 0; i < 8; ++i) - { - - // For the non-active path, set the child of the root node - // as the OT message XOR'ed with the correction sum. - int notAi = mBaseChoices[i + treeIdx][0]; - l1[notAi][i] = mBaseOTs[i + treeIdx][0] ^ theirSums[0][notAi][i]; - l1[notAi ^ 1][i] = ZeroBlock; - } - - // space for our sums of each level. - std::array, 2> mySums; - - // For all other levels, expand the GGM tree and add in - // the correction along the active path. - for (u64 d = 1; d < mDepth; ++d) - { - // The already constructed level. Only missing the - // GGM tree node value along the active path. - auto level0 = levels[d]; - - // The next level that we want to construct. - auto level1 = levels[d + 1]; - - // Zero out the previous sums. - memset(mySums.data(), 0, sizeof(mySums)); - - // We will iterate over each node on this level and - // expand it into it's two children. Note that the - // active node will also be expanded. Later we will just - // overwrite whatever the value was. This is an optimization. - auto width = divCeil(mDomain, 1ull << (mDepth - d)); - - // for internal nodes we the optimized approach. - if (d + 1 < mDepth && 0) - { - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto parent = level0[parentIdx]; - - auto& child0 = level1.data()[childIdx]; - auto& child1 = level1.data()[childIdx + 1]; - mAesFixedKey.ecbEncBlocks<8>(parent.data(), child1.data()); - - child0[0] = child1[0] ^ parent[0]; - child0[1] = child1[1] ^ parent[1]; - child0[2] = child1[2] ^ parent[2]; - child0[3] = child1[3] ^ parent[3]; - child0[4] = child1[4] ^ parent[4]; - child0[5] = child1[5] ^ parent[5]; - child0[6] = child1[6] ^ parent[6]; - child0[7] = child1[7] ^ parent[7]; - - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent (assuming !DEBUG_PRINT_PPRF). - // This is ok since we will later XOR off these incorrect values. - mySums[0][0] = mySums[0][0] ^ child0[0]; - mySums[0][1] = mySums[0][1] ^ child0[1]; - mySums[0][2] = mySums[0][2] ^ child0[2]; - mySums[0][3] = mySums[0][3] ^ child0[3]; - mySums[0][4] = mySums[0][4] ^ child0[4]; - mySums[0][5] = mySums[0][5] ^ child0[5]; - mySums[0][6] = mySums[0][6] ^ child0[6]; - mySums[0][7] = mySums[0][7] ^ child0[7]; - - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; - - mySums[1][0] = mySums[1][0] ^ child1[0]; - mySums[1][1] = mySums[1][1] ^ child1[1]; - mySums[1][2] = mySums[1][2] ^ child1[2]; - mySums[1][3] = mySums[1][3] ^ child1[3]; - mySums[1][4] = mySums[1][4] ^ child1[4]; - mySums[1][5] = mySums[1][5] ^ child1[5]; - mySums[1][6] = mySums[1][6] ^ child1[6]; - mySums[1][7] = mySums[1][7] ^ child1[7]; - } - } - else - { - // for leaf nodes both children should be hashed. - for (u64 parentIdx = 0, childIdx = 0; parentIdx < width; ++parentIdx) - { - // The value of the parent. - auto parent = level0[parentIdx]; - - for (u64 keep = 0; keep < 2; ++keep, ++childIdx) - { - // The child that we will write in this iteration. - auto& child = level1[childIdx]; - - // Each parent is expanded into the left and right children - // using a different AES fixed-key. Therefore our OWF is: - // - // H(x) = (AES(k0, x) + x) || (AES(k1, x) + x); - // - // where each half defines one of the children. - gAes[keep].hashBlocks<8>(parent.data(), child.data()); - - // Update the running sums for this level. We keep - // a left and right totals for each level. Note that - // we are actually XOR in the incorrect value of the - // children of the active parent (assuming !DEBUG_PRINT_PPRF). - // This is ok since we will later XOR off these incorrect values. - auto& sum = mySums[keep]; - sum[0] = sum[0] ^ child[0]; - sum[1] = sum[1] ^ child[1]; - sum[2] = sum[2] ^ child[2]; - sum[3] = sum[3] ^ child[3]; - sum[4] = sum[4] ^ child[4]; - sum[5] = sum[5] ^ child[5]; - sum[6] = sum[6] ^ child[6]; - sum[7] = sum[7] ^ child[7]; - } - } - } - - // For everything but the last level we have to - // 1) fix our sums so they dont include the incorrect - // values that are the children of the active parent - // 2) Update the non-active child of the active parent. - if (!programActivePath || d != mDepth - 1) - { - for (u64 i = 0; i < 8; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = mPoints[i + treeIdx]; - - // The index of the active child node. - auto activeChildIdx = leafIdx >> (mDepth - 1 - d); - - // The index of the active child node sibling. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; - - auto& inactiveChild = level1[inactiveChildIdx][i]; - - // correct the sum value by XORing off the incorrect - auto correctSum = - inactiveChild ^ - theirSums[d][notAi][i]; - - inactiveChild = - correctSum ^ - mySums[notAi][i] ^ - mBaseOTs[i + treeIdx][d]; - - } - } - } - - // last level. - auto level = levels[mDepth]; - if (programActivePath) - { - // Now processes the last level. This one is special - // because we must XOR in the correction value as - // before but we must also fixed the child value for - // the active child. To do this, we will receive 4 - // values. Two for each case (left active or right active). - //timer.setTimePoint("recv.recvLast"); - - auto d = mDepth - 1; - for (u64 j = 0; j < 8; ++j) - { - // The index of the child on the active path. - auto activeChildIdx = mPoints[j + treeIdx]; - - // The index of the other (inactive) child. - auto inactiveChildIdx = activeChildIdx ^ 1; - - // The indicator as to the left or right child is inactive - auto notAi = inactiveChildIdx & 1; - - std::array masks, maskIn; - - // We are going to expand the 128 bit OT string - // into a 256 bit OT string using AES. - maskIn[0] = mBaseOTs[j + treeIdx][d]; - maskIn[1] = mBaseOTs[j + treeIdx][d] ^ AllOneBlock; - mAesFixedKey.hashBlocks<2>(maskIn.data(), masks.data()); - - // now get the chosen message OT strings by XORing - // the expended (random) OT strings with the lastOts values. - auto& ot0 = lastOts[j][2 * notAi + 0]; - auto& ot1 = lastOts[j][2 * notAi + 1]; - ot0 = ot0 ^ masks[0]; - ot1 = ot1 ^ masks[1]; - - auto& inactiveChild = level[inactiveChildIdx][j]; - auto& activeChild = level[activeChildIdx][j]; - - // Fix the sums we computed previously to not include the - // incorrect child values. - auto inactiveSum = mySums[notAi][j] ^ inactiveChild; - auto activeSum = mySums[notAi ^ 1][j] ^ activeChild; - - // Update the inactive and active child to have to correct - // value by XORing their full sum with out partial sum, which - // gives us exactly the value we are missing. - inactiveChild = ot0 ^ inactiveSum; - activeChild = ot1 ^ activeSum; - } - // pprf.setTimePoint("SilentMultiPprfReceiver.last " + std::to_string(treeIdx)); - - //timer.setTimePoint("recv.expandLast"); - } - else - { - for (auto j : rng(std::min(8, mPntCount - treeIdx))) - { - // The index of the child on the active path. - auto activeChildIdx = mPoints[j + treeIdx]; - level[activeChildIdx][j] = ZeroBlock; - } - } - } - - task<> SilentMultiPprfReceiver::expand( - Socket& chl, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 _) - { - validateExpandFormat(oFormat, output, mDomain, mPntCount); - - MC_BEGIN(task<>, this, oFormat, output, &chl, activeChildXorDelta, - i = u64{}, - mTreeAllocDepth = u64{}, - tree = span>{}, - levels = std::vector>>{}, - buff = std::vector{}, - sums = span, 2>>{}, - last = span>{} - ); - - setTimePoint("SilentMultiPprfReceiver.start"); - mPoints.resize(roundUpTo(mPntCount, 8)); - getPoints(mPoints, PprfOutputFormat::ByLeafIndex); - - mTreeAllocDepth = mDepth + (oFormat != PprfOutputFormat::Interleaved); - mTreeAlloc.reserve(1, (1ull << mTreeAllocDepth) + 2); - setTimePoint("SilentMultiPprfSender.reserve"); - - levels.resize(mDepth + 1); - allocateExpandTree(mTreeAllocDepth, mTreeAlloc, tree, levels); - - for (i = 0; i < mPntCount; i += 8) - { - // for interleaved format, the last level of the tree - // is simply the output. - if (oFormat == PprfOutputFormat::Interleaved) - { - auto b = (AlignedArray*)output.data(); - auto forest = i / 8; - b += forest * mDomain; - - levels.back() = span>(b, mDomain); - } - - // allocate the send buffer and partition it. - allocateExpandBuffer(mDepth, activeChildXorDelta, buff, sums, last); - - MC_AWAIT(chl.recv(buff)); - - // exapnd the tree - expandOne(i, activeChildXorDelta, levels, sums, last); - - // if we aren't interleaved, we need to copy the - // last layer to the output. - if (oFormat != PprfOutputFormat::Interleaved) - copyOut(levels.back(), output, mPntCount, i, oFormat, mOutputFn); - } - - setTimePoint("SilentMultiPprfReceiver.join"); - - mBaseOTs = {}; - mTreeAlloc.del(tree); - mTreeAlloc.clear(); - - setTimePoint("SilentMultiPprfReceiver.de-alloc"); - - MC_END(); - } - -} - -#endif diff --git a/libOTe/Tools/SilentPprf.h b/libOTe/Tools/SilentPprf.h deleted file mode 100644 index b5ffbd4d..00000000 --- a/libOTe/Tools/SilentPprf.h +++ /dev/null @@ -1,314 +0,0 @@ -#pragma once -// © 2020 Peter Rindal. -// © 2022 Visa. -// © 2022 Lawrence Roy. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - -#include -#include -#include -#include -#include -#include -#include "libOTe/Tools/Coproto.h" -#include -//#define DEBUG_PRINT_PPRF - -namespace osuCrypto -{ - - // the various formats that the output of the - // Pprf can be generated. - enum class PprfOutputFormat - { - // The i'th row holds the i'th leaf for all trees. - // The j'th tree is in the j'th column. - ByLeafIndex, - - // The i'th row holds the i'th tree. - // The j'th leaf is in the j'th column. - ByTreeIndex, - - // The native output mode. The output will be - // a single row with all leaf values. - // Every 8 trees are mixed together where the - // i'th leaf for each of the 8 tree will be next - // to each other. For example, let tij be the j'th - // leaf of the i'th tree. If we have m leaves, then - // - // t00 t10 ... t70 t01 t11 ... t71 ... t0m t1m ... t7m - // t80 t90 ... t_{15,0} t81 t91 ... t_{15,1} ... t8m t9m ... t_{15,m} - // ... - // - // These are all flattened into a single row. - Interleaved, - - // call the user's callback. The leaves will be in - // Interleaved format. - Callback - }; - - enum class OTType - { - Random, Correlated - }; - - enum class ChoiceBitPacking - { - False, True - }; - - enum class SilentSecType - { - SemiHonest, - Malicious, - //MaliciousFS - }; - - struct TreeAllocator - { - TreeAllocator() = default; - TreeAllocator(const TreeAllocator&) = delete; - TreeAllocator(TreeAllocator&&) = delete; - - using ValueType = AlignedArray; - std::list> mTrees; - std::vector> mFreeTrees; - std::mutex mMutex; - u64 mTreeSize = 0, mNumTrees = 0; - - void reserve(u64 num, u64 size) - { - std::lock_guard lock(mMutex); - mTreeSize = size; - mNumTrees += num; - mTrees.clear(); - mFreeTrees.clear(); - mTrees.emplace_back(num * size); - auto iter = mTrees.back().data(); - for (u64 i = 0; i < num; ++i) - { - mFreeTrees.push_back(span(iter, size)); - assert((u64)mFreeTrees.back().data() % 32 == 0); - iter += size; - } - } - - span get() - { - std::lock_guard lock(mMutex); - if (mFreeTrees.size() == 0) - { - assert(mTreeSize); - mTrees.emplace_back(mTreeSize); - mFreeTrees.push_back(span(mTrees.back().data(), mTreeSize)); - assert((u64)mFreeTrees.back().data() % 32 == 0); - ++mNumTrees; - } - - auto ret = mFreeTrees.back(); - mFreeTrees.pop_back(); - return ret; - } - - void del(span uPtr) - { - std::lock_guard lock(mMutex); - mFreeTrees.push_back(uPtr); - } - - void clear() - { - assert(mNumTrees == mFreeTrees.size()); - mTrees = {}; - mFreeTrees = {}; - mTreeSize = 0; - mNumTrees = 0; - } - }; - - - void allocateExpandBuffer( - u64 depth, - u64 activeChildXorDelta, - std::vector& buff, - span< std::array, 2>>& sums, - span< std::array>& last); - - void allocateExpandTree( - u64 dpeth, - TreeAllocator& alloc, - span>& tree, - std::vector>>& levels); - - class SilentMultiPprfSender : public TimerAdapter - { - public: - u64 mDomain = 0, mDepth = 0, mPntCount = 0; - std::vector mValue; - bool mPrint = false; - TreeAllocator mTreeAlloc; - Matrix> mBaseOTs; - - std::function>)> mOutputFn; - - SilentMultiPprfSender() = default; - SilentMultiPprfSender(const SilentMultiPprfSender&) = delete; - SilentMultiPprfSender(SilentMultiPprfSender&&) = delete; - - SilentMultiPprfSender(u64 domainSize, u64 pointCount) - { - configure(domainSize, pointCount); - } - - void configure(u64 domainSize, u64 pointCount) - { - mDomain = domainSize; - mDepth = log2ceil(mDomain); - mPntCount = pointCount; - //mPntCount8 = roundUpTo(pointCount, 8); - - mBaseOTs.resize(0, 0); - } - - - // the number of base OTs that should be set. - u64 baseOtCount() const - { - return mDepth * mPntCount; - } - - // returns true if the base OTs are currently set. - bool hasBaseOts() const - { - return mBaseOTs.size(); - } - - - void setBase(span> baseMessages); - - task<> expand(Socket& chls, span value, block seed, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) - { - MatrixView o(output.data(), output.size(), 1); - return expand(chls, value, seed, o, oFormat, activeChildXorDelta, numThreads); - } - - task<> expand( - Socket& chl, - span value, - block seed, - MatrixView output, - PprfOutputFormat oFormat, - bool activeChildXorDelta, - u64 numThreads); - - void setValue(span value); - - void clear(); - - void expandOne( - block aesSeed, - u64 treeIdx, - bool activeChildXorDelta, - span>> levels, - span, 2>> sums, - span> lastOts - ); - }; - - - class SilentMultiPprfReceiver : public TimerAdapter - { - public: - u64 mDomain = 0, mDepth = 0, mPntCount = 0; - - std::vector mPoints; - - Matrix mBaseOTs; - Matrix mBaseChoices; - bool mPrint = false; - TreeAllocator mTreeAlloc; - block mDebugValue; - std::function>)> mOutputFn; - - SilentMultiPprfReceiver() = default; - SilentMultiPprfReceiver(const SilentMultiPprfReceiver&) = delete; - SilentMultiPprfReceiver(SilentMultiPprfReceiver&&) = delete; - - void configure(u64 domainSize, u64 pointCount) - { - mDomain = domainSize; - mDepth = log2ceil(mDomain); - mPntCount = pointCount; - - mBaseOTs.resize(0, 0); - } - - // For output format ByLeafIndex or ByTreeIndex, the choice bits it - // samples are in blocks of mDepth, with mPntCount blocks total (one for - // each punctured point). For ByLeafIndex these blocks encode the punctured - // leaf index in big endian, while for ByTreeIndex they are in - // little endian. - BitVector sampleChoiceBits(u64 modulus, PprfOutputFormat format, PRNG& prng); - - // choices is in the same format as the output from sampleChoiceBits. - void setChoiceBits(PprfOutputFormat format, BitVector choices); - - // the number of base OTs that should be set. - u64 baseOtCount() const - { - return mDepth * mPntCount; - } - - // returns true if the base OTs are currently set. - bool hasBaseOts() const - { - return mBaseOTs.size(); - } - - void setBase(span baseMessages); - - std::vector getPoints(PprfOutputFormat format) - { - std::vector pnts(mPntCount); - getPoints(pnts, format); - return pnts; - } - void getPoints(span points, PprfOutputFormat format); - - task<> expand(Socket& chl, span output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads) - { - MatrixView o(output.data(), output.size(), 1); - return expand(chl, o, oFormat, activeChildXorDelta, numThreads); - } - - // activeChildXorDelta says whether the sender is trying to program the - // active child to be its correct value XOR delta. If it is not, the - // active child will just take a random value. - task<> expand(Socket& chl, MatrixView output, PprfOutputFormat oFormat, bool activeChildXorDelta, u64 numThreads); - - void clear() - { - mBaseOTs.resize(0, 0); - mBaseChoices.resize(0, 0); - mDomain = 0; - mDepth = 0; - mPntCount = 0; - } - - void expandOne( - u64 treeIdx, - bool programActivePath, - span>> levels, - span, 2>> encSums, - span> lastOts); - }; -} -#endif diff --git a/libOTe/TwoChooseOne/ConfigureCode.cpp b/libOTe/TwoChooseOne/ConfigureCode.cpp index c85019e7..110227c5 100644 --- a/libOTe/TwoChooseOne/ConfigureCode.cpp +++ b/libOTe/TwoChooseOne/ConfigureCode.cpp @@ -4,37 +4,13 @@ #include "cryptoTools/Common/Range.h" #include "libOTe/TwoChooseOne/TcoOtDefines.h" #include "libOTe/Tools/Tools.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" #include "libOTe/Tools/QuasiCyclicCode.h" #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" +#include "libOTe/Tools/ExConvCode/ExConvCode.h" #include namespace osuCrypto { - //u64 secLevel(u64 scale, u64 n, u64 points) - //{ - // auto x1 = std::log2(scale * n / double(n)); - // auto x2 = std::log2(scale * n) / 2; - // return static_cast(points * x1 + x2); - //} - - //u64 getPartitions(u64 scaler, u64 n, u64 secParam) - //{ - // if (scaler < 2) - // throw std::runtime_error("scaler must be 2 or greater"); - - // u64 ret = 1; - // auto ss = secLevel(scaler, n, ret); - // while (ss < secParam) - // { - // ++ret; - // ss = secLevel(scaler, n, ret); - // if (ret > 1000) - // throw std::runtime_error("failed to find silent OT parameters"); - // } - // return roundUpTo(ret, 8); - //} - // We get e^{-2td} security against linear attacks, // with noise weigh t and minDist d. @@ -142,6 +118,36 @@ namespace osuCrypto mEncoder.config(numOTs, numOTs * mScaler, w, a, true); } + + void ExConvConfigure( + double scaler, + MultType mMultType, + u64& expanderWeight, + u64& accumulatorWeight, + double& minDist) + { + if (scaler != 2) + throw RTE_LOC; + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + accumulatorWeight = 24; + expanderWeight = 7; + minDist = 0.2; // psuedo min dist estimate + break; + case osuCrypto::MultType::ExConv21x24: + accumulatorWeight = 24; + expanderWeight = 21; + minDist = 0.25; // psuedo min dist estimate + break; + default: + throw RTE_LOC; + break; + } + + } + + #ifdef ENABLE_INSECURE_SILVER void SilverConfigure( diff --git a/libOTe/TwoChooseOne/ConfigureCode.h b/libOTe/TwoChooseOne/ConfigureCode.h index 47cbaa53..720e75d5 100644 --- a/libOTe/TwoChooseOne/ConfigureCode.h +++ b/libOTe/TwoChooseOne/ConfigureCode.h @@ -10,11 +10,7 @@ namespace osuCrypto { // https://eprint.iacr.org/2019/1159.pdf QuasiCyclic = 1, -#ifdef ENABLE_INSECURE_SILVER - // https://eprint.iacr.org/2021/1150, see https://eprint.iacr.org/2023/882 for attack. - slv5 = 2, - slv11 = 3, -#endif + // https://eprint.iacr.org/2022/1014 ExAcc7 = 4, // fast ExAcc11 = 5,// fast but more conservative @@ -33,14 +29,7 @@ namespace osuCrypto case osuCrypto::MultType::QuasiCyclic: o << "QuasiCyclic"; break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - o << "slv5"; - break; - case osuCrypto::MultType::slv11: - o << "slv11"; - break; -#endif + case osuCrypto::MultType::ExAcc7: o << "ExAcc7"; break; @@ -105,19 +94,14 @@ namespace osuCrypto ExConvCode& mEncoder ); -#ifdef ENABLE_INSECURE_SILVER - struct SilverEncoder; - void SilverConfigure( - u64 numOTs, u64 secParam, + + void ExConvConfigure( + double scaler, MultType mMultType, - u64& mRequestedNumOTs, - u64& mNumPartitions, - u64& mSizePer, - u64& mN2, - u64& mN, - u64& gap, - SilverEncoder& mEncoder); -#endif + u64& expanderWeight, + u64& accumulatorWeight, + double& minDist + ); void QuasiCyclicConfigure( u64 numOTs, u64 secParam, @@ -130,4 +114,15 @@ namespace osuCrypto u64& mN, u64& mP, u64& mScaler); + + + inline void QuasiCyclicConfigure( + double mScaler, + double& minDist) + { + if (mScaler == 2) + minDist = 0.2; // psuedo min dist + else + throw RTE_LOC; // not impl + } } \ No newline at end of file diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index 3107f0db..c473459f 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -3,7 +3,6 @@ #include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" #include -#include #include #include @@ -13,6 +12,7 @@ #include #include #include "libOTe/Tools/QuasiCyclicCode.h" +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { @@ -68,11 +68,9 @@ namespace osuCrypto throw std::runtime_error("wrong number of silent base OTs"); auto genOts = recvBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOts = recvBaseOts.subspan(genOts.size(), mGapOts.size()); - auto malOts = recvBaseOts.subspan(genOts.size() + mGapOts.size()); + auto malOts = recvBaseOts.subspan(genOts.size()); mGen.setBase(genOts); - std::copy(gapOts.begin(), gapOts.end(), mGapOts.begin()); std::copy(malOts.begin(), malOts.end(), mMalCheckOts.begin()); } @@ -115,14 +113,7 @@ namespace osuCrypto if (isConfigured() == false) throw std::runtime_error("configure(...) must be called first"); - auto choice = mGen.sampleChoiceBits(mN2, getPprfFormat(), prng); - - if (mGapOts.size()) - { - mGapBaseChoice.resize(mGapOts.size()); - mGapBaseChoice.randomize(prng); - choice.append(mGapBaseChoice); - } + auto choice = mGen.sampleChoiceBits(prng); mS.resize(mNumPartitions); mGen.getPoints(mS, getPprfFormat()); @@ -210,7 +201,6 @@ namespace osuCrypto throw std::runtime_error("configure must be called first"); return mGen.baseOtCount() + - mGapOts.size() + (mMalType == SilentSecType::Malicious) * 128; } @@ -223,7 +213,6 @@ namespace osuCrypto { mMalType = malType; mNumThreads = numThreads; - mGapOts.resize(0); switch (mMultType) { @@ -240,28 +229,6 @@ namespace osuCrypto mScaler); break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - { - if (scaler != 2) - throw std::runtime_error("only scaler = 2 is supported for slv. " LOCATION); - - u64 gap; - SilverConfigure(numOTs, 128, - mMultType, - mRequestedNumOts, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - - mGapOts.resize(gap); - break; - } -#endif case osuCrypto::MultType::ExAcc7: case osuCrypto::MultType::ExAcc11: case osuCrypto::MultType::ExAcc21: @@ -470,61 +437,24 @@ namespace osuCrypto mA.resize(mN2); mC.resize(0); - //// do the compression to get the final OTs. - //if (mMultType == MultType::QuasiCyclic) - //{ - // rT = MatrixView(mA.data(), 128, mN2 / 128); - // // locally expand the seeds. - // MC_AWAIT(mGen.expand(chl, prng, rT, PprfOutputFormat::InterleavedTransposed, mNumThreads)); - // setTimePoint("recver.expand.pprf_transpose"); + MC_AWAIT(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); + setTimePoint("recver.expand.pprf_transpose"); + gTimer.setTimePoint("recver.expand.pprf_transpose"); - // if (mDebug) - // { - // MC_AWAIT(checkRT(chl, rT)); - // } - - // randMulQuasiCyclic(type); - - //} - //else - { - - main = mNumPartitions * mSizePer; - if (mGapOts.size()) - { - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - MC_AWAIT(chl.recv(gapVals)); - for (i = main, j = 0; i < mN2; ++i, ++j) - { - if (mGapBaseChoice[j]) - mA[i] = AES(mGapOts[j]).ecbEncBlock(ZeroBlock) ^ gapVals[j]; - else - mA[i] = mGapOts[j]; - } - } - - - MC_AWAIT(mGen.expand(chl, mA.subspan(0, main), PprfOutputFormat::Interleaved, true, mNumThreads)); - setTimePoint("recver.expand.pprf_transpose"); - gTimer.setTimePoint("recver.expand.pprf_transpose"); - - - if (mMalType == SilentSecType::Malicious) - MC_AWAIT(ferretMalCheck(chl, prng)); + if (mMalType == SilentSecType::Malicious) + MC_AWAIT(ferretMalCheck(chl, prng)); - if (mDebug) - { - rT = MatrixView(mA.data(), mN2, 1); - MC_AWAIT(checkRT(chl, rT)); - } - compress(type); + if (mDebug) + { + rT = MatrixView(mA.data(), mN2, 1); + MC_AWAIT(checkRT(chl, rT)); } + compress(type); + mA.resize(mRequestedNumOts); if (mC.size()) @@ -543,9 +473,10 @@ namespace osuCrypto sum0 = block{}, sum1 = block{}, mySum = block{}, - deltaShare = block{}, + b = AlignedUnVector(1), + //deltaShare = block{}, i = u64{}, - sender = NoisyVoleSender{}, + sender = NoisyVoleSender{}, theirHash = std::array{}, myHash = std::array{}, ro = RandomOracle(32) @@ -570,9 +501,9 @@ namespace osuCrypto mySum = sum0.gf128Reduce(sum1); - MC_AWAIT(sender.send(mMalCheckX, { &deltaShare,1 }, prng, mMalCheckOts, chl)); + MC_AWAIT(sender.send(mMalCheckX, b, prng, mMalCheckOts, chl, {})); ; - ro.Update(mySum ^ deltaShare); + ro.Update(mySum ^ b[0]); ro.Final(myHash); MC_AWAIT(chl.recv(theirHash)); @@ -727,7 +658,7 @@ namespace osuCrypto case osuCrypto::MultType::ExAcc40: { AlignedUnVector A2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mA.subspan(0, mEAEncoder.mCodeSize), A2); + mEAEncoder.dualEncode(mA.subspan(0, mEAEncoder.mCodeSize), A2, {}); std::swap(mA, A2); break; } @@ -735,7 +666,7 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mA.subspan(0, mExConvEncoder.generatorCols())); + mExConvEncoder.dualEncode(mA.begin(), {}); break; default: throw RTE_LOC; @@ -747,13 +678,6 @@ namespace osuCrypto } else { - // allocate and initialize mC - //if (mChoiceSpanSize < mN2) - //{ - // mChoiceSpanSize = mN2; - // mChoicePtr.reset((new u8[mN2]())); - //} - //else mC.resize(mN2); std::memset(mC.data(), 0, mN2); auto cc = mC.data(); @@ -788,9 +712,10 @@ namespace osuCrypto { AlignedUnVector A2(mEAEncoder.mMessageSize); AlignedUnVector C2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode2( + mEAEncoder.dualEncode2( mA.subspan(0, mEAEncoder.mCodeSize), A2, - mC.subspan(0, mEAEncoder.mCodeSize), C2); + mC.subspan(0, mEAEncoder.mCodeSize), C2, + {}); std::swap(mA, A2); std::swap(mC, C2); @@ -801,10 +726,10 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode2( - mA.subspan(0, mExConvEncoder.mCodeSize), - mC.subspan(0, mExConvEncoder.mCodeSize) - ); + mExConvEncoder.dualEncode2( + mA.begin(), + mC.begin(), + {}); break; default: throw RTE_LOC; @@ -827,8 +752,6 @@ namespace osuCrypto mGen.clear(); - mGapOts = {}; - mS = {}; } diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h index ec43ff6c..0ac227e5 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h @@ -16,14 +16,14 @@ #include #include #include -#include +#include #include #include -#include #include #include #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" +#include "SilentOtExtUtil.h" namespace osuCrypto { @@ -70,13 +70,9 @@ namespace osuCrypto macoro::optional mOtExtRecver; #endif - // The OTs recv msgs which will be used to flood the - // last gap bits of the noisy vector for the slv code. - std::vector mGapOts; - // The OTs recv msgs which will be used to create the // secret share of xa * delta as described in ferret. - std::vector mMalCheckOts; + AlignedUnVector mMalCheckOts; // The OTs choice bits which will be used to flood the // last gap bits of the noisy vector for the slv code. @@ -95,7 +91,7 @@ namespace osuCrypto block mMalCheckX = ZeroBlock; // The ggm tree thats used to generate the sparse vectors. - SilentMultiPprfReceiver mGen; + RegularPprfReceiver mGen; // The type of compress we will use to generate the // dense vectors from the sparse vectors. diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp index 7fb8e32d..73a927a5 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp @@ -135,7 +135,7 @@ namespace osuCrypto if (isConfigured() == false) throw std::runtime_error("configure must be called first"); - auto n = mGen.baseOtCount() + mGapOts.size(); + auto n = mGen.baseOtCount(); if (mMalType == SilentSecType::Malicious) n += 128; @@ -151,12 +151,10 @@ namespace osuCrypto throw RTE_LOC; auto genOt = sendBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOt = sendBaseOts.subspan(genOt.size(), mGapOts.size()); - auto malOt = sendBaseOts.subspan(genOt.size() + gapOt.size()); + auto malOt = sendBaseOts.subspan(genOt.size()); mMalCheckOts.resize((mMalType == SilentSecType::Malicious) * 128); mGen.setBase(genOt); - std::copy(gapOt.begin(), gapOt.end(), mGapOts.begin()); std::copy(malOt.begin(), malOt.end(), mMalCheckOts.begin()); } @@ -166,7 +164,6 @@ namespace osuCrypto mMalType = malType; mNumThreads = numThreads; - mGapOts.resize(0); switch (mMultType) { @@ -183,28 +180,6 @@ namespace osuCrypto mScaler); break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - { - if (scaler != 2) - throw std::runtime_error("only scaler = 2 is supported for slv. " LOCATION); - - u64 gap; - SilverConfigure(numOTs, 128, - mMultType, - mRequestNumOts, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - - mGapOts.resize(gap); - break; - } -#endif case osuCrypto::MultType::ExAcc7: case osuCrypto::MultType::ExAcc11: case osuCrypto::MultType::ExAcc21: @@ -249,8 +224,6 @@ namespace osuCrypto mDelta = block(0,0); - mGapOts = {}; - mGen.clear(); } @@ -392,9 +365,8 @@ namespace osuCrypto Socket& chl) { MC_BEGIN(task<>,this, d, n, &prng, &chl, - rT = MatrixView{}, - gapVals = std::vector {}, - i = u64{}, j = u64{}, main = u64{} + i = u64{}, j = u64{}, + delta = AlignedUnVector{} ); gTimer.setTimePoint("sender.ot.enter"); @@ -420,54 +392,23 @@ namespace osuCrypto // allocate b mB.resize(mN2); + + delta.resize(1); + delta[0] = mDelta; - //if (mMultType == MultType::QuasiCyclic) - //{ - // rT = MatrixView(mB.data(), 128, mN2 / 128); - - // MC_AWAIT(mGen.expand(chl, mDelta, prng, rT, PprfOutputFormat::InterleavedTransposed, mNumThreads)); - // setTimePoint("sender.expand.pprf_transpose"); - // gTimer.setTimePoint("sender.expand.pprf_transpose"); - - // if (mDebug) - // MC_AWAIT(checkRT(chl)); - - // randMulQuasiCyclic(); - //} - //else - { - - main = mNumPartitions * mSizePer; - if (mGapOts.size()) - { - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - for (i = main, j = 0; i < mN2; ++i, ++j) - { - auto v = mGapOts[j][0] ^ mDelta; - gapVals[j] = AES(mGapOts[j][1]).ecbEncBlock(ZeroBlock) ^ v; - mB[i] = mGapOts[j][0]; - //std::cout << "jj " << j << " " <(1), + c = AlignedUnVector(1), + //deltaShare = block{}, i = u64{}, - recver = NoisyVoleReceiver{}, + recver = NoisyVoleReceiver{}, myHash = std::array{}, ro = RandomOracle(32) ); @@ -506,11 +449,11 @@ namespace osuCrypto mySum = sum0.gf128Reduce(sum1); + c[0] = mDelta; + //a[0] = deltaShare; + MC_AWAIT(recver.receive(c, a, prng, mMalCheckOts, chl, {})); - - MC_AWAIT(recver.receive({ &mDelta,1 }, { &deltaShare,1 }, prng, mMalCheckOts, chl)); - - ro.Update(mySum ^ deltaShare); + ro.Update(mySum ^ a[0]); ro.Final(myHash); MC_AWAIT(chl.send(std::move(myHash))); @@ -534,17 +477,6 @@ namespace osuCrypto #endif } break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - if (mTimer) - mEncoder.setTimer(getTimer()); - mEncoder.dualEncode(mB); - setTimePoint("sender.expand.ldpc.dualEncode"); - - break; -#endif case osuCrypto::MultType::ExAcc7: case osuCrypto::MultType::ExAcc11: case osuCrypto::MultType::ExAcc21: @@ -553,7 +485,7 @@ namespace osuCrypto if (mTimer) mEAEncoder.setTimer(getTimer()); AlignedUnVector B2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mB.subspan(0, mEAEncoder.mCodeSize), B2); + mEAEncoder.dualEncode(mB.subspan(0, mEAEncoder.mCodeSize), B2, {}); std::swap(mB, B2); break; } @@ -561,14 +493,14 @@ namespace osuCrypto case osuCrypto::MultType::ExConv21x24: if (mTimer) mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); + mExConvEncoder.dualEncode(mB.begin(), {}); break; default: throw RTE_LOC; break; } - + } // // diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h index 36247933..88ee73e5 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.h @@ -15,14 +15,14 @@ #include #include #include -#include +#include #include #include #include -#include #include #include "libOTe/Tools/EACode/EACode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" +#include "SilentOtExtUtil.h" namespace osuCrypto { @@ -110,7 +110,7 @@ namespace osuCrypto #endif // The ggm tree thats used to generate the sparse vectors. - SilentMultiPprfSender mGen; + RegularPprfSender mGen; // The type of compress we will use to generate the // dense vectors from the sparse vectors. @@ -119,17 +119,9 @@ namespace osuCrypto // The flag which controls whether the malicious check is performed. SilentSecType mMalType = SilentSecType::SemiHonest; - // The Silver encoder for MultType::slv5, MultType::slv11 -#ifdef ENABLE_INSECURE_SILVER - SilverEncoder mEncoder; -#endif ExConvCode mExConvEncoder; EACode mEAEncoder; - // The OTs send msgs which will be used to flood the - // last gap bits of the noisy vector for the slv code. - std::vector> mGapOts; - // The OTs send msgs which will be used to create the // secret share of xa * delta as described in ferret. std::vector> mMalCheckOts; diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h b/libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h new file mode 100644 index 00000000..9476bbc6 --- /dev/null +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h @@ -0,0 +1,23 @@ +#pragma once + + +namespace osuCrypto +{ + + enum class OTType + { + Random, Correlated + }; + + enum class ChoiceBitPacking + { + False, True + }; + + enum class SilentSecType + { + SemiHonest, + Malicious, + //MaliciousFS + }; +} \ No newline at end of file diff --git a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp index 767efa5d..85ed17c6 100644 --- a/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp +++ b/libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenMalOtExt.cpp @@ -16,6 +16,9 @@ namespace osuCrypto if ((u64)messages.data() % 32) throw std::runtime_error("soft spoken requires the messages to by 32 byte aligned. Consider using AlignedUnVector or AlignedVector." LOCATION); + if (messages.size() == 0) + throw std::runtime_error("soft spoken must be called with at least 1 messag." LOCATION); + MC_BEGIN(task<>, this, messages, &prng, &chl, nChunks = u64{}, messagesFullChunks = u64{}, @@ -26,6 +29,7 @@ namespace osuCrypto mHasher = Hasher{} ); + if (!hasBaseOts()) MC_AWAIT(genBaseOts(prng, chl)); @@ -53,7 +57,7 @@ namespace osuCrypto MC_AWAIT(runBatch(chl, scratch.subspan(0, messagesFullChunks * chunkSize()))); - assert(scratch[0] != ZeroBlock); + assert(messagesFullChunks == 0 || scratch[0] != ZeroBlock); // Extra blocks MC_AWAIT(runBatch(chl, mExtraW.subspan(0, numExtra * chunkSize()))); diff --git a/libOTe/Vole/Noisy/NoisyVoleReceiver.cpp b/libOTe/Vole/Noisy/NoisyVoleReceiver.cpp deleted file mode 100644 index 9a9a20cc..00000000 --- a/libOTe/Vole/Noisy/NoisyVoleReceiver.cpp +++ /dev/null @@ -1,88 +0,0 @@ -#include "NoisyVoleReceiver.h" - -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) -#include "cryptoTools/Common/BitIterator.h" -#include "cryptoTools/Common/Matrix.h" - - -namespace osuCrypto -{ - - task<> NoisyVoleReceiver::receive(span y, span z, PRNG& prng, - OtSender& ot, Socket& chl) - { - MC_BEGIN(task<>,this, y,z, &prng, &ot, &chl, - otMsg = AlignedUnVector>{ 128 } - ); - - setTimePoint("NoisyVoleReceiver.ot.begin"); - - MC_AWAIT(ot.send(otMsg, prng, chl)); - - setTimePoint("NoisyVoleReceiver.ot.end"); - - MC_AWAIT(receive(y, z, prng, otMsg, chl)); - - MC_END(); - } - - task<> NoisyVoleReceiver::receive( - span y, span z, - PRNG& _, span> otMsg, - Socket& chl) - { - MC_BEGIN(task<>,this, y,z, otMsg, &chl, - msg = Matrix{}, - prng = std::move(PRNG{}) - //buffer = std::vector{} - ); - - setTimePoint("NoisyVoleReceiver.begin"); - if (otMsg.size() != 128) - throw RTE_LOC; - if (y.size() != z.size()) - throw RTE_LOC; - if (z.size() == 0) - throw RTE_LOC; - - memset(z.data(), 0, sizeof(block) * z.size()); - msg.resize(otMsg.size(), y.size()); - - //buffer.resize(z.size()); - - for (u64 ii = 0; ii < (u64)otMsg.size(); ++ii) - { - //PRNG p0(otMsg[ii][0]); - //PRNG p1(otMsg[ii][1]); - prng.SetSeed(otMsg[ii][0], z.size()); - auto& buffer = prng.mBuffer; - - for (u64 j = 0; j < (u64)y.size(); ++j) - { - z[j] = z[j] ^ buffer[j]; - - block twoPowI = ZeroBlock; - *BitIterator((u8*)&twoPowI, ii) = 1; - - auto yy = y[j].gf128Mul(twoPowI); - - msg(ii, j) = yy ^ buffer[j]; - } - - prng.SetSeed(otMsg[ii][1], z.size()); - - for (u64 j = 0; j < (u64)y.size(); ++j) - { - // enc one message under the OT msg. - msg(ii, j) = msg(ii, j) ^ buffer[j]; - } - } - - MC_AWAIT(chl.send(std::move(msg))); - //chl.asyncSend(std::move(msg)); - setTimePoint("NoisyVoleReceiver.done"); - - MC_END(); - } -} -#endif \ No newline at end of file diff --git a/libOTe/Vole/Noisy/NoisyVoleReceiver.h b/libOTe/Vole/Noisy/NoisyVoleReceiver.h index da0182b6..3c83b2ff 100644 --- a/libOTe/Vole/Noisy/NoisyVoleReceiver.h +++ b/libOTe/Vole/Noisy/NoisyVoleReceiver.h @@ -1,12 +1,22 @@ #pragma once // © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. #include #if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) @@ -14,20 +24,127 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" #include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/TwoChooseOne/OTExtInterface.h" #include "libOTe/Tools/Coproto.h" +#include "libOTe/TwoChooseOne/OTExtInterface.h" +#include "libOTe/Tools/CoeffCtx.h" + +namespace osuCrypto { + + template < + typename F, + typename G = F, + typename CoeffCtx = DefaultCoeffCtx + > + class NoisyVoleReceiver : public TimerAdapter + { + public: + using VecF = typename CoeffCtx::template Vec; + + + // for chosen c, compute a such htat + // + // a = b + c * delta + // + template + task<> receive(VecG& c, VecF& a, PRNG& prng, + OtSender& ot, Socket& chl, CoeffCtx ctx) + { + MC_BEGIN(task<>, this, &c, &a, &prng, &ot, &chl, ctx, + otMsg = AlignedUnVector>{}); + + setTimePoint("NoisyVoleReceiver.ot.begin"); + otMsg.resize(ctx.template bitSize()); + MC_AWAIT(ot.send(otMsg, prng, chl)); + + setTimePoint("NoisyVoleReceiver.ot.end"); + + MC_AWAIT(receive(c, a, prng, otMsg, chl,ctx)); + + MC_END(); + } + + // for chosen c, compute a such htat + // + // a = b + c * delta + // + template + task<> receive(VecG& c, VecF& a, PRNG& _, + span> otMsg, + Socket& chl, CoeffCtx ctx) + { + MC_BEGIN(task<>, this, &c, &a, otMsg, &chl, ctx, + buff = std::vector{}, + msg = VecF{}, + temp = VecF{}, + prng = std::move(PRNG{}) + ); + + if (c.size() != a.size()) + throw RTE_LOC; + if (a.size() == 0) + throw RTE_LOC; + + setTimePoint("NoisyVoleReceiver.begin"); + + ctx.zero(a.begin(), a.end()); + ctx.resize(msg, otMsg.size() * a.size()); + ctx.resize(temp, 2); + + for (size_t i = 0, k = 0; i < otMsg.size(); ++i) + { + prng.SetSeed(otMsg[i][0], a.size()); + + // t1 = 2^i + ctx.powerOfTwo(temp[1], i); + //std::cout << "2^i " << ctx.str(temp[1]) << "\n"; + + for (size_t j = 0; j < c.size(); ++j, ++k) + { + // msg[i,j] = otMsg[i,j,0] + ctx.fromBlock(msg[k], prng.get()); + //ctx.zero(msg.begin() + k, msg.begin() + k + 1); + //std::cout << "m" << i << ",0 = " << ctx.str(msg[k]) << std::endl; + + // a[j] += otMsg[i,j,0] + ctx.plus(a[j], a[j], msg[k]); + //std::cout << "z = " << ctx.str(a[j]) << std::endl; + + // temp = 2^i * c[j] + ctx.mul(temp[0], temp[1], c[j]); + //std::cout << "2^i y = " << ctx.str(temp[0]) << std::endl; + + // msg[i,j] = otMsg[i,j,0] + 2^i * c[j] + ctx.minus(msg[k], msg[k], temp[0]); + //std::cout << "m" << i << ",0 + 2^i y = " << ctx.str(msg[k]) << std::endl; + } + + k -= c.size(); + prng.SetSeed(otMsg[i][1], a.size()); + + for (size_t j = 0; j < c.size(); ++j, ++k) + { + // temp = otMsg[i,j,1] + ctx.fromBlock(temp[0], prng.get()); + //ctx.zero(temp.begin(), temp.begin() + 1); + //std::cout << "m" << i << ",1 = " << ctx.str(temp[0]) << std::endl; + + // enc one message under the OT msg. + // msg[i,j] = (otMsg[i,j,0] + 2^i * c[j]) - otMsg[i,j,1] + ctx.minus(msg[k], msg[k], temp[0]); + //std::cout << "m" << i << ",0 + 2^i y - m" << i << ",1 = " << ctx.str(msg[k]) << std::endl << std::endl; + } + } -namespace osuCrypto -{ - class NoisyVoleReceiver : public TimerAdapter - { - public: + buff.resize(msg.size() * ctx.template byteSize()); + ctx.serialize(msg.begin(), msg.end(), buff.begin()); - task<> receive(span y, span z, PRNG& prng, OtSender& ot, Socket& chl); - task<> receive(span y, span z, PRNG& prng, span> otMsg, Socket& chl); + MC_AWAIT(chl.send(std::move(buff))); + setTimePoint("NoisyVoleReceiver.done"); - }; + MC_END(); + } + }; -} +} // namespace osuCrypto #endif diff --git a/libOTe/Vole/Noisy/NoisyVoleSender.cpp b/libOTe/Vole/Noisy/NoisyVoleSender.cpp deleted file mode 100644 index 83cc55fc..00000000 --- a/libOTe/Vole/Noisy/NoisyVoleSender.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "NoisyVoleSender.h" - -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) -#include "cryptoTools/Common/BitVector.h" -#include "cryptoTools/Common/Matrix.h" - -namespace osuCrypto -{ - task<> NoisyVoleSender::send( - block x, span z, - PRNG& prng, - OtReceiver& ot, - Socket& chl) - { - MC_BEGIN(task<>,this, x, z, &prng, &ot, &chl, - bv = BitVector((u8*)&x, 128), - otMsg = AlignedUnVector{ 128 }); - - setTimePoint("NoisyVoleSender.ot.begin"); - - //BitVector bv((u8*)&x, 128); - //std::array otMsg; - MC_AWAIT(ot.receive(bv, otMsg, prng, chl)); - setTimePoint("NoisyVoleSender.ot.end"); - - MC_AWAIT(send(x, z, prng, otMsg, chl)); - - MC_END(); - } - - task<> NoisyVoleSender::send( - block x, - span z, - PRNG& prng, - span otMsg, - Socket& chl) - { - MC_BEGIN(task<>,this, x, z, &prng, otMsg, &chl, - msg = Matrix{}, - buffer = std::vector{}, - xIter = BitIterator{}); - - if (otMsg.size() != 128) - throw RTE_LOC; - setTimePoint("NoisyVoleSender.main"); - - msg.resize(otMsg.size(), z.size()); - memset(z.data(), 0, sizeof(block) * z.size()); - - - MC_AWAIT(chl.recv(msg)); - - setTimePoint("NoisyVoleSender.recvMsg"); - buffer.resize(z.size()); - - xIter = BitIterator((u8*)&x); - - for (u64 i = 0; i < otMsg.size(); ++i, ++xIter) - { - PRNG pi(otMsg[i]); - pi.get(buffer); - - if (*xIter) - { - for (u64 j = 0; j < z.size(); ++j) - { - buffer[j] = msg(i, j) ^ buffer[j]; - } - } - - for (u64 j = 0; j < (u64)z.size(); ++j) - { - z[j] = z[j] ^ buffer[j]; - } - } - setTimePoint("NoisyVoleSender.done"); - - MC_END(); - } - -} -#endif diff --git a/libOTe/Vole/Noisy/NoisyVoleSender.h b/libOTe/Vole/Noisy/NoisyVoleSender.h index 862d3002..d42e9fcb 100644 --- a/libOTe/Vole/Noisy/NoisyVoleSender.h +++ b/libOTe/Vole/Noisy/NoisyVoleSender.h @@ -1,32 +1,142 @@ #pragma once // © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). +// This code implements features described in [Silver: Silent VOLE and Oblivious +// Transfer from Hardness of Decoding Structured LDPC Codes, +// https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative +// Commons Attribution 4.0 International Public License +// (https://creativecommons.org/licenses/by/4.0/legalcode). #include #if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) +#include "cryptoTools/Common/BitVector.h" #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Timer.h" #include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/TwoChooseOne/OTExtInterface.h" #include "libOTe/Tools/Coproto.h" +#include "libOTe/TwoChooseOne/OTExtInterface.h" +#include "libOTe/Tools/CoeffCtx.h" - -namespace osuCrypto -{ - class NoisyVoleSender : public TimerAdapter +namespace osuCrypto { + template < + typename F, + typename G, + typename CoeffCtx + > + class NoisyVoleSender : public TimerAdapter { + public: + using VecF = typename CoeffCtx::template Vec; + + // for chosen delta, compute b such htat + // + // a = b + c * delta + // + template + task<> send(F delta, FVec& b, PRNG& prng, + OtReceiver& ot, Socket& chl, CoeffCtx ctx) { + MC_BEGIN(task<>, this, delta, &b, &prng, &ot, &chl, ctx, + bv = ctx.binaryDecomposition(delta), + otMsg = AlignedUnVector{ }); + otMsg.resize(bv.size()); + + setTimePoint("NoisyVoleSender.ot.begin"); + + MC_AWAIT(ot.receive(bv, otMsg, prng, chl)); + setTimePoint("NoisyVoleSender.ot.end"); + + MC_AWAIT(send(delta, b, prng, otMsg, chl, ctx)); + + MC_END(); + } + + // for chosen delta, compute b such htat + // + // a = b + c * delta + // + template + task<> send(F delta, FVec& b, PRNG& _, + span otMsg, Socket& chl, CoeffCtx ctx) { + MC_BEGIN(task<>, this, delta, &b, otMsg, &chl, ctx, + prng = std::move(PRNG{}), + buffer = std::vector{}, + msg = VecF{}, + temp = VecF{}, + xb = BitVector{}); + + xb = ctx.binaryDecomposition(delta); + + if (otMsg.size() != xb.size()) + throw RTE_LOC; + setTimePoint("NoisyVoleSender.main"); + + // b = 0; + ctx.zero(b.begin(), b.end()); + + // receive the the excrypted one shares. + buffer.resize(xb.size() * b.size() * ctx.template byteSize()); + MC_AWAIT(chl.recv(buffer)); + ctx.resize(msg, xb.size() * b.size()); + ctx.deserialize(buffer.begin(), buffer.end(), msg.begin()); + + setTimePoint("NoisyVoleSender.recvMsg"); + + temp.resize(1); + for (size_t i = 0, k = 0; i < xb.size(); ++i) + { + // expand the zero shares or one share masks + prng.SetSeed(otMsg[i], b.size()); + + // otMsg[i,j, bc[i]] + //auto otMsgi = prng.getBufferSpan(b.size()); + + for (u64 j = 0; j < (u64)b.size(); ++j, ++k) + { + // temp = otMsg[i,j, xb[i]] + ctx.fromBlock(temp[0], prng.get()); + //ctx.zero(temp.begin(), temp.begin() + 1); + //std::cout << "m" << i << ","< send(block x, span z, PRNG& prng, OtReceiver& ot, Socket& chl); - task<> send(block x, span z, PRNG& prng, span otMsg, Socket& chl); }; -} +} // namespace osuCrypto #endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.cpp b/libOTe/Vole/Silent/SilentVoleReceiver.cpp deleted file mode 100644 index 8d1972b8..00000000 --- a/libOTe/Vole/Silent/SilentVoleReceiver.cpp +++ /dev/null @@ -1,822 +0,0 @@ -#include "libOTe/Vole/Silent/SilentVoleReceiver.h" - -#ifdef ENABLE_SILENT_VOLE -#include "libOTe/Vole/Silent/SilentVoleSender.h" -#include "libOTe/Vole/Noisy/NoisyVoleReceiver.h" -#include -#include - -#include -#include -#include -#include -#include - -namespace osuCrypto -{ - - - //u64 getPartitions(u64 scaler, u64 p, u64 secParam); - - - - // sets the Iknp base OTs that are then used to extend - void SilentVoleReceiver::setBaseOts( - span> baseSendOts) - { -#ifdef ENABLE_SOFTSPOKEN_OT - mOtExtRecver.setBaseOts(baseSendOts); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - // return the number of base OTs soft spoken needs - u64 SilentVoleReceiver::baseOtCount() const { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtRecver.baseOtCount(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - // returns true if the soft spoken base OTs are currently set. - bool SilentVoleReceiver::hasBaseOts() const { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtRecver.hasBaseOts(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - }; - - - BitVector SilentVoleReceiver::sampleBaseChoiceBits(PRNG& prng) { - - if (isConfigured() == false) - throw std::runtime_error("configure(...) must be called first"); - - auto choice = mGen.sampleChoiceBits(mN2, getPprfFormat(), prng); - - mGapBaseChoice.resize(mGapOts.size()); - mGapBaseChoice.randomize(prng); - choice.append(mGapBaseChoice); - - return choice; - } - - std::vector SilentVoleReceiver::sampleBaseVoleVals(PRNG& prng) - { - if (isConfigured() == false) - throw RTE_LOC; - if (mGapBaseChoice.size() != mGapOts.size()) - throw std::runtime_error("sampleBaseChoiceBits must be called before sampleBaseVoleVals. " LOCATION); - - // sample the values of the noisy coordinate of c - // and perform a noicy vole to get x+y = mD * c - auto w = mNumPartitions + mGapOts.size(); - //std::vector y(w); - mNoiseValues.resize(w); - prng.get(mNoiseValues); - - mS.resize(mNumPartitions); - mGen.getPoints(mS, getPprfFormat()); - - auto j = mNumPartitions * mSizePer; - - for (u64 i = 0; i < (u64)mGapBaseChoice.size(); ++i) - { - if (mGapBaseChoice[i]) - { - mS.push_back(j + i); - } - } - - if (mMalType == SilentSecType::Malicious) - { - - mMalCheckSeed = prng.get(); - mMalCheckX = ZeroBlock; - auto yIter = mNoiseValues.begin(); - - for (u64 i = 0; i < mNumPartitions; ++i) - { - auto s = mS[i]; - auto xs = mMalCheckSeed.gf128Pow(s + 1); - mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); - ++yIter; - } - - auto sIter = mS.begin() + mNumPartitions; - for (u64 i = 0; i < mGapBaseChoice.size(); ++i) - { - if (mGapBaseChoice[i]) - { - auto s = *sIter; - auto xs = mMalCheckSeed.gf128Pow(s + 1); - mMalCheckX = mMalCheckX ^ xs.gf128Mul(*yIter); - ++sIter; - } - ++yIter; - } - - - std::vector y(mNoiseValues.begin(), mNoiseValues.end()); - y.push_back(mMalCheckX); - return y; - } - - return std::vector(mNoiseValues.begin(), mNoiseValues.end()); - } - - task<> SilentVoleReceiver::genBaseOts( - PRNG& prng, - Socket& chl) - { - setTimePoint("SilentVoleReceiver.gen.start"); -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtRecver.genBaseOts(prng, chl); - //mIknpSender.genBaseOts(mIknpRecver, prng, chl); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - void SilentVoleReceiver::configure( - u64 numOTs, - SilentBaseType type, - u64 secParam) - { - mState = State::Configured; - u64 gap = 0; - mBaseType = type; - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - { - u64 p, s; - QuasiCyclicConfigure(numOTs, secParam, - 2, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - p, - s - ); -#ifdef ENABLE_BITPOLYMUL - mQuasiCyclicEncoder.init(p, s); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - break; - } -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - SilverConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - - break; -#endif - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - EAConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - mEAEncoder); - - break; - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - - ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); - break; - default: - throw RTE_LOC; - break; - } - - mGapOts.resize(gap); - mGen.configure(mSizePer, mNumPartitions); - } - - - task<> SilentVoleReceiver::genSilentBaseOts( - PRNG& prng, - Socket& chl) - { -#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE -using BaseOT = McRosRoyTwist; -#elif defined ENABLE_MR -using BaseOT = MasnyRindal; -#elif defined ENABLE_MRR -using BaseOT = McRosRoy; -#else -using BaseOT = DefaultBaseOT; -#endif - - MC_BEGIN(task<>, this, &prng, &chl, - choice = BitVector{}, - bb = BitVector{}, - msg = AlignedUnVector{}, - baseVole = std::vector{}, - baseOt = BaseOT{}, - chl2 = Socket{}, - prng2 = std::move(PRNG{}), - noiseVals = std::vector{}, - noiseDeltaShares = std::vector{}, - nv = NoisyVoleReceiver{} - - ); - - setTimePoint("SilentVoleReceiver.genSilent.begin"); - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - choice = sampleBaseChoiceBits(prng); - msg.resize(choice.size()); - - // sample the noise vector noiseVals such that we will compute - // - // C = (000 noiseVals[0] 0000 ... 000 noiseVals[p] 000) - // - // and then we want secret shares of C * delta. As a first step - // we will compute secret shares of - // - // delta * noiseVals - // - // and store our share in voleDeltaShares. This party will then - // compute their share of delta * C as what comes out of the PPRF - // plus voleDeltaShares[i] added to the appreciate spot. Similarly, the - // other party will program the PPRF to output their share of delta * noiseVals. - // - noiseVals = sampleBaseVoleVals(prng); - noiseDeltaShares.resize(noiseVals.size()); - if (mTimer) - nv.setTimer(*mTimer); - - if (mBaseType == SilentBaseType::BaseExtend) - { -#ifdef ENABLE_SOFTSPOKEN_OT - - if (mOtExtSender.hasBaseOts() == false) - { - msg.resize(msg.size() + mOtExtSender.baseOtCount()); - bb.resize(mOtExtSender.baseOtCount()); - bb.randomize(prng); - choice.append(bb); - - MC_AWAIT(mOtExtRecver.receive(choice, msg, prng, chl)); - - mOtExtSender.setBaseOts( - span(msg).subspan( - msg.size() - mOtExtSender.baseOtCount(), - mOtExtSender.baseOtCount()), - bb); - - msg.resize(msg.size() - mOtExtSender.baseOtCount()); - MC_AWAIT(nv.receive(noiseVals, noiseDeltaShares, prng, mOtExtSender, chl)); - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - - MC_AWAIT( - macoro::when_all_ready( - nv.receive(noiseVals, noiseDeltaShares, prng2, mOtExtSender, chl2), - mOtExtRecver.receive(choice, msg, prng, chl) - )); - } -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - MC_AWAIT( - macoro::when_all_ready( - nv.receive(noiseVals, noiseDeltaShares, prng2, baseOt, chl2), - baseOt.receive(choice, msg, prng, chl) - )); - } - - - - - setSilentBaseOts(msg, noiseDeltaShares); - - setTimePoint("SilentVoleReceiver.genSilent.done"); - - MC_END(); - }; - - void SilentVoleReceiver::setSilentBaseOts( - span recvBaseOts, - span noiseDeltaShare) - { - if (isConfigured() == false) - throw std::runtime_error("configure(...) must be called first."); - - if (static_cast(recvBaseOts.size()) != silentBaseOtCount()) - throw std::runtime_error("wrong number of silent base OTs"); - - auto genOts = recvBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOts = recvBaseOts.subspan(mGen.baseOtCount(), mGapOts.size()); - - mGen.setBase(genOts); - std::copy(gapOts.begin(), gapOts.end(), mGapOts.begin()); - - if (mMalType == SilentSecType::Malicious) - { - mDeltaShare = noiseDeltaShare.back(); - noiseDeltaShare = noiseDeltaShare.subspan(0, noiseDeltaShare.size() - 1); - } - - mNoiseDeltaShare = AlignedVector(noiseDeltaShare.begin(), noiseDeltaShare.end()); - - mState = State::HasBase; - } - - - - //sigma = 0 Receiver - // - // u_i is the choice bit - // v_i = w_i + u_i * x - // - // ------------------------ - - // u' = 0000001000000000001000000000100000...00000, u_i = 1 iff i \in S - // - // v' = r + (x . u') = DPF(k0) - // = r + (000000x00000000000x000000000x00000...00000) - // - // u = u' * H bit-vector * H. Mapping n'->n bits - // v = v' * H block-vector * H. Mapping n'->n block - // - //sigma = 1 Sender - // - // x is the delta - // w_i is the zero message - // - // m_i0 = w_i - // m_i1 = w_i + x - // - // ------------------------ - // x - // r = DPF(k1) - // - // w = r * H - task<> SilentVoleReceiver::silentReceive( - span c, - span b, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, c, b, &prng, &chl); - if (c.size() != b.size()) - throw RTE_LOC; - - MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); - - std::memcpy(c.data(), mC.data(), c.size() * sizeof(block)); - std::memcpy(b.data(), mA.data(), b.size() * sizeof(block)); - clear(); - MC_END(); - } - - task<> SilentVoleReceiver::silentReceiveInplace( - u64 n, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>, this, n, &prng, &chl, - gapVals = std::vector{}, - myHash = std::array{}, - theirHash = std::array{} - - ); - gTimer.setTimePoint("SilentVoleReceiver.ot.enter"); - - if (isConfigured() == false) - { - // first generate 128 normal base OTs - configure(n, SilentBaseType::BaseExtend); - } - - if (mRequestedNumOTs != n) - throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); - - if (hasSilentBaseOts() == false) - { - MC_AWAIT(genSilentBaseOts(prng, chl)); - } - - // allocate mA - mA.resize(0); - mA.resize(mN2); - - setTimePoint("SilentVoleReceiver.alloc"); - - // allocate the space for mC - mC.resize(0); - mC.resize(mN2, AllocType::Zeroed); - setTimePoint("SilentVoleReceiver.alloc.zero"); - - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - - if(gapVals.size()) - MC_AWAIT(chl.recv(gapVals)); - - for (auto g : rng(mGapOts.size())) - { - auto aa = mA.subspan(mNumPartitions * mSizePer); - auto cc = mC.subspan(mNumPartitions * mSizePer); - - auto noise = mNoiseValues.subspan(mNumPartitions); - auto noiseShares = mNoiseDeltaShare.subspan(mNumPartitions); - - if (mGapBaseChoice[g]) - { - cc[g] = noise[g]; - aa[g] = AES(mGapOts[g]).ecbEncBlock(ZeroBlock) ^ - gapVals[g] ^ - noiseShares[g]; - } - else - aa[g] = mGapOts[g]; - } - - setTimePoint("SilentVoleReceiver.recvGap"); - - - - if (mTimer) - mGen.setTimer(*mTimer); - // expand the seeds into mA - MC_AWAIT(mGen.expand(chl, mA.subspan(0, mNumPartitions * mSizePer), PprfOutputFormat::Interleaved, true, mNumThreads)); - - setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); - - // populate the noisy coordinates of mC and - // update mA to be a secret share of mC * delta - for (u64 i = 0; i < mNumPartitions; ++i) - { - auto pnt = mS[i]; - mC[pnt] = mNoiseValues[i]; - mA[pnt] = mA[pnt] ^ mNoiseDeltaShare[i]; - } - - - if (mDebug) - { - MC_AWAIT(checkRT(chl)); - setTimePoint("SilentVoleReceiver.expand.checkRT"); - } - - - if (mMalType == SilentSecType::Malicious) - { - MC_AWAIT(chl.send(std::move(mMalCheckSeed))); - - myHash = ferretMalCheck(mDeltaShare, mNoiseValues); - - MC_AWAIT(chl.recv(theirHash)); - - if (theirHash != myHash) - throw RTE_LOC; - } - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - -#ifdef ENABLE_BITPOLYMUL - if (mTimer) - mQuasiCyclicEncoder.setTimer(getTimer()); - - // compress both mA and mC in place. - mQuasiCyclicEncoder.dualEncode(mA.subspan(0, mQuasiCyclicEncoder.size())); - mQuasiCyclicEncoder.dualEncode(mC.subspan(0, mQuasiCyclicEncoder.size())); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - - setTimePoint("SilentVoleReceiver.expand.mQuasiCyclicEncoder.a"); - break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - if (mTimer) - mEncoder.setTimer(getTimer()); - - // compress both mA and mC in place. - mEncoder.dualEncode2(mA, mC); - setTimePoint("SilentVoleReceiver.expand.cirTransEncode.a"); - break; -#endif - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - { - if (mTimer) - mEAEncoder.setTimer(getTimer()); - - AlignedUnVector - A2(mEAEncoder.mMessageSize), - C2(mEAEncoder.mMessageSize); - - // compress both mA and mC in place. - mEAEncoder.dualEncode2( - mA.subspan(0, mEAEncoder.mCodeSize), A2, - mC.subspan(0, mEAEncoder.mCodeSize), C2); - - std::swap(mA, A2); - std::swap(mC, C2); - - setTimePoint("SilentVoleReceiver.expand.cirTransEncode.a"); - break; - } - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - if (mTimer) - mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode2( - mA.subspan(0, mExConvEncoder.mCodeSize), - mC.subspan(0, mExConvEncoder.mCodeSize) - ); - break; - default: - throw RTE_LOC; - break; - } - - // resize the buffers down to only contain the real elements. - mA.resize(mRequestedNumOTs); - mC.resize(mRequestedNumOTs); - - mNoiseValues = {}; - mNoiseDeltaShare = {}; - - // make the protocol as done and that - // mA,mC are ready to be consumed. - mState = State::Default; - - MC_END(); - } - - - std::array SilentVoleReceiver::ferretMalCheck( - block deltaShare, - span yy) - { - - block xx = mMalCheckSeed; - block sum0 = ZeroBlock; - block sum1 = ZeroBlock; - - - for (u64 i = 0; i < (u64)mA.size(); ++i) - { - block low, high; - xx.gf128Mul(mA[i], low, high); - sum0 = sum0 ^ low; - sum1 = sum1 ^ high; - //mySum = mySum ^ xx.gf128Mul(mA[i]); - - // xx = mMalCheckSeed^{i+1} - xx = xx.gf128Mul(mMalCheckSeed); - } - block mySum = sum0.gf128Reduce(sum1); - - std::array myHash; - RandomOracle ro(32); - ro.Update(mySum ^ deltaShare); - ro.Final(myHash); - return myHash; - } - - - u64 SilentVoleReceiver::silentBaseOtCount() const - { - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - return mGen.baseOtCount() + mGapOts.size(); - - } - - task<> SilentVoleReceiver::checkRT(Socket& chl) const - { - MC_BEGIN(task<>, this, &chl, - B = AlignedVector(mA.size()), - sparseNoiseDelta = std::vector(mA.size()), - noiseDeltaShare2 = std::vector(), - delta = block{} - ); - //std::vector mB(mA.size()); - MC_AWAIT(chl.recv(delta)); - MC_AWAIT(chl.recv(B)); - MC_AWAIT(chl.recvResize(noiseDeltaShare2)); - - //check that at locations mS[0],...,mS[..] - // that we hold a sharing mA, mB of - // - // delta * mC = delta * (00000 noiseDeltaShare2[0] 0000 .... 0000 noiseDeltaShare2[m] 0000) - // - // where noiseDeltaShare2[i] is at position mS[i] of mC - // - // That is, I hold mA, mC s.t. - // - // delta * mC = mA + mB - // - - if (noiseDeltaShare2.size() != mNoiseDeltaShare.size()) - throw RTE_LOC; - - for (auto i : rng(mNoiseDeltaShare.size())) - { - if ((mNoiseDeltaShare[i] ^ noiseDeltaShare2[i]) != mNoiseValues[i].gf128Mul(delta)) - throw RTE_LOC; - } - - { - - for (auto i : rng(mNumPartitions* mSizePer)) - { - auto iter = std::find(mS.begin(), mS.end(), i); - if (iter != mS.end()) - { - auto d = iter - mS.begin(); - - if (mC[i] != mNoiseValues[d]) - throw RTE_LOC; - - if (mNoiseValues[d].gf128Mul(delta) != (mA[i] ^ B[i])) - { - std::cout << "bad vole base correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; - std::cout << "i " << i << std::endl; - std::cout << "mA[i] " << mA[i] << std::endl; - std::cout << "mB[i] " << B[i] << std::endl; - std::cout << "mC[i] " << mC[i] << std::endl; - std::cout << "delta " << delta << std::endl; - std::cout << "mA[i] + mB[i] " << (mA[i] ^ B[i]) << std::endl; - std::cout << "mC[i] * delta " << (mC[i].gf128Mul(delta)) << std::endl; - - throw RTE_LOC; - } - } - else - { - if (mA[i] != B[i]) - { - std::cout << mA[i] << " " << B[i] << std::endl; - throw RTE_LOC; - } - - if (mC[i] != oc::ZeroBlock) - throw RTE_LOC; - } - } - - u64 d = mNumPartitions; - for (auto j : rng(mGapBaseChoice.size())) - { - auto idx = j + mNumPartitions * mSizePer; - auto aa = mA.subspan(mNumPartitions * mSizePer); - auto bb = B.subspan(mNumPartitions * mSizePer); - auto cc = mC.subspan(mNumPartitions * mSizePer); - auto noise = mNoiseValues.subspan(mNumPartitions); - //auto noiseShare = mNoiseValues.subspan(mNumPartitions); - if (mGapBaseChoice[j]) - { - if (mS[d++] != idx) - throw RTE_LOC; - - if (cc[j] != noise[j]) - { - std::cout << "sparse noise vector mC is not the expected value" << std::endl; - std::cout << "i j " << idx << " " << j << std::endl; - std::cout << "mC[i] " << cc[j] << std::endl; - std::cout << "noise[j] " << noise[j] << std::endl; - throw RTE_LOC; - } - - if (noise[j].gf128Mul(delta) != (aa[j] ^ bb[j])) - { - - std::cout << "bad vole base GAP correlation, mA[i] + mB[i] != mC[i] * delta" << std::endl; - std::cout << "i " << idx << std::endl; - std::cout << "mA[i] " << aa[j] << std::endl; - std::cout << "mB[i] " << bb[j] << std::endl; - std::cout << "mC[i] " << cc[j] << std::endl; - std::cout << "delta " << delta << std::endl; - std::cout << "mA[i] + mB[i] " << (aa[j] ^ bb[j]) << std::endl; - std::cout << "mC[i] * delta " << (cc[j].gf128Mul(delta)) << std::endl; - std::cout << "noise * delta " << (noise[j].gf128Mul(delta)) << std::endl; - throw RTE_LOC; - } - - } - else - { - if (aa[j] != bb[j]) - throw RTE_LOC; - - if (cc[j] != oc::ZeroBlock) - throw RTE_LOC; - } - } - - if (d != mS.size()) - throw RTE_LOC; - } - - - //{ - - // auto cDelta = B; - // for (u64 i = 0; i < cDelta.size(); ++i) - // cDelta[i] = cDelta[i] ^ mA[i]; - - // std::vector exp(mN2); - // for (u64 i = 0; i < mNumPartitions; ++i) - // { - // auto j = mS[i]; - // exp[j] = noiseDeltaShare2[i]; - // } - - // auto iter = mS.begin() + mNumPartitions; - // for (u64 i = 0, j = mNumPartitions * mSizePer; i < mGapOts.size(); ++i, ++j) - // { - // if (mGapBaseChoice[i]) - // { - // if (*iter != j) - // throw RTE_LOC; - // ++iter; - - // exp[j] = noiseDeltaShare2[mNumPartitions + i]; - // } - // } - - // if (iter != mS.end()) - // throw RTE_LOC; - - // bool failed = false; - // for (u64 i = 0; i < mN2; ++i) - // { - // if (neq(cDelta[i], exp[i])) - // { - // std::cout << i << " / " << mN2 << - // " cd = " << cDelta[i] << - // " exp= " << exp[i] << std::endl; - // failed = true; - // } - // } - - // if (failed) - // throw RTE_LOC; - - // std::cout << "debug check ok" << std::endl; - //} - - MC_END(); - } - - - void SilentVoleReceiver::clear() - { - mS = {}; - mA = {}; - mC = {}; - mGen.clear(); - mGapBaseChoice = {}; - } - - -} -#endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleReceiver.h b/libOTe/Vole/Silent/SilentVoleReceiver.h index be968a6f..13024928 100644 --- a/libOTe/Vole/Silent/SilentVoleReceiver.h +++ b/libOTe/Vole/Silent/SilentVoleReceiver.h @@ -14,26 +14,38 @@ #include #include #include -#include +#include "libOTe/Tools/Pprf/RegularPprf.h" #include #include #include -#include -#include #include -#include #include +#include +#include +#include +#include +#include "libOTe/Tools/QuasiCyclicCode.h" +#include "libOTe/TwoChooseOne/Silent/SilentOtExtUtil.h" +#include namespace osuCrypto { - // For more documentation see SilentOtExtSender. + template< + typename F, + typename G = F, + typename Ctx = DefaultCoeffCtx + > class SilentVoleReceiver : public TimerAdapter { public: static constexpr u64 mScaler = 2; + static constexpr bool MaliciousSupported = + std::is_same_v && std::is_same_v; + + enum class State { Default, @@ -41,22 +53,29 @@ namespace osuCrypto HasBase }; + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; + // The current state of the protocol State mState = State::Default; - // The number of OTs the user requested. - u64 mRequestedNumOTs = 0; + // the context used to perform F, G operations + Ctx mCtx; + + // The number of correlations the user requested. + u64 mRequestSize = 0; + + // the LPN security parameter + u64 mSecParam = 0; + + // The length of the noisy vectors (2 * mN for the most codes). + u64 mNoiseVecSize = 0; - // The number of OTs actually produced (at least the number requested). - u64 mN = 0; - - // The length of the noisy vectors (2 * mN for the silver codes). - u64 mN2 = 0; - // We perform regular LPN, so this is the // size of the each chunk. u64 mSizePer = 0; + // the number of noisy positions u64 mNumPartitions = 0; // The noisy coordinates. @@ -69,116 +88,365 @@ namespace osuCrypto // the sparse vector. MultType mMultType = DefaultMultType; -#ifdef ENABLE_INSECURE_SILVER - // The silver encoder. - SilverEncoder mEncoder; -#endif - - ExConvCode mExConvEncoder; - EACode mEAEncoder; - -#ifdef ENABLE_BITPOLYMUL - QuasiCyclicCode mQuasiCyclicEncoder; -#endif - // The multi-point punctured PRF for generating // the sparse vectors. - SilentMultiPprfReceiver mGen; + RegularPprfReceiver mGen; // The internal buffers for holding the expanded vectors. - // mA + mB = mC * delta - AlignedUnVector mA; + // mA = mB + mC * delta + VecF mA; - // mA + mB = mC * delta - AlignedUnVector mC; - - std::vector mGapOts; + // mA = mB + mC * delta + VecG mC; u64 mNumThreads = 1; bool mDebug = false; - BitVector mIknpSendBaseChoice, mGapBaseChoice; + BitVector mIknpSendBaseChoice; + + SilentSecType mMalType = SilentSecType::SemiHonest; - SilentSecType mMalType = SilentSecType::SemiHonest; + block mMalCheckSeed, mMalCheckX, mMalBaseA; - block mMalCheckSeed, mMalCheckX, mDeltaShare; - - AlignedVector mNoiseDeltaShare, mNoiseValues; + // we + VecF mBaseA; + VecG mBaseC; #ifdef ENABLE_SOFTSPOKEN_OT - SoftSpokenMalOtSender mOtExtSender; - SoftSpokenMalOtReceiver mOtExtRecver; + macoro::optional mOtExtSender; + macoro::optional mOtExtRecver; #endif - // sets the Iknp base OTs that are then used to extend - void setBaseOts( - span> baseSendOts); - - // return the number of base OTs IKNP needs - u64 baseOtCount() const; + // // sets the Iknp base OTs that are then used to extend + // void setBaseOts( + // span> baseSendOts); + // + // // return the number of base OTs IKNP needs + // u64 baseOtCount() const; u64 baseVoleCount() const { - return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); } - // returns true if the IKNP base OTs are currently set. - bool hasBaseOts() const; - - // returns true if the silent base OTs are set. + // // returns true if the IKNP base OTs are currently set. + // bool hasBaseOts() const; + // + // returns true if the silent base OTs are set. bool hasSilentBaseOts() const { return mGen.hasBaseOts(); }; - - // Generate the IKNP base OTs - task<> genBaseOts(PRNG& prng, Socket& chl) ; + // + // // Generate the IKNP base OTs + // task<> genBaseOts(PRNG& prng, Socket& chl) ; // Generate the silent base OTs. If the Iknp // base OTs are set then we do an IKNP extend, // otherwise we perform a base OT protocol to // generate the needed OTs. - task<> genSilentBaseOts(PRNG& prng, Socket& chl); - + task<> genSilentBaseOts(PRNG& prng, Socket& chl) + { +#ifdef LIBOTE_HAS_BASE_OT + +#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE + using BaseOT = McRosRoyTwist; +#elif defined ENABLE_MR + using BaseOT = MasnyRindal; +#elif defined ENABLE_MRR + using BaseOT = McRosRoy; +#elif defined ENABLE_NP_KYBER + using BaseOT = MasnyRindalKyber; +#else + using BaseOT = DefaultBaseOT; +#endif + + MC_BEGIN(task<>, this, &prng, &chl, + choice = BitVector{}, + bb = BitVector{}, + msg = AlignedUnVector{}, + baseVole = std::vector{}, + baseOt = BaseOT{}, + chl2 = Socket{}, + prng2 = std::move(PRNG{}), + noiseVals = VecG{}, + baseAs = VecF{}, + nv = NoisyVoleReceiver{} + + ); + + setTimePoint("SilentVoleReceiver.genSilent.begin"); + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + choice = sampleBaseChoiceBits(prng); + msg.resize(choice.size()); + + // sample the noise vector noiseVals such that we will compute + // + // C = (000 noiseVals[0] 0000 ... 000 noiseVals[p] 000) + // + // and then we want secret shares of C * delta. As a first step + // we will compute secret shares of + // + // delta * noiseVals + // + // and store our share in voleDeltaShares. This party will then + // compute their share of delta * C as what comes out of the PPRF + // plus voleDeltaShares[i] added to the appreciate spot. Similarly, the + // other party will program the PPRF to output their share of delta * noiseVals. + // + noiseVals = sampleBaseVoleVals(prng); + mCtx.resize(baseAs, noiseVals.size()); + + if (mTimer) + nv.setTimer(*mTimer); + + if (mBaseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + if (!mOtExtRecver) + mOtExtRecver.emplace(); + + if (!mOtExtSender) + mOtExtSender.emplace(); + + if (mOtExtSender->hasBaseOts() == false) + { + msg.resize(msg.size() + mOtExtSender->baseOtCount()); + bb.resize(mOtExtSender->baseOtCount()); + bb.randomize(prng); + choice.append(bb); + + MC_AWAIT(mOtExtRecver->receive(choice, msg, prng, chl)); + + mOtExtSender->setBaseOts( + span(msg).subspan( + msg.size() - mOtExtSender->baseOtCount(), + mOtExtSender->baseOtCount()), + bb); + + msg.resize(msg.size() - mOtExtSender->baseOtCount()); + MC_AWAIT(nv.receive(noiseVals, baseAs, prng, *mOtExtSender, chl, mCtx)); + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + + MC_AWAIT( + macoro::when_all_ready( + nv.receive(noiseVals, baseAs, prng2, *mOtExtSender, chl2, mCtx), + mOtExtRecver->receive(choice, msg, prng, chl) + )); + } +#else + throw std::runtime_error("soft spoken must be enabled"); +#endif + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + MC_AWAIT( + macoro::when_all_ready( + baseOt.receive(choice, msg, prng, chl), + nv.receive(noiseVals, baseAs, prng2, baseOt, chl2, mCtx)) + ); + } + + setSilentBaseOts(msg, baseAs); + setTimePoint("SilentVoleReceiver.genSilent.done"); + MC_END(); +#else + throw std::runtime_error("LIBOTE_HAS_BASE_OT = false, must enable relic, sodium or simplest ot asm." LOCATION); +#endif + + }; + // configure the silent OT extension. This sets // the parameters and figures out how many base OT // will be needed. These can then be ganerated for // a different OT extension or using a base OT protocol. void configure( - u64 n, - SilentBaseType baseType = SilentBaseType::BaseExtend, - u64 secParam = 128); + u64 requestSize, + SilentBaseType type = SilentBaseType::BaseExtend, + u64 secParam = 128, + Ctx ctx = {}) + { + mCtx = std::move(ctx); + mSecParam = secParam; + mRequestSize = requestSize; + mState = State::Configured; + mBaseType = type; + double minDist = 0; + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + u64 _1, _2; + ExConvConfigure(mScaler, mMultType, _1, _2, minDist); + break; + } + case MultType::QuasiCyclic: + QuasiCyclicConfigure(mScaler, minDist); + break; + default: + throw RTE_LOC; + break; + } + + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); + mNoiseVecSize = mSizePer * mNumPartitions; + + //std::cout << "n " << mRequestSize << " -> " << mNoiseVecSize << " = " << mSizePer << " * " << mNumPartitions << std::endl; + + mGen.configure(mSizePer, mNumPartitions); + } // return true if this instance has been configured. bool isConfigured() const { return mState != State::Default; } // Returns how many base OTs the silent OT extension // protocol will needs. - u64 silentBaseOtCount() const; + u64 silentBaseOtCount() const + { + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + return mGen.baseOtCount(); + + } // The silent base OTs must have specially set base OTs. // This returns the choice bits that should be used. // Call this is you want to use a specific base OT protocol // and then pass the OT messages back using setSilentBaseOts(...). - BitVector sampleBaseChoiceBits(PRNG& prng); + BitVector sampleBaseChoiceBits(PRNG& prng) { + + if (isConfigured() == false) + throw std::runtime_error("configure(...) must be called first"); - std::vector sampleBaseVoleVals(PRNG& prng); + auto choice = mGen.sampleChoiceBits(prng); + + return choice; + } + + VecG sampleBaseVoleVals(PRNG& prng) + { + if (isConfigured() == false) + throw RTE_LOC; + + // sample the values of the noisy coordinate of c + // and perform a noicy vole to get a = b + mD * c + + + mCtx.resize(mBaseC, mNumPartitions + (mMalType == SilentSecType::Malicious)); + + if (mCtx.template bitSize() == 1) + { + mCtx.one(mBaseC.begin(), mBaseC.begin() + mNumPartitions); + } + else + { + VecG zero, one; + mCtx.resize(zero, 1); + mCtx.resize(one, 1); + mCtx.zero(zero.begin(), zero.end()); + mCtx.one(one.begin(), one.end()); + for (size_t i = 0; i < mNumPartitions; i++) + { + mCtx.fromBlock(mBaseC[i], prng.get()); + + // must not be zero. + while(mCtx.eq(zero[0], mBaseC[i])) + mCtx.fromBlock(mBaseC[i], prng.get()); + + // if we are not a field, then the noise should be odd. + if (mCtx.template isField() == false) + { + u8 odd = mCtx.binaryDecomposition(mBaseC[i])[0]; + if (odd) + mCtx.plus(mBaseC[i], mBaseC[i], one[0]); + } + } + } + + + mS.resize(mNumPartitions); + mGen.getPoints(mS, PprfOutputFormat::Interleaved); + + if (mMalType == SilentSecType::Malicious) + { + if constexpr (MaliciousSupported) + { + mMalCheckSeed = prng.get(); + + auto yIter = mBaseC.begin(); + mCtx.zero(mBaseC.end() - 1, mBaseC.end()); + for (u64 i = 0; i < mNumPartitions; ++i) + { + auto s = mS[i]; + auto xs = mMalCheckSeed.gf128Pow(s + 1); + mBaseC[mNumPartitions] = mBaseC[mNumPartitions] ^ xs.gf128Mul(*yIter); + ++yIter; + } + } + else + { + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + } + } + + return mBaseC; + } // Set the externally generated base OTs. This choice // bits must be the one return by sampleBaseChoiceBits(...). - void setSilentBaseOts(span recvBaseOts, - span voleBase); + void setSilentBaseOts( + span recvBaseOts, + VecF& baseA) + { + if (isConfigured() == false) + throw std::runtime_error("configure(...) must be called first."); + + if (static_cast(recvBaseOts.size()) != silentBaseOtCount()) + throw std::runtime_error("wrong number of silent base OTs"); + + mGen.setBase(recvBaseOts); + + mCtx.resize(mBaseA, baseA.size()); + mCtx.copy(baseA.begin(), baseA.end(), mBaseA.begin()); + mState = State::HasBase; + } // Perform the actual OT extension. If silent // base OTs have been generated or set, then // this function is non-interactive. Otherwise // the silent base OTs will automatically be performed. task<> silentReceive( - span c, - span a, - PRNG & prng, - Socket & chl); + VecG& c, + VecF& a, + PRNG& prng, + Socket& chl) + { + MC_BEGIN(task<>, this, &c, &a, &prng, &chl); + if (c.size() != a.size()) + throw RTE_LOC; + + MC_AWAIT(silentReceiveInplace(c.size(), prng, chl)); + + mCtx.copy(mC.begin(), mC.begin() + c.size(), c.begin()); + mCtx.copy(mA.begin(), mA.begin() + a.size(), a.begin()); + + clear(); + MC_END(); + } // Perform the actual OT extension. If silent // base OTs have been generated or set, then @@ -187,23 +455,330 @@ namespace osuCrypto task<> silentReceiveInplace( u64 n, PRNG& prng, - Socket& chl); + Socket& chl) + { + MC_BEGIN(task<>, this, n, &prng, &chl, + myHash = std::array{}, + theirHash = std::array{} + ); + gTimer.setTimePoint("SilentVoleReceiver.ot.enter"); + + if (isConfigured() == false) + { + // first generate 128 normal base OTs + configure(n, SilentBaseType::BaseExtend); + } + + if (mRequestSize != n) + throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); + + if (hasSilentBaseOts() == false) + { + MC_AWAIT(genSilentBaseOts(prng, chl)); + } + + // allocate mA + mCtx.resize(mA, 0); + mCtx.resize(mA, mNoiseVecSize); + + setTimePoint("SilentVoleReceiver.alloc"); + + // allocate the space for mC + mCtx.resize(mC, 0); + mCtx.resize(mC, mNoiseVecSize); + mCtx.zero(mC.begin(), mC.end()); + setTimePoint("SilentVoleReceiver.alloc.zero"); + + if (mTimer) + mGen.setTimer(*mTimer); + + // As part of the setup, we have generated + // + // mBaseA + mBaseB = mBaseC * mDelta + // + // We have mBaseA, mBaseC, + // they have mBaseB, mDelta + // This was done with a small (noisy) vole. + // + // We use the Pprf to expand as + // + // mA' = mB + mS(mBaseB) + // = mB + mS(mBaseC * mDelta - mBaseA) + // = mB + mS(mBaseC * mDelta) - mS(mBaseA) + // + // Therefore if we add mS(mBaseA) to mA' we will get + // + // mA = mB + mS(mBaseC * mDelta) + // + MC_AWAIT(mGen.expand(chl, mA, PprfOutputFormat::Interleaved, true, mNumThreads)); + + setTimePoint("SilentVoleReceiver.expand.pprf_transpose"); + + // populate the noisy coordinates of mC and + // update mA to be a secret share of mC * delta + for (u64 i = 0; i < mNumPartitions; ++i) + { + auto pnt = mS[i]; + mCtx.copy(mC[pnt], mBaseC[i]); + mCtx.plus(mA[pnt], mA[pnt], mBaseA[i]); + } + + if (mDebug) + { + MC_AWAIT(checkRT(chl)); + setTimePoint("SilentVoleReceiver.expand.checkRT"); + } + + + if (mMalType == SilentSecType::Malicious) + { + MC_AWAIT(chl.send(std::move(mMalCheckSeed))); + + if constexpr (MaliciousSupported) + myHash = ferretMalCheck(); + else + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + + MC_AWAIT(chl.recv(theirHash)); + + if (theirHash != myHash) + throw RTE_LOC; + } + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + u64 expanderWeight, accumulatorWeight; + double _; + ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _); + ExConvCode encoder; + if (mScaler * mRequestSize > mNoiseVecSize) + throw RTE_LOC; + encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); + + if (mTimer) + encoder.setTimer(getTimer()); + + encoder.dualEncode2( + mA.begin(), + mC.begin(), + {} + ); + break; + } + case osuCrypto::MultType::QuasiCyclic: + { +#ifdef ENABLE_BITPOLYMUL + if constexpr ( + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + QuasiCyclicCode encoder; + encoder.init2(mRequestSize, mNoiseVecSize); + encoder.dualEncode(mA); + encoder.dualEncode(mC); + } + else + throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); +#else + throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); +#endif + break; + } + default: + throw std::runtime_error("Code is not supported. " LOCATION); + break; + } + // resize the buffers down to only contain the real elements. + mCtx.resize(mA, mRequestSize); + mCtx.resize(mC, mRequestSize); + mBaseC = {}; + mBaseA = {}; + + // make the protocol as done and that + // mA,mC are ready to be consumed. + mState = State::Default; + + MC_END(); + } - // internal. - task<> checkRT(Socket& chls) const; - std::array ferretMalCheck( - block deltaShare, - span y); - PprfOutputFormat getPprfFormat() + // internal. + task<> checkRT(Socket& chl) { - return PprfOutputFormat::Interleaved; + MC_BEGIN(task<>, this, &chl, + B = VecF{}, + sparseNoiseDelta = VecF{}, + baseB = VecF{}, + delta = VecF{}, + tempF = VecF{}, + tempG = VecG{}, + buffer = std::vector{} + ); + + // recv delta + buffer.resize(mCtx.template byteSize()); + mCtx.resize(delta, 1); + MC_AWAIT(chl.recv(buffer)); + mCtx.deserialize(buffer.begin(), buffer.end(), delta.begin()); + + // recv B + buffer.resize(mCtx.template byteSize() * mA.size()); + mCtx.resize(B, mA.size()); + MC_AWAIT(chl.recv(buffer)); + mCtx.deserialize(buffer.begin(), buffer.end(), B.begin()); + + // recv the noisy values. + buffer.resize(mCtx.template byteSize() * mBaseA.size()); + mCtx.resize(baseB, mBaseA.size()); + MC_AWAIT(chl.recvResize(buffer)); + mCtx.deserialize(buffer.begin(), buffer.end(), baseB.begin()); + + // it shoudl hold that + // + // mBaseA = baseB + mBaseC * mDelta + // + // and + // + // mA = mB + mC * mDelta + // + { + bool verbose = false; + bool failed = false; + std::vector index(mS.size()); + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(), + [&](std::size_t i, std::size_t j) { return mS[i] < mS[j]; }); + + mCtx.resize(tempF, 2); + mCtx.resize(tempG, 1); + mCtx.zero(tempG.begin(), tempG.end()); + + + // check the correlation that + // + // mBaseA + mBaseB = mBaseC * mDelta + for (auto i : rng(mBaseA.size())) + { + // temp[0] = baseB[i] + mBaseA[i] + mCtx.plus(tempF[0], baseB[i], mBaseA[i]); + + // temp[1] = mBaseC[i] * delta[0] + mCtx.mul(tempF[1], delta[0], mBaseC[i]); + + if (!mCtx.eq(tempF[0], tempF[1])) + throw RTE_LOC; + + if (i < mNumPartitions) + { + //auto idx = index[i]; + auto point = mS[i]; + if (!mCtx.eq(mBaseC[i], mC[point])) + throw RTE_LOC; + + if (i && mS[index[i - 1]] >= mS[index[i]]) + throw RTE_LOC; + } + } + + + auto iIter = index.begin(); + auto leafIdx = mS[*iIter]; + F act = tempF[0]; + G zero = tempG[0]; + mCtx.zero(tempG.begin(), tempG.end()); + + for (u64 j = 0; j < mA.size(); ++j) + { + mCtx.mul(act, delta[0], mC[j]); + mCtx.plus(act, act, B[j]); + + bool active = false; + if (j == leafIdx) + { + active = true; + } + else if (!mCtx.eq(zero, mC[j])) + throw RTE_LOC; + + if (mA[j] != act) + { + failed = true; + if (verbose) + std::cout << Color::Red; + } + + if (verbose) + { + std::cout << j << " act " << mCtx.str(act) + << " a " << mCtx.str(mA[j]) << " b " << mCtx.str(B[j]); + + if (active) + std::cout << " < " << mCtx.str(delta[0]); + + std::cout << std::endl << Color::Default; + } + + if (j == leafIdx) + { + ++iIter; + if (iIter != index.end()) + { + leafIdx = mS[*iIter]; + } + } + } + + if (failed) + throw RTE_LOC; + } + + MC_END(); } - void clear(); + std::array ferretMalCheck() + { + + block xx = mMalCheckSeed; + block sum0 = ZeroBlock; + block sum1 = ZeroBlock; + + + for (u64 i = 0; i < (u64)mA.size(); ++i) + { + block low, high; + xx.gf128Mul(mA[i], low, high); + sum0 = sum0 ^ low; + sum1 = sum1 ^ high; + //mySum = mySum ^ xx.gf128Mul(mA[i]); + + // xx = mMalCheckSeed^{i+1} + xx = xx.gf128Mul(mMalCheckSeed); + } + + // = < + block mySum = sum0.gf128Reduce(sum1); + + std::array myHash; + RandomOracle ro(32); + ro.Update(mySum ^ mBaseA.back()); + ro.Final(myHash); + return myHash; + } + + void clear() + { + mS = {}; + mA = {}; + mC = {}; + mGen.clear(); + } }; } #endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleSender.cpp b/libOTe/Vole/Silent/SilentVoleSender.cpp deleted file mode 100644 index ef107a62..00000000 --- a/libOTe/Vole/Silent/SilentVoleSender.cpp +++ /dev/null @@ -1,471 +0,0 @@ -#include "libOTe/Vole/Silent/SilentVoleSender.h" -#ifdef ENABLE_SILENT_VOLE - -#include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" -#include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" -#include "libOTe/Vole/Noisy/NoisyVoleSender.h" - -#include "libOTe/Base/BaseOT.h" -#include "libOTe/Tools/Tools.h" -#include "cryptoTools/Common/Log.h" -#include "cryptoTools/Crypto/RandomOracle.h" -#include "libOTe/Tools/LDPC/LdpcSampler.h" - - -namespace osuCrypto -{ - u64 SilentVoleSender::baseOtCount() const - { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtSender.baseOtCount(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - bool SilentVoleSender::hasBaseOts() const - { -#ifdef ENABLE_SOFTSPOKEN_OT - return mOtExtSender.hasBaseOts(); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - // sets the soft spoken base OTs that are then used to extend - void SilentVoleSender::setBaseOts( - span baseRecvOts, - const BitVector& choices) - { -#ifdef ENABLE_SOFTSPOKEN_OT - mOtExtSender.setBaseOts(baseRecvOts, choices); -#else - throw std::runtime_error("soft spoken must be enabled"); -#endif - } - - - task<> SilentVoleSender::genSilentBaseOts(PRNG& prng, Socket& chl, cp::optional delta) - { - -#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE - using BaseOT = McRosRoyTwist; -#elif defined ENABLE_MR - using BaseOT = MasnyRindal; -#elif defined ENABLE_MRR - using BaseOT = McRosRoy; -#else - using BaseOT = DefaultBaseOT; -#endif - - MC_BEGIN(task<>,this, delta, &prng, &chl, - msg = AlignedUnVector>(silentBaseOtCount()), - baseOt = BaseOT{}, - prng2 = std::move(PRNG{}), - xx = BitVector{}, - chl2 = Socket{}, - nv = NoisyVoleSender{}, - noiseDeltaShares = std::vector{} - ); - setTimePoint("SilentVoleSender.genSilent.begin"); - - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - - delta = delta.value_or(prng.get()); - xx.append(delta->data(), 128); - - // compute the correlation for the noisy coordinates. - noiseDeltaShares.resize(baseVoleCount()); - - - if (mBaseType == SilentBaseType::BaseExtend) - { -#ifdef ENABLE_SOFTSPOKEN_OT - - if (mOtExtRecver.hasBaseOts() == false) - { - msg.resize(msg.size() + mOtExtRecver.baseOtCount()); - MC_AWAIT(mOtExtSender.send(msg, prng, chl)); - - mOtExtRecver.setBaseOts( - span>(msg).subspan( - msg.size() - mOtExtRecver.baseOtCount(), - mOtExtRecver.baseOtCount())); - msg.resize(msg.size() - mOtExtRecver.baseOtCount()); - - MC_AWAIT(nv.send(*delta, noiseDeltaShares, prng, mOtExtRecver, chl)); - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - - MC_AWAIT( - macoro::when_all_ready( - nv.send(*delta, noiseDeltaShares, prng2, mOtExtRecver, chl2), - mOtExtSender.send(msg, prng, chl))); - } -#else - -#endif - } - else - { - chl2 = chl.fork(); - prng2.SetSeed(prng.get()); - MC_AWAIT( - macoro::when_all_ready( - nv.send(*delta, noiseDeltaShares, prng2, baseOt, chl2), - baseOt.send(msg, prng, chl))); - } - - - setSilentBaseOts(msg, noiseDeltaShares); - setTimePoint("SilentVoleSender.genSilent.done"); - MC_END(); - } - - u64 SilentVoleSender::silentBaseOtCount() const - { - if (isConfigured() == false) - throw std::runtime_error("configure must be called first"); - - return mGen.baseOtCount() + mGapOts.size(); - } - - void SilentVoleSender::setSilentBaseOts( - span> sendBaseOts, - span noiseDeltaShares) - { - if ((u64)sendBaseOts.size() != silentBaseOtCount()) - throw RTE_LOC; - - if (noiseDeltaShares.size() != baseVoleCount()) - throw RTE_LOC; - - auto genOt = sendBaseOts.subspan(0, mGen.baseOtCount()); - auto gapOt = sendBaseOts.subspan(genOt.size(), mGapOts.size()); - - mGen.setBase(genOt); - std::copy(gapOt.begin(), gapOt.end(), mGapOts.begin()); - mNoiseDeltaShares.resize(noiseDeltaShares.size()); - std::copy(noiseDeltaShares.begin(), noiseDeltaShares.end(), mNoiseDeltaShares.begin()); - } - - void SilentVoleSender::configure( - u64 numOTs, - SilentBaseType type, - u64 secParam) - { - mBaseType = type; - u64 gap = 0; - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - { - u64 p, s; - - QuasiCyclicConfigure(numOTs, secParam, - 2, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - p, - s - ); -#ifdef ENABLE_BITPOLYMUL - mQuasiCyclicEncoder.init(p, s); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - break; - } -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - SilverConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - gap, - mEncoder); - break; -#endif - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - - EAConfigure(numOTs, secParam, - mMultType, - mRequestedNumOTs, - mNumPartitions, - mSizePer, - mN2, - mN, - mEAEncoder); - break; - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - - ExConvConfigure(numOTs, 128, mMultType, mRequestedNumOTs, mNumPartitions, mSizePer, mN2, mN, mExConvEncoder); - break; - default: - throw RTE_LOC; - break; - } - - mGapOts.resize(gap); - mGen.configure(mSizePer, mNumPartitions); - - mState = State::Configured; - } - - //sigma = 0 Receiver - // - // u_i is the choice bit - // v_i = w_i + u_i * x - // - // ------------------------ - - // u' = 0000001000000000001000000000100000...00000, u_i = 1 iff i \in S - // - // v' = r + (x . u') = DPF(k0) - // = r + (000000x00000000000x000000000x00000...00000) - // - // u = u' * H bit-vector * H. Mapping n'->n bits - // v = v' * H block-vector * H. Mapping n'->n block - // - //sigma = 1 Sender - // - // x is the delta - // w_i is the zero message - // - // m_i0 = w_i - // m_i1 = w_i + x - // - // ------------------------ - // x - // r = DPF(k1) - // - // w = r * H - - - task<> SilentVoleSender::checkRT(Socket& chl, block delta) const - { - MC_BEGIN(task<>,this, &chl, delta); - MC_AWAIT(chl.send(delta)); - MC_AWAIT(chl.send(mB)); - MC_AWAIT(chl.send(mNoiseDeltaShares)); - MC_END(); - } - - void SilentVoleSender::clear() - { - mB = {}; - mGen.clear(); - } - - task<> SilentVoleSender::silentSend( - block delta, - span b, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>,this, delta, b, &prng, &chl); - - MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); - - std::memcpy(b.data(), mB.data(), b.size() * sizeof(block)); - clear(); - - setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); - MC_END(); - } - - task<> SilentVoleSender::silentSendInplace( - block delta, - u64 n, - PRNG& prng, - Socket& chl) - { - MC_BEGIN(task<>,this, delta, n, &prng, &chl, - gapVals = std::vector{}, - deltaShare = block{}, - X = block{}, - hash = std::array{}, - noiseShares = span{}, - mbb = span{} - ); - setTimePoint("SilentVoleSender.ot.enter"); - - - if (isConfigured() == false) - { - // first generate 128 normal base OTs - configure(n, SilentBaseType::BaseExtend); - } - - if (mRequestedNumOTs != n) - throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); - - if (mGen.hasBaseOts() == false) - { - // recvs data - MC_AWAIT(genSilentBaseOts(prng, chl, delta)); - } - - //mDelta = delta; - - setTimePoint("SilentVoleSender.start"); - //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); - - if (mMalType == SilentSecType::Malicious) - { - deltaShare = mNoiseDeltaShares.back(); - mNoiseDeltaShares.pop_back(); - } - - // allocate B - mB.resize(0); - mB.resize(mN2); - - // derandomize the random OTs for the gap - // to have the desired correlation. - gapVals.resize(mGapOts.size()); - for (u64 i = mNumPartitions * mSizePer, j = 0; i < mN2; ++i, ++j) - { - auto v = mGapOts[j][0] ^ mNoiseDeltaShares[mNumPartitions + j]; - gapVals[j] = AES(mGapOts[j][1]).ecbEncBlock(ZeroBlock) ^ v; - mB[i] = mGapOts[j][0]; - } - - if(gapVals.size()) - MC_AWAIT(chl.send(std::move(gapVals))); - - - if (mTimer) - mGen.setTimer(*mTimer); - - // program the output the PPRF to be secret shares of - // our secret share of delta * noiseVals. The receiver - // can then manually add their shares of this to the - // output of the PPRF at the correct locations. - noiseShares = span(mNoiseDeltaShares.data(), mNumPartitions); - mbb = mB.subspan(0, mNumPartitions * mSizePer); - MC_AWAIT(mGen.expand(chl, noiseShares, prng.get(), mbb, - PprfOutputFormat::Interleaved, true, mNumThreads)); - - setTimePoint("SilentVoleSender.expand.pprf_transpose"); - if (mDebug) - { - MC_AWAIT(checkRT(chl, delta)); - setTimePoint("SilentVoleSender.expand.checkRT"); - } - - - if (mMalType == SilentSecType::Malicious) - { - MC_AWAIT(chl.recv(X)); - hash = ferretMalCheck(X, deltaShare); - MC_AWAIT(chl.send(std::move(hash))); - } - - switch (mMultType) - { - case osuCrypto::MultType::QuasiCyclic: - -#ifdef ENABLE_BITPOLYMUL - - if (mTimer) - mQuasiCyclicEncoder.setTimer(getTimer()); - - mQuasiCyclicEncoder.dualEncode(mB.subspan(0, mQuasiCyclicEncoder.size())); -#else - throw std::runtime_error("ENABLE_BITPOLYMUL not defined."); -#endif - setTimePoint("SilentVoleSender.expand.QuasiCyclic"); - break; -#ifdef ENABLE_INSECURE_SILVER - case osuCrypto::MultType::slv5: - case osuCrypto::MultType::slv11: - - if (mTimer) - mEncoder.setTimer(getTimer()); - - mEncoder.dualEncode(mB); - setTimePoint("SilentVoleSender.expand.Silver"); - break; -#endif - case osuCrypto::MultType::ExAcc7: - case osuCrypto::MultType::ExAcc11: - case osuCrypto::MultType::ExAcc21: - case osuCrypto::MultType::ExAcc40: - { - if (mTimer) - mEAEncoder.setTimer(getTimer()); - AlignedUnVector B2(mEAEncoder.mMessageSize); - mEAEncoder.dualEncode(mB.subspan(0,mEAEncoder.mCodeSize), B2); - std::swap(mB, B2); - - setTimePoint("SilentVoleSender.expand.Silver"); - break; - } - case osuCrypto::MultType::ExConv7x24: - case osuCrypto::MultType::ExConv21x24: - if (mTimer) - mExConvEncoder.setTimer(getTimer()); - mExConvEncoder.dualEncode(mB.subspan(0, mExConvEncoder.mCodeSize)); - break; - default: - throw RTE_LOC; - break; - } - - - mB.resize(mRequestedNumOTs); - - mState = State::Default; - mNoiseDeltaShares.clear(); - - MC_END(); - } - - std::array SilentVoleSender::ferretMalCheck(block X, block deltaShare) - { - - auto xx = X; - block sum0 = ZeroBlock; - block sum1 = ZeroBlock; - for (u64 i = 0; i < (u64)mB.size(); ++i) - { - block low, high; - xx.gf128Mul(mB[i], low, high); - sum0 = sum0 ^ low; - sum1 = sum1 ^ high; - - xx = xx.gf128Mul(X); - } - - block mySum = sum0.gf128Reduce(sum1); - - std::array myHash; - RandomOracle ro(32); - ro.Update(mySum ^ deltaShare); - ro.Final(myHash); - - return myHash; - //chl.send(myHash); - } -} - -#endif \ No newline at end of file diff --git a/libOTe/Vole/Silent/SilentVoleSender.h b/libOTe/Vole/Silent/SilentVoleSender.h index 99610316..81fcce8f 100644 --- a/libOTe/Vole/Silent/SilentVoleSender.h +++ b/libOTe/Vole/Silent/SilentVoleSender.h @@ -15,24 +15,32 @@ #include #include #include -#include +#include "libOTe/Tools/Pprf/RegularPprf.h" #include #include #include -#include -#include -#include #include -//#define NO_HASH +#include +#include +#include +#include +#include namespace osuCrypto { - + template< + typename F, + typename G = F, + typename Ctx = DefaultCoeffCtx + > class SilentVoleSender : public TimerAdapter { public: static constexpr u64 mScaler = 2; + static constexpr bool MaliciousSupported = + std::is_same_v && std::is_same_v; + enum class State { Default, @@ -40,102 +48,253 @@ namespace osuCrypto HasBase }; + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; State mState = State::Default; - SilentMultiPprfSender mGen; + // the context used to perform F, G operations + Ctx mCtx; + + // the pprf used to generate the noise vector. + RegularPprfSender mGen; - u64 mRequestedNumOTs = 0; - u64 mN2 = 0; - u64 mN = 0; + // the number of correlations requested. + u64 mRequestSize = 0; + + // the length of the noisy vector. + u64 mNoiseVecSize = 0; + + // the weight of the nosy vector u64 mNumPartitions = 0; + + // the size of each regular, weight 1, subvector + // of the noisy vector. mNoiseVecSize = mNumPartions * mSizePer u64 mSizePer = 0; - u64 mNumThreads = 1; - std::vector> mGapOts; - SilentBaseType mBaseType; - //block mDelta; - std::vector mNoiseDeltaShares; - SilentSecType mMalType = SilentSecType::SemiHonest; + // the lpn security parameters + u64 mSecParam = 0; -#ifdef ENABLE_SOFTSPOKEN_OT - SoftSpokenMalOtSender mOtExtSender; - SoftSpokenMalOtReceiver mOtExtRecver; -#endif + // the type of base OT OT that should be performed. + // Base requires more work but less communication. + SilentBaseType mBaseType = SilentBaseType::BaseExtend; - MultType mMultType = DefaultMultType; -#ifdef ENABLE_INSECURE_SILVER - SilverEncoder mEncoder; -#endif - ExConvCode mExConvEncoder; - EACode mEAEncoder; + // the base Vole correlation. To generate the silent vole, + // we must first create a small vole + // mBaseA + mBaseB = mBaseC * mDelta. + // These will be used to initialize the non-zeros of the noisy + // vector. mBaseB is the b in this corrlations. + VecF mBaseB; -#ifdef ENABLE_BITPOLYMUL - QuasiCyclicCode mQuasiCyclicEncoder; -#endif + // the full sized noisy vector. This will initalially be + // sparse with the corrlations + // mA = mB + mC * mDelta + // before it is compressed. + VecF mB; - //span mB; - //u64 mBackingSize = 0; - //std::unique_ptr mBacking; - AlignedUnVector mB; + // determines if the malicious checks are performed. + SilentSecType mMalType = SilentSecType::SemiHonest; - ///////////////////////////////////////////////////// - // The standard OT extension interface - ///////////////////////////////////////////////////// + // A flag to specify the linear code to use + MultType mMultType = DefaultMultType; - // the number of IKNP base OTs that should be set. - u64 baseOtCount() const; - // returns true if the IKNP base OTs are currently set. - bool hasBaseOts() const; + block mDeltaShare; - // sets the IKNP base OTs that are then used to extend - void setBaseOts( - span baseRecvOts, - const BitVector& choices); +#ifdef ENABLE_SOFTSPOKEN_OT + macoro::optional mOtExtSender; + macoro::optional mOtExtRecver; +#endif - // use the default base OT class to generate the - // IKNP base OTs that are required. - task<> genBaseOts(PRNG& prng, Socket& chl) + bool hasSilentBaseOts()const { - return mOtExtSender.genBaseOts(prng, chl); + return mGen.hasBaseOts(); } - ///////////////////////////////////////////////////// - // The native silent OT extension interface - ///////////////////////////////////////////////////// - u64 baseVoleCount() const { - return mNumPartitions + mGapOts.size() + 1 * (mMalType == SilentSecType::Malicious); + u64 baseVoleCount() const + { + return mNumPartitions + 1 * (mMalType == SilentSecType::Malicious); } // Generate the silent base OTs. If the Iknp // base OTs are set then we do an IKNP extend, // otherwise we perform a base OT protocol to // generate the needed OTs. - task<> genSilentBaseOts(PRNG& prng, Socket& chl, cp::optional delta = {}); + task<> genSilentBaseOts(PRNG& prng, Socket& chl, F delta) + { +#ifdef LIBOTE_HAS_BASE_OT + +#if defined ENABLE_MRR_TWIST && defined ENABLE_SSE + using BaseOT = McRosRoyTwist; +#elif defined ENABLE_MR + using BaseOT = MasnyRindal; +#elif defined ENABLE_MRR + using BaseOT = McRosRoy; +#elif defined ENABLE_NP_KYBER + using BaseOT = MasnyRindalKyber; +#else + using BaseOT = DefaultBaseOT; +#endif + + MC_BEGIN(task<>, this, delta, &prng, &chl, + msg = AlignedUnVector>(silentBaseOtCount()), + baseOt = BaseOT{}, + prng2 = std::move(PRNG{}), + xx = BitVector{}, + chl2 = Socket{}, + nv = NoisyVoleSender{}, + b = VecF{} + ); + setTimePoint("SilentVoleSender.genSilent.begin"); + + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + xx = mCtx.template binaryDecomposition(delta); + + // compute the correlation for the noisy coordinates. + b.resize(baseVoleCount()); + + + if (mBaseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + + if (!mOtExtSender) + mOtExtSender = SoftSpokenMalOtSender{}; + if (!mOtExtRecver) + mOtExtRecver = SoftSpokenMalOtReceiver{}; + + if (mOtExtRecver->hasBaseOts() == false) + { + msg.resize(msg.size() + mOtExtRecver->baseOtCount()); + MC_AWAIT(mOtExtSender->send(msg, prng, chl)); + + mOtExtRecver->setBaseOts( + span>(msg).subspan( + msg.size() - mOtExtRecver->baseOtCount(), + mOtExtRecver->baseOtCount())); + msg.resize(msg.size() - mOtExtRecver->baseOtCount()); + + MC_AWAIT(nv.send(delta, b, prng, *mOtExtRecver, chl, mCtx)); + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + + MC_AWAIT( + macoro::when_all_ready( + nv.send(delta, b, prng2, *mOtExtRecver, chl2, mCtx), + mOtExtSender->send(msg, prng, chl))); + } +#else + throw RTE_LOC; +#endif + } + else + { + chl2 = chl.fork(); + prng2.SetSeed(prng.get()); + MC_AWAIT( + macoro::when_all_ready( + nv.send(delta, b, prng2, baseOt, chl2, mCtx), + baseOt.send(msg, prng, chl))); + } + + + setSilentBaseOts(msg, b); + setTimePoint("SilentVoleSender.genSilent.done"); + MC_END(); +#else + throw std::runtime_error("LIBOTE_HAS_BASE_OT = false, must enable relic, sodium or simplest ot asm." LOCATION); +#endif + } // configure the silent OT extension. This sets // the parameters and figures out how many base OT // will be needed. These can then be ganerated for // a different OT extension or using a base OT protocol. void configure( - u64 n, - SilentBaseType baseType = SilentBaseType::BaseExtend, - u64 secParam = 128); + u64 requestSize, + SilentBaseType type = SilentBaseType::BaseExtend, + u64 secParam = 128, + Ctx ctx = {}) + { + mCtx = std::move(ctx); + mSecParam = secParam; + mRequestSize = requestSize; + mState = State::Configured; + mBaseType = type; + double minDist = 0; + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + u64 _1, _2; + ExConvConfigure(mScaler, mMultType, _1, _2, minDist); + break; + } + case MultType::QuasiCyclic: + QuasiCyclicConfigure(mScaler, minDist); + break; + default: + throw RTE_LOC; + break; + } + + mNumPartitions = getRegNoiseWeight(minDist, secParam); + mSizePer = std::max(4, roundUpTo(divCeil(mRequestSize * mScaler, mNumPartitions), 2)); + mNoiseVecSize = mSizePer * mNumPartitions; + + mGen.configure(mSizePer, mNumPartitions); + } // return true if this instance has been configured. bool isConfigured() const { return mState != State::Default; } // Returns how many base OTs the silent OT extension // protocol will needs. - u64 silentBaseOtCount() const; + u64 silentBaseOtCount() const + { + if (isConfigured() == false) + throw std::runtime_error("configure must be called first"); + + return mGen.baseOtCount(); + } // Set the externally generated base OTs. This choice // bits must be the one return by sampleBaseChoiceBits(...). void setSilentBaseOts( span> sendBaseOts, - span sendBaseVole); + const VecF& b) + { + if ((u64)sendBaseOts.size() != silentBaseOtCount()) + throw RTE_LOC; + + if (b.size() != baseVoleCount()) + throw RTE_LOC; + + mGen.setBase(sendBaseOts); + + // we store the negative of b. This is because + // we need the correlation + // + // mBaseA + mBaseB = mBaseC * delta + // + // for the pprf to expand correctly but the + // input correlation is a vole: + // + // mBaseA = b + mBaseC * delta + // + mCtx.resize(mBaseB, b.size()); + mCtx.zero(mBaseB.begin(), mBaseB.end()); + for (u64 i = 0; i < mBaseB.size(); ++i) + mCtx.minus(mBaseB[i], mBaseB[i], b[i]); + } // The native OT extension interface of silent // OT. The receiver does not get to specify @@ -143,10 +302,22 @@ namespace osuCrypto // the protocol picks them at random. Use the // send(...) interface for the normal behavior. task<> silentSend( - block delta, - span b, + F delta, + VecF& b, PRNG& prng, - Socket& chls); + Socket& chl) + { + MC_BEGIN(task<>, this, delta, &b, &prng, &chl); + + MC_AWAIT(silentSendInplace(delta, b.size(), prng, chl)); + + mCtx.copy(mB.begin(), mB.begin() + b.size(), b.begin()); + //std::memcpy(b.data(), mB.data(), b.size() * mCtx.bytesF); + clear(); + + setTimePoint("SilentVoleSender.expand.ldpc.msgCpy"); + MC_END(); + } // The native OT extension interface of silent // OT. The receiver does not get to specify @@ -154,18 +325,170 @@ namespace osuCrypto // the protocol picks them at random. Use the // send(...) interface for the normal behavior. task<> silentSendInplace( - block delta, + F delta, u64 n, PRNG& prng, - Socket& chls); + Socket& chl) + { + MC_BEGIN(task<>, this, delta, n, &prng, &chl, + deltaShare = block{}, + X = block{}, + hash = std::array{}, + baseB = VecF{} + ); + setTimePoint("SilentVoleSender.ot.enter"); + + + if (isConfigured() == false) + { + // first generate 128 normal base OTs + configure(n, SilentBaseType::BaseExtend); + } + + if (mRequestSize != n) + throw std::invalid_argument("n does not match the requested number of OTs via configure(...). " LOCATION); + + if (mGen.hasBaseOts() == false) + { + // recvs data + MC_AWAIT(genSilentBaseOts(prng, chl, delta)); + } + + setTimePoint("SilentVoleSender.start"); + //gTimer.setTimePoint("SilentVoleSender.iknp.base2"); + + // allocate B + mCtx.resize(mB, 0); + mCtx.resize(mB, mNoiseVecSize); + + if (mTimer) + mGen.setTimer(*mTimer); + + // extract just the first mNumPartitions value of mBaseB. + // the last is for the malicious check (if present). + mCtx.resize(baseB, mNumPartitions); + mCtx.copy(mBaseB.begin(), mBaseB.begin() + mNumPartitions, baseB.begin()); + + // program the output the PPRF to be secret shares of + // our secret share of delta * noiseVals. The receiver + // can then manually add their shares of this to the + // output of the PPRF at the correct locations. + MC_AWAIT(mGen.expand(chl, baseB, prng.get(), mB, + PprfOutputFormat::Interleaved, true, 1)); + setTimePoint("SilentVoleSender.expand.pprf"); + + if (mDebug) + { + MC_AWAIT(checkRT(chl, delta)); + setTimePoint("SilentVoleSender.expand.checkRT"); + } + + if (mMalType == SilentSecType::Malicious) + { + MC_AWAIT(chl.recv(X)); + + if constexpr (MaliciousSupported) + hash = ferretMalCheck(X); + else + throw std::runtime_error("malicious is currently only supported for GF128 block. " LOCATION); + + MC_AWAIT(chl.send(std::move(hash))); + } + + switch (mMultType) + { + case osuCrypto::MultType::ExConv7x24: + case osuCrypto::MultType::ExConv21x24: + { + ExConvCode encoder; + u64 expanderWeight, accumulatorWeight; + double _1; + ExConvConfigure(mScaler, mMultType, expanderWeight, accumulatorWeight, _1); + if (mScaler * mRequestSize > mNoiseVecSize) + throw RTE_LOC; + encoder.config(mRequestSize, mScaler * mRequestSize, expanderWeight, accumulatorWeight); + if (mTimer) + encoder.setTimer(getTimer()); + encoder.dualEncode(mB.begin(), mCtx); + break; + } + case MultType::QuasiCyclic: + { +#ifdef ENABLE_BITPOLYMUL + if constexpr ( + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + QuasiCyclicCode encoder; + encoder.init2(mRequestSize, mNoiseVecSize); + encoder.dualEncode(mB); + } + else + throw std::runtime_error("QuasiCyclic is only supported for GF128, i.e. block. " LOCATION); +#else + throw std::runtime_error("QuasiCyclic requires ENABLE_BITPOLYMUL = true. " LOCATION); +#endif + + break; + } + default: + throw std::runtime_error("Code is not supported. " LOCATION); + break; + } + + mCtx.resize(mB, mRequestSize); + + + mState = State::Default; + mBaseB.clear(); + + MC_END(); + } bool mDebug = false; - task<> checkRT(Socket& chl, block delta) const; + task<> checkRT(Socket& chl, F delta) const + { + MC_BEGIN(task<>, this, &chl, delta); + MC_AWAIT(chl.send(delta)); + MC_AWAIT(chl.send(mB)); + MC_AWAIT(chl.send(mBaseB)); + MC_END(); + } + + std::array ferretMalCheck(block X) + { + + auto xx = X; + block sum0 = ZeroBlock; + block sum1 = ZeroBlock; + for (u64 i = 0; i < (u64)mB.size(); ++i) + { + block low, high; + xx.gf128Mul(mB[i], low, high); + sum0 = sum0 ^ low; + sum1 = sum1 ^ high; + + xx = xx.gf128Mul(X); + } + + block mySum = sum0.gf128Reduce(sum1); - std::array ferretMalCheck(block X, block deltaShare); + std::array myHash; + RandomOracle ro(32); + ro.Update(mySum ^ mBaseB.back()); + ro.Final(myHash); - void clear(); + return myHash; + //chl.send(myHash); + } + + void clear() + { + mB = {}; + mGen.clear(); + } }; } diff --git a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp index f8ad9f5a..ea4d17ea 100644 --- a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp +++ b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.cpp @@ -315,7 +315,7 @@ namespace osuCrypto //mNumVolesPadded = (computeNumVolesPadded(fieldBits_, numVoles_)); mGenerateFn = (selectGenerateImpl(mFieldBits)); if (!mPprf) - mPprf.reset(new SilentMultiPprfSender); + mPprf.reset(new PprfSender); } void SmallFieldVoleReceiver::init(u64 fieldBits_, u64 numVoles_, bool malicious) @@ -323,7 +323,7 @@ namespace osuCrypto SmallFieldVoleBase::init(fieldBits_, numVoles_, malicious); mGenerateFn = (selectGenerateImpl(mFieldBits)); if (!mPprf) - mPprf.reset(new SilentMultiPprfReceiver); + mPprf.reset(new PprfReceiver); } void SmallFieldVoleReceiver::setDelta(BitVector delta_) @@ -347,26 +347,9 @@ namespace osuCrypto std::copy(seeds_.begin(), seeds_.end(), mSeeds.data()); } - // The choice bits (as expected by SilentMultiPprfReceiver) store the locations of the complements - // of the active paths (because they tell which messages were transferred, not which ones weren't), - // in big endian. We want delta, which is the locations of the active paths, in little endian. - static BitVector choicesToDelta(const BitVector& choices, u64 fieldBits, u64 numVoles) - { - if ((u64)choices.size() != numVoles * fieldBits) - throw RTE_LOC; - - BitVector delta(choices.size()); - for (u64 i = 0; i < numVoles; ++i) - for (u64 j = 0; j < fieldBits; ++j) - delta[i * fieldBits + j] = 1 ^ choices[(i + 1) * fieldBits - j - 1]; - return delta; - } - - void SmallFieldVoleReceiver::setBaseOts(span baseMessages, const BitVector& choices) { - setDelta(choicesToDelta(choices, mFieldBits, mNumVoles)); - + setDelta(choices); if (!mPprf) throw RTE_LOC; @@ -379,7 +362,7 @@ namespace osuCrypto throw RTE_LOC; mPprf->setBase(baseMessages); - mPprf->setChoiceBits(PprfOutputFormat::ByTreeIndex, choices); + mPprf->setChoiceBits(choices); } void SmallFieldVoleSender::setBaseOts(span> msgs) @@ -404,16 +387,19 @@ namespace osuCrypto MC_BEGIN(task<>, this, &chl, &prng, numThreads, corrections = std::vector>{}, hashes = std::vector>{}, - seedView = MatrixView{} + seedView = MatrixView{}, + _ = AlignedUnVector{} ); assert(mSeeds.size() == 0 && mNumVoles && mNumVoles <= mNumVolesPadded); mSeeds.resize(mNumVolesPadded * fieldSize()); std::fill(mSeeds.begin(), mSeeds.end(), block(0, 0)); + mSeeds.resize(mNumVoles * fieldSize()); + MC_AWAIT(mPprf->expand(chl, _, prng.get(), mSeeds, PprfOutputFormat::ByTreeIndex, false, 1)); + mSeeds.resize(mNumVolesPadded * fieldSize()); seedView = MatrixView(mSeeds.data(), mNumVoles, fieldSize()); - MC_AWAIT(mPprf->expand(chl, span(), prng.get(), seedView, PprfOutputFormat::ByTreeIndex, false, 1)); // Prove consistency if (mMalicious) @@ -452,7 +438,8 @@ namespace osuCrypto task<> SmallFieldVoleReceiver::expand(Socket& chl, PRNG& prng, u64 numThreads) { MC_BEGIN(task<>, this, &chl, &prng, numThreads, - seedsFull = Matrix{}, + seeds = AlignedUnVector{}, + seedsFull = MatrixView{}, totals = std::vector>{}, entryHashes = std::vector>{}, corrections = std::vector>{}, @@ -464,9 +451,9 @@ namespace osuCrypto mSeeds.resize(mNumVolesPadded * (fieldSize() - 1)); std::fill(mSeeds.begin(), mSeeds.end(), block(0, 0)); - - seedsFull.resize(mNumVoles, fieldSize()); - MC_AWAIT(mPprf->expand(chl, seedsFull, PprfOutputFormat::ByTreeIndex, false, 1)); + seeds.resize(mNumVoles * fieldSize()); + MC_AWAIT(mPprf->expand(chl, seeds, PprfOutputFormat::ByTreeIndex, false, 1)); + seedsFull = MatrixView(seeds.data(), mNumVoles, fieldSize()); // Check consistency if (mMalicious) diff --git a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h index b4e02937..acb12481 100644 --- a/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h +++ b/libOTe/Vole/SoftSpokenOT/SmallFieldVole.h @@ -16,7 +16,7 @@ #include #include "libOTe/TwoChooseOne/TcoOtDefines.h" #include "libOTe/Tools/Coproto.h" -#include "libOTe/Tools/SilentPprf.h" +#include "libOTe/Tools/Pprf/RegularPprf.h" namespace osuCrypto { @@ -121,9 +121,14 @@ namespace osuCrypto class SmallFieldVoleSender : public SmallFieldVoleBase { + + public: + using PprfSender = RegularPprfSender; + + private: SmallFieldVoleSender(const SmallFieldVoleSender& b) : SmallFieldVoleBase(b) - , mPprf(new SilentMultiPprfSender) + , mPprf(new PprfSender) , mGenerateFn(b.mGenerateFn) {} @@ -134,7 +139,7 @@ namespace osuCrypto // wastes a few AES calls, but saving them wouldn't have helped much because you still have to // pay for the AES latency. - std::unique_ptr mPprf; + std::unique_ptr mPprf; SmallFieldVoleSender() = default; SmallFieldVoleSender(SmallFieldVoleSender&&) = default; @@ -203,17 +208,20 @@ namespace osuCrypto class SmallFieldVoleReceiver : public SmallFieldVoleBase { + public: + using PprfReceiver = RegularPprfReceiver; + private: SmallFieldVoleReceiver(const SmallFieldVoleReceiver& b) : SmallFieldVoleBase(b) - , mPprf(new SilentMultiPprfReceiver) + , mPprf(new PprfReceiver) , mDelta(b.mDelta) , mDeltaUnpacked(b.mDeltaUnpacked) , mGenerateFn(b.mGenerateFn) {} public: - std::unique_ptr mPprf; + std::unique_ptr mPprf; BitVector mDelta; AlignedUnVector mDeltaUnpacked; // Each bit of delta becomes a byte, either 0 or 0xff. diff --git a/libOTe/config.h.in b/libOTe/config.h.in index f9c5bb36..40b1bec8 100644 --- a/libOTe/config.h.in +++ b/libOTe/config.h.in @@ -52,6 +52,8 @@ // build the library with silent vole enabled #cmakedefine ENABLE_SILENT_VOLE @ENABLE_SILENT_VOLE@ +#cmakedefine ENABLE_PPRF @ENABLE_PPRF@ + // build the library with silver codes. #cmakedefine ENABLE_INSECURE_SILVER @ENABLE_INSECURE_SILVER@ diff --git a/libOTe_Tests/EACode_Tests.cpp b/libOTe_Tests/EACode_Tests.cpp index 2aa06829..6347d4d3 100644 --- a/libOTe_Tests/EACode_Tests.cpp +++ b/libOTe_Tests/EACode_Tests.cpp @@ -1,6 +1,7 @@ #include "EACode_Tests.h" #include "libOTe/Tools/EACode/EACode.h" #include +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { @@ -12,82 +13,75 @@ namespace osuCrypto auto n = cmd.getOr("n", k * R); auto bw = cmd.getOr("bw", 7); - bool v = cmd.isSet("v"); EACode code; code.config(k, n, bw); - auto A = code.getA(); - auto B = code.getB(); - auto G = B * A; + //auto A = code.getA(); + //auto B = code.getB(); + //auto G = B * A; std::vector m0(k), m1(k), c(n), c0(n), c1(n), a1(n); std::vector c2(n), m2(k); - if (v) - { - std::cout << "B\n" << B << std::endl << std::endl; - std::cout << "A'\n" << code.getAPar() << std::endl << std::endl; - std::cout << "A\n" << A << std::endl << std::endl; - std::cout << "G\n" << G << std::endl; + //if (v) + //{ + // std::cout << "B\n" << B << std::endl << std::endl; + // std::cout << "A'\n" << code.getAPar() << std::endl << std::endl; + // std::cout << "A\n" << A << std::endl << std::endl; + // std::cout << "G\n" << G << std::endl; - } + //} PRNG prng(ZeroBlock); prng.get(c0.data(), c0.size()); auto a0 = c0; - code.accumulate(a0); - A.multAdd(c0, a1); - //A.leftMultAdd(c0, c1); - if (a0 != a1) + code.accumulate(a0, {}); + CoeffCtxGF128 ctx; + block sum = c0[0]; + for (u64 i = 0; i < a0.size(); ++i) { - if (v) - { + if (a0[i] != sum) + throw RTE_LOC; - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (a0[i]) << " "; - std::cout << "\n"; - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (c1[i]) << " "; - std::cout << "\n"; + if (i + 1 < a0.size()) + { + sum ^= c0[i + 1]; + ctx.mulConst(sum, sum); } - - throw RTE_LOC; } - auto cc = c0; - B.multAdd(cc, m0); - code.expand(cc, m1); - - if (m0 != m1) - throw RTE_LOC; - - m0.resize(0); - m0.resize(k); - - G.multAdd(c0, m0); - - - cc = c0; - code.dualEncode(cc, m1); - - if (m0 != m1) - throw RTE_LOC; + u64 i = 0; + detail::ExpanderModd expanderCoeff(code.mSeed, code.mCodeSize); + auto main = k / 8 * 8; + for (; i < main; i += 8) + { + for (u64 j = 0; j < code.mExpanderWeight; ++j) + { + for (u64 p = 0; p < 8; ++p) + { + auto idx = expanderCoeff.get(); + m0[i + p] = m0[i + p] ^ a0[idx]; + } + } + } - cc = c0; - for (u64 i = 0; i < code.mCodeSize; ++i) - c2[i] = c0[i].get(0); + for (; i < k; ++i) + { + for (u64 j = 0; j < code.mExpanderWeight; ++j) + { + auto idx = expanderCoeff.get(); + m0[i] = m0[i] ^ a0[idx]; + } + } - code.dualEncode2(cc, m1, c2, m2); + code.dualEncode(c0, m1, {}); if (m0 != m1) throw RTE_LOC; - for (u64 i = 0; i < code.mMessageSize; ++i) - m2[i] = m0[i].get(0); - } } \ No newline at end of file diff --git a/libOTe_Tests/ExConvCode_Tests.cpp b/libOTe_Tests/ExConvCode_Tests.cpp index 22565f4e..49be53ef 100644 --- a/libOTe_Tests/ExConvCode_Tests.cpp +++ b/libOTe_Tests/ExConvCode_Tests.cpp @@ -2,221 +2,272 @@ #include "libOTe/Tools/ExConvCode/ExConvCode.h" #include "libOTe/Tools/ExConvCode/ExConvCode.h" #include +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { - void ExConvCode_encode_basic_test(const oc::CLP& cmd) + + std::ostream& operator<<(std::ostream& o, const std::array&a) { + o << "{" << a[0] << " " << a[1] << " " << a[2] << "}"; + return o; + } - auto k = cmd.getOr("k", 16ul); - auto R = cmd.getOr("R", 2.0); - auto n = cmd.getOr("n", k * R); - auto bw = cmd.getOr("bw", 7); - auto aw = cmd.getOr("aw", 8); + struct mtxPrint + { + mtxPrint(AlignedUnVector& d, u64 r, u64 c) + :mData(d) + , mRows(r) + , mCols(c) + {} - bool v = cmd.isSet("v"); + AlignedUnVector& mData; + u64 mRows, mCols; + }; - for (auto sys : {/* false,*/ true }) + std::ostream& operator<<(std::ostream& o, const mtxPrint& m) + { + for (u64 i = 0; i < m.mRows; ++i) { + for (u64 j = 0; j < m.mCols; ++j) + { + bool color = (int)m.mData[i * m.mCols + j] && &o == &std::cout; + if (color) + o << Color::Green; + o << (int)m.mData[i * m.mCols + j] << " "; + if (color) + o << Color::Default; + } + o << std::endl; + } + return o; + } - ExConvCode code; - code.config(k, n, bw, aw, sys); + template + void exConvTest(u64 k, u64 n, u64 bw, u64 aw, bool sys) + { - auto A = code.getA(); - auto B = code.getB(); - auto G = B * A; + ExConvCode code; + code.config(k, n, bw, aw, sys); - std::vector m0(k), m1(k), a1(n); + auto accOffset = sys * k; + std::vector x1(n), x2(n), x3(n), x4(n); + PRNG prng(CCBlock); - if (v) - { - std::cout << "B\n" << B << std::endl << std::endl; - std::cout << "A'\n" << code.getAPar() << std::endl << std::endl; - std::cout << "A\n" << A << std::endl << std::endl; - std::cout << "G\n" << G << std::endl; + for (u64 i = 0; i < x1.size(); ++i) + { + x1[i] = x2[i] = x3[i] = prng.get(); + } + CoeffCtx ctx; + std::vector rand(divCeil(aw, 8)); + for (i64 i = 0; i < i64(x1.size() - aw - 1); ++i) + { + prng.get(rand.data(), rand.size()); + code.accOneGen(x1.data() + i, x1.data()+n, rand.data(), ctx); + + if (aw == 16) + code.accOne(x2.data() + i, x2.data()+n, rand.data(), ctx); - } - const auto c0 = [&]() { - std::vector c0(n); - PRNG prng(ZeroBlock); - prng.get(c0.data(), c0.size()); - return c0; - }(); - - auto a0 = c0; - auto aa0 = a0; - std::vector aa1(n); - for (u64 i = 0; i < n; ++i) + ctx.plus(x3[i + 1], x3[i + 1], x3[i]); + //std::cout << "x" << i + 1 << " " << x3[i + 1] << " -> "; + ctx.mulConst(x3[i + 1], x3[i + 1]); + //std::cout << x3[i + 1] << std::endl;; + assert(aw <= 64); + u64 bits = 0; + memcpy(&bits, rand.data(), std::min(rand.size(), 8)); + for (u64 j = 0; j < aw && (i + j + 2) < x3.size(); ++j) { - aa1[i] = aa0[i].get(0); + if (bits & 1) + { + ctx.plus(x3[i + j + 2], x3[i + j + 2], x3[i]); + } + bits >>= 1; } - if (code.mSystematic) + + for (u64 j = i; j < x1.size() && j < i + aw + 2; ++j) { - code.accumulate(span(a0.begin() + k, a0.begin() + n)); - code.accumulate( - span(aa0.begin() + k, aa0.begin() + n), - span(aa1.begin() + k, aa1.begin() + n) - ); + if (aw == 16 && x1[j] != x2[j]) + { + std::cout << j << " " << ctx.str(x1[j]) << " " << ctx.str(x2[j]) << std::endl; + throw RTE_LOC; + } - for (u64 i = 0; i < n; ++i) + if (x1[j] != x3[j]) { - if (aa0[i] != a0[i]) - throw RTE_LOC; - if (aa1[i] != a0[i].get(0)) - throw RTE_LOC; + std::cout << j << " " << ctx.str(x1[j]) << " " << ctx.str(x3[j]) << std::endl; + throw RTE_LOC; } } - else - { - code.accumulate(a0); - } - A.multAdd(c0, a1); - //A.leftMultAdd(c0, a1); - if (a0 != a1) + } + + + x4 = x1; + //std::cout << std::endl; + + code.accumulateFixed(x1.data() + accOffset, ctx); + + if (aw == 16) + { + code.accumulateFixed(x2.data() + accOffset, ctx); + + if (x1 != x2) { - if (v) + for (u64 i = 0; i < x1.size(); ++i) { - - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (a0[i]) << " "; - std::cout << "\n"; - for (u64 i = 0; i < k; ++i) - std::cout << std::hex << std::setw(2) << std::setfill('0') << (a1[i]) << " "; - std::cout << "\n"; + std::cout << i << " " << ctx.str(x1[i]) << " " << ctx.str(x2[i]) << std::endl; } - throw RTE_LOC; } + } + { + PRNG coeffGen(code.mSeed ^ OneBlock); + u8* mtxCoeffIter = (u8*)coeffGen.mBuffer.data(); + auto mtxCoeffEnd = mtxCoeffIter + coeffGen.mBuffer.size() * sizeof(block) - divCeil(aw, 8); - - for (u64 q = 0; q < n; ++q) + auto xi = x3.data() + accOffset; + auto end = x3.data() + n; + while (xi < end) { - std::vector c0(n); - c0[q] = AllOneBlock; - - //auto q = 0; - auto cc = c0; - auto cc1 = c0; - auto mm1 = m1; - - std::vector cc2(cc1.size()), mm2(mm1.size()); - for (u64 i = 0; i < n; ++i) - cc2[i] = cc1[i].get(0); - for (u64 i = 0; i < k; ++i) - mm2[i] = mm1[i].get(0); - //std::vector cc(n); - //cc[q] = AllOneBlock; - std::fill(m0.begin(), m0.end(), ZeroBlock); - B.multAdd(cc, m0); - - - if (code.mSystematic) + if (mtxCoeffIter > mtxCoeffEnd) { - std::copy(cc.begin(), cc.begin() + k, m1.begin()); - code.mExpander.expand( - span(cc.begin() + k, cc.end()), - m1); - //for (u64 i = 0; i < k; ++i) - // m1[i] ^= cc[i]; - std::copy(cc1.begin(), cc1.begin() + k, mm1.begin()); - std::copy(cc2.begin(), cc2.begin() + k, mm2.begin()); - - code.mExpander.expand( - span(cc1.begin() + k, cc1.end()), - span(cc2.begin() + k, cc2.end()), - mm1, mm2); + // generate more mtx coefficients + ExConvCode::refill(coeffGen); + mtxCoeffIter = (u8*)coeffGen.mBuffer.data(); } - else + + // add xi to the next positions + auto xj = xi + 1; + if (xj != end) { - code.mExpander.expand(cc, m1); + ctx.plus(*xj, *xj, *xi); + ctx.mulConst(*xj, *xj); + ++xj; } - if (m0 != m1) + //assert((mtxCoeffEnd - mtxCoeffIter) * 8 >= aw); + u64 bits = 0; + memcpy(&bits, mtxCoeffIter, divCeil(aw,8)); + for (u64 j = 0; j < aw && xj != end; ++j, ++xj) { - - std::cout << "B\n" << B << std::endl << std::endl; - for (u64 i = 0; i < n; ++i) - std::cout << (c0[i].get(0) & 1) << " "; - std::cout << std::endl; - - std::cout << "exp act " << q << "\n"; - for (u64 i = 0; i < k; ++i) + if (bits &1) { - std::cout << (m0[i].get(0) & 1) << " " << (m1[i].get(0) & 1) << std::endl; + ctx.plus(*xj, *xj, *xi); } - throw RTE_LOC; + bits >>= 1; } + ++mtxCoeffIter; - if (code.mSystematic) - { - if (mm1 != m1) - throw RTE_LOC; - - for (u64 i = 0; i < k; ++i) - if (mm2[i] != m1[i].get(0)) - throw RTE_LOC; - } + ++xi; } + } - //for (u64 q = 0; q < n; ++q) + if (x1 != x3) + { + for (u64 i = 0; i < x1.size(); ++i) { - auto q = 0; + std::cout << i << " " << ctx.str(x1[i]) << " " << ctx.str(x3[i]) << std::endl; + } + throw RTE_LOC; + } - //std::fill(c0.begin(), c0.end(), ZeroBlock); - //c0[q] = AllOneBlock; - auto cc = c0; - auto cc1 = c0; - std::vector cc2(cc1.size()); - for (u64 i = 0; i < n; ++i) - cc2[i] = cc1[i].get(0); + detail::ExpanderModd expanderCoeff(code.mExpander.mSeed, code.mExpander.mCodeSize); + std::vector y1(k), y2(k); - std::fill(m0.begin(), m0.end(), ZeroBlock); - G.multAdd(c0, m0); + if (sys) + { + std::copy(x1.data(), x1.data() + k, y1.data()); + y2 = y1; + code.mExpander.expand(x1.data() + accOffset, y1.data()); + //using P = std::pair::const_iterator, typename std::vector::iterator>; + //auto p = P{ x1.cbegin() + accOffset, y1.begin() }; + //code.mExpander.expandMany( + // std::tuple

{ p } + //); + } + else + { + code.mExpander.expand(x1.data() + accOffset, y1.data()); + } - if (code.mSystematic) - { - code.dualEncode(cc); - std::copy(cc.begin(), cc.begin() + k, m1.begin()); - } - else + u64 i = 0; + auto main = k / 8 * 8; + for (; i < main; i += 8) + { + for (u64 j = 0; j < code.mExpander.mExpanderWeight; ++j) + { + for (u64 p = 0; p < 8; ++p) { - code.dualEncode(cc, m1); + auto idx = expanderCoeff.get(); + ctx.plus(y2[i + p], y2[i + p], x1[idx + accOffset]); } + } + } - if (m0 != m1) - { - std::cout << "G\n" << G << std::endl << std::endl; - for (u64 i = 0; i < n; ++i) - std::cout << (c0[i].get(0) & 1) << " "; - std::cout << std::endl; + for (; i < k; ++i) + { + for (u64 j = 0; j < code.mExpander.mExpanderWeight; ++j) + { + auto idx = expanderCoeff.get(); + ctx.plus(y2[i], y2[i], x1[idx + accOffset]); + } + } - std::cout << "exp act " << q << "\n"; - for (u64 i = 0; i < k; ++i) - { - std::cout << (m0[i].get(0) & 1) << " " << (m1[i].get(0) & 1) << std::endl; - } - throw RTE_LOC; - } + if (y1 != y2) + throw RTE_LOC; + code.dualEncode(x4.begin(), {}); - if (code.mSystematic) - { - code.dualEncode2(cc1, cc2); + x4.resize(k); + if (x4 != y1) + throw RTE_LOC; + } - for (u64 i = 0; i < k; ++i) - { - if (cc1[i] != cc[i]) - throw RTE_LOC; - if (cc2[i] != cc[i].get(0)) - throw RTE_LOC; - } - } - } + //block mult2(block x, int imm8) + //{ + // assert(imm8 < 2); + // if (imm8) + // { + // // mult x[1] * 2 + + // } + // else + // { + // // x[0] * 2 + // __m128i carry = _mm_slli_si128(x, 8); + // carry = _mm_srli_epi64(carry, 63); + // x = _mm_slli_epi64(x, 1); + // return _mm_or_si128(x, carry); + + // //return _mm_slli_si128(x, 8); + // } + // //TEMP[i] : = (TEMP1[0] and TEMP2[i]) + // // FOR j : = 1 to i + // // TEMP[i] : = TEMP[i] XOR(TEMP1[j] AND TEMP2[i - j]) + // // ENDFOR + // //dst[i] : = TEMP[i] + //} + + + void ExConvCode_encode_basic_test(const oc::CLP& cmd) + { + auto K = cmd.getManyOr("k", { 32ul, 333 }); + auto R = cmd.getManyOr("R", { 2.0, 3.0 }); + auto Bw = cmd.getManyOr("bw", { 7, 21 }); + auto Aw = cmd.getManyOr("aw", { 16, 24, 29 }); + + //bool v = cmd.isSet("v"); + for (auto k : K) for (auto r : R) for (auto bw : Bw) for (auto aw : Aw) for (auto sys : { false, true }) + { + auto n = k * r; + exConvTest(k, n, bw, aw, sys); + exConvTest(k, n, bw, aw, sys); + exConvTest(k, n, bw, aw, sys); + exConvTest, CoeffCtxArray>(k, n, bw, aw, sys); } - } + } } \ No newline at end of file diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp new file mode 100644 index 00000000..68eaf4e7 --- /dev/null +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -0,0 +1,471 @@ +#include "Pprf_Tests.h" + +#include "libOTe/Tools/Pprf/RegularPprf.h" +#include "cryptoTools/Common/TestCollection.h" + +#ifdef ENABLE_PPRF +#include "cryptoTools/Common/Log.h" +#include "Common.h" +#include +using namespace osuCrypto; +using namespace tests_libOTe; + + +template +void Tools_Pprf_expandOne_test_impl(u64 domain, bool program) +{ + + u64 depth = log2ceil(domain); + auto pntCount = 8ull; + PRNG prng(CCBlock); + + RegularPprfSender sender; + RegularPprfReceiver recver; + + sender.configure(domain, pntCount); + recver.configure(domain, pntCount); + + F value = prng.get(); + sender.setValue({ &value, 1 }); + + auto numOTs = sender.baseOtCount(); + std::vector> sendOTs(numOTs); + std::vector recvOTs(numOTs); + BitVector recvBits = recver.sampleChoiceBits(prng); + + + prng.get(sendOTs.data(), sendOTs.size()); + for (u64 i = 0; i < numOTs; ++i) + { + recvOTs[i] = sendOTs[i][recvBits[i]]; + } + sender.setBase(sendOTs); + recver.setBase(recvOTs); + + + block seed = CCBlock; + + auto sLevels = std::vector>>{}; + auto rLevels = std::vector>>{}; + auto sBuff = std::vector{}; + auto sSums = span>{}; + auto sLast = span{}; + + pprf::TreeAllocator mTreeAlloc; + sLevels.resize(depth); + rLevels.resize(depth); + + + mTreeAlloc.reserve(2, (1ull << depth) + 2); + + + pprf::allocateExpandTree(mTreeAlloc, sLevels); + pprf::allocateExpandTree(mTreeAlloc, rLevels); + using VecF = typename Ctx::template Vec; + VecF sLeafLevel(8ull * domain); + VecF rLeafLevel(8ull * domain); + u64 leafOffset = 0; + + Ctx ctx; + pprf::allocateExpandBuffer(depth - 1, pntCount, program, sBuff, sSums, sLast, ctx); + + std::vector points(recver.mPntCount); + recver.getPoints(points, PprfOutputFormat::ByLeafIndex); + + sender.expandOne(seed, 0, program, sLevels, sLeafLevel, leafOffset, sSums, sLast, ctx); + recver.expandOne(0, program, rLevels, rLeafLevel, leafOffset, sSums, sLast, points, ctx); + + bool failed = false; + for (u64 i = 0; i < pntCount; ++i) + { + // the index of the leaf node that is active. + auto leafIdx = points[i]; + //std::cout << "active leaf idx = " << leafIdx << std::endl; + for (u64 d = 1; d < depth; ++d) + { + //u64 width = std::min(domain, 1ull << d); + auto width = divCeil(domain, 1ull << (depth - d)); + + // The index of the active child node. + auto activeChildIdx = leafIdx >> (depth - d); + + // The index of the active child node sibling. + + for (u64 j = 0; j < width; ++j) + { + //std::cout + // << " " << sLevels[d][j][i].get()[0] + // << " " << rLevels[d][j][i].get()[0] + // ; + + if (j == activeChildIdx) + { + //std::cout << "*"; + continue; + } + + + if (sLevels[d][j][i] != rLevels[d][j][i]) + { + //std::cout << " < "; + throw RTE_LOC; + failed = true; + } + + //std::cout << ", "; + } + //std::cout << std::endl; + } + + MatrixView sLeaves(sLeafLevel.data(), sLeafLevel.size() / 8, 8); + MatrixView rLeaves(rLeafLevel.data(), rLeafLevel.size() / 8, 8); + + for (u64 j = 0; j < sLeaves.rows(); ++j) + { + if (j == leafIdx) + { + F exp; + ctx.plus(exp, sLeaves(j, i), value); + if (program && exp != rLeaves(j, i)) + { + std::cout << i << " exp " << ctx.str(exp) << " " << ctx.str(rLeaves(j, i)) << std::endl; + throw RTE_LOC; + } + } + else + { + if (sLeaves(j, i) != rLeaves(j, i)) + { + std::cout << "j " << j << " i " << i << " sender " << ctx.str(sLeaves(j, i)) << " recver " << ctx.str(rLeaves(j, i)) << std::endl; + throw RTE_LOC; + } + } + } + } + + if (failed) + throw RTE_LOC; +} + +void Tools_Pprf_expandOne_test(const oc::CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + + for (u64 domain : { 2, 128, 4522}) for (bool program : {true, false}) + { + + Tools_Pprf_expandOne_test_impl(domain, program); + Tools_Pprf_expandOne_test_impl(domain, program); + Tools_Pprf_expandOne_test_impl, u32, CoeffCtxArray>(domain, program); + + } + +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} + + +template +void Tools_Pprf_test_impl( + u64 domain, + u64 numPoints, + bool program, + PprfOutputFormat format, + bool verbose) +{ + + auto threads = 1; + PRNG prng(CCBlock); + using Vec = typename Ctx::template Vec; + + auto sockets = cp::LocalAsyncSocket::makePair(); + + RegularPprfSender sender; + RegularPprfReceiver recver; + Vec delta; + Ctx ctx; + auto seed = prng.get(); + ctx.resize(delta, numPoints * program); + for (u64 i = 0; i < delta.size(); ++i) + ctx.fromBlock(delta[i], seed); + + sender.configure(domain, numPoints); + recver.configure(domain, numPoints); + + auto numOTs = sender.baseOtCount(); + std::vector> sendOTs(numOTs); + std::vector recvOTs(numOTs); + BitVector recvBits = recver.sampleChoiceBits(prng); + + prng.get(sendOTs.data(), sendOTs.size()); + for (u64 i = 0; i < numOTs; ++i) + recvOTs[i] = sendOTs[i][recvBits[i]]; + + sender.setBase(sendOTs); + recver.setBase(recvOTs); + + Vec a(numPoints * domain), a2; + Vec b(numPoints * domain), b2; + if (format == PprfOutputFormat::Callback) + { + a2 = std::move(a); + b2 = std::move(b); + a = {}; + b = {}; + sender.mOutputFn = [&](u64 treeIdx, Vec& data){ + auto offset = treeIdx * domain; + std::copy(data.begin(), data.end(), b2.begin() + offset); + }; + recver.mOutputFn = [&](u64 treeIdx, Vec& data) { + auto offset = treeIdx * domain; + std::copy(data.begin(), data.end(), a2.begin() + offset); + }; + } + + std::vector points(numPoints); + recver.getPoints(points, format); + + // a = b + points * delta + auto p0 = sender.expand(sockets[0], delta, prng.get(), b, format, program, threads); + auto p1 = recver.expand(sockets[1], a, format, program, threads); + + + try + { + eval(p0, p1); + } + catch (std::exception& e) + { + sockets[0].close(); + sockets[1].close(); + macoro::sync_wait(macoro::when_all_ready( + sockets[0].flush(), + sockets[1].flush() + )); + throw; + } + + if (format == PprfOutputFormat::Callback) + { + a = std::move(a2); + b = std::move(b2); + } + + switch (format) + { + case osuCrypto::PprfOutputFormat::ByLeafIndex: + case osuCrypto::PprfOutputFormat::ByTreeIndex: + { + + bool failed = false; + for (u64 j = 0; j < numPoints; ++j) + { + for (u64 i = 0; i < domain; ++i) + { + u64 idx = format == osuCrypto::PprfOutputFormat::ByTreeIndex ? + j * domain + i : + i * numPoints + j; + + F exp; + + if (points[j] == i) + { + if (program) + ctx.plus(exp, b[idx], delta[j]); + else + ctx.zero(&exp, &exp + 1); + } + else + exp = b[idx]; + + if (program && exp != a[idx]) + { + failed = true; + + if (verbose) + std::cout << Color::Red; + } + if (verbose) + { + std::cout << "r[" << j << "][" << i << "] " << exp << " " << ctx.str(a[idx]); + if (points[j] == i) + std::cout << " < "; + + std::cout << std::endl << Color::Default; + } + } + if (verbose) + std::cout << "\n"; + } + + if (failed) + throw RTE_LOC; + + break; + } + case osuCrypto::PprfOutputFormat::Interleaved: + case osuCrypto::PprfOutputFormat::Callback: + { + + bool failed = false; + std::vector index(points.size()); + std::iota(index.begin(), index.end(), 0); + std::sort(index.begin(), index.end(), + [&](std::size_t i, std::size_t j) { return points[i] < points[j]; }); + + auto iIter = index.begin(); + auto leafIdx = points[*iIter]; + F deltaVal; + ctx.zero(&deltaVal, &deltaVal + 1); + if(program) + deltaVal = delta[*iIter]; + + ++iIter; + for (u64 j = 0; j < a.size(); ++j) + { + F exp, act; + + // a = b + points * delta + + // act = a - b + // = point * delta + ctx.minus(act, a[j], b[j]); + ctx.zero(&exp, &exp + 1); + bool active = false; + if (j == leafIdx) + { + active = true; + if (program) + ctx.copy(exp, deltaVal); + else + ctx.minus(exp, exp, b[j]); + } + + if (exp != act) + { + failed = true; + if (verbose) + std::cout << Color::Red; + } + + if (verbose) + { + std::cout << j << " exp " << ctx.str(exp) << " " << ctx.str(act) + << " a " << ctx.str(a[j]) << " b " << ctx.str(b[j]); + + if (active) + std::cout << " < " << deltaVal; + + std::cout << std::endl << Color::Default; + } + + if (j == leafIdx) + { + if (iIter != index.end()) + { + leafIdx = points[*iIter]; + if(program) + deltaVal = delta[*iIter]; + ++iIter; + } + } + } + + if (failed) + throw RTE_LOC; + break; + } + default: + break; + } + + +} + +void Tools_Pprf_inter_test(const CLP& cmd) +{ + auto f = PprfOutputFormat::Interleaved; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } +} + + + +void Tools_Pprf_ByLeafIndex_test(const CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + auto f = PprfOutputFormat::ByLeafIndex; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false */}) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} + + + +void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + + auto f = PprfOutputFormat::ByTreeIndex; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 19}) for (auto p : { true/*, false*/ }) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } + +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} + + + +void Tools_Pprf_callback_test(const oc::CLP& cmd) +{ +#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) + + auto f = PprfOutputFormat::Callback; + auto v = cmd.isSet("v"); + for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false */}) + { + Tools_Pprf_test_impl(d, n, p, f, v); + Tools_Pprf_test_impl(d, n, p, f, v); + } +#else + throw UnitTestSkipped("ENABLE_SILENTOT not defined."); +#endif +} +#else + + +namespace { + void throwDisabled() + { + throw oc::UnitTestSkipped( + "ENABLE_PPRF not defined. " + ); + } +} + +void Tools_Pprf_expandOne_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_inter_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_ByLeafIndex_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd) { throwDisabled(); } +void Tools_Pprf_callback_test(const oc::CLP& cmd) { throwDisabled(); } + +#endif diff --git a/libOTe/Tools/LDPC/LdpcImpulseDist.h b/libOTe_Tests/Pprf_Tests.h similarity index 52% rename from libOTe/Tools/LDPC/LdpcImpulseDist.h rename to libOTe_Tests/Pprf_Tests.h index e3bcfe34..785d6ce5 100644 --- a/libOTe/Tools/LDPC/LdpcImpulseDist.h +++ b/libOTe_Tests/Pprf_Tests.h @@ -1,40 +1,16 @@ #pragma once -// © 2022 Visa. +// © 2020 Peter Rindal. +// © 2022 Visa. // Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -// This code implements features described in [Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes, https://eprint.iacr.org/2021/1150]; the paper is licensed under Creative Commons Attribution 4.0 International Public License (https://creativecommons.org/licenses/by/4.0/legalcode). -#include "libOTe/config.h" -#ifdef ENABLE_LDPC -#include "Mtx.h" -#include "LdpcDecoder.h" +#include -namespace osuCrypto -{ - - - enum class BPAlgo - { - LogBP = 0, - AltLogBP = 1, - MinSum = 2 - }; - - enum class ListDecoder - { - Chase = 0 , - OSD = 1 - }; - - //extern std::vector minCW; - - void LdpcDecode_impulse(const oc::CLP& cmd); - - //u64 impulseDist(LdpcDecoder& D, u64 i, u64 n, u64 k, u64 Ne, u64 maxIter); - //u64 impulseDist(SparseMtx& mH, u64 Ne, u64 w, u64 maxIter, u64 numThreads, bool randImpulse, u64 trials, BPAlgo algo, bool verbose); -} -#endif \ No newline at end of file +void Tools_Pprf_expandOne_test(const oc::CLP& cmd); +void Tools_Pprf_inter_test(const oc::CLP& cmd); +void Tools_Pprf_ByLeafIndex_test(const oc::CLP& cmd); +void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd); +void Tools_Pprf_callback_test(const oc::CLP& cmd); diff --git a/libOTe_Tests/SilentOT_Tests.cpp b/libOTe_Tests/SilentOT_Tests.cpp index e1eb82e1..e95bba79 100644 --- a/libOTe_Tests/SilentOT_Tests.cpp +++ b/libOTe_Tests/SilentOT_Tests.cpp @@ -1,6 +1,5 @@ #include "SilentOT_Tests.h" -#include "libOTe/Tools/SilentPprf.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtSender.h" #include "libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.h" #include @@ -840,491 +839,3 @@ void OtExt_Silent_mal_Test(const oc::CLP& cmd) throw UnitTestSkipped("ENABLE_SILENTOT not defined."); #endif } - - -void Tools_Pprf_expandOne_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - u64 depth = cmd.getOr("d", 4);; - u64 domain = (1ull << depth) * 0.75; - auto pntCount = 8ull; - PRNG prng(CCBlock); - - auto format = PprfOutputFormat::Interleaved; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, pntCount); - recver.configure(domain, pntCount); - - block value = prng.get(); - sender.setValue({ &value, 1 }); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * pntCount, format, prng); - - - prng.get(sendOTs.data(), sendOTs.size()); - for (u64 i = 0; i < numOTs; ++i) - { - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - std::vector points(8); - recver.getPoints(points, PprfOutputFormat::ByLeafIndex); - - block seed = CCBlock; - bool program = true; - - auto sTree = span>{}; - auto sLevels = std::vector>>{}; - auto rTree = span>{}; - auto rLevels = std::vector>>{}; - //auto rBuff = std::vector{}; - auto sBuff = std::vector{}; - auto sSums = span, 2>>{}; - auto sLast = span>{}; - - TreeAllocator mTreeAlloc; - sLevels.resize(depth + 1); - rLevels.resize(depth + 1); - - - mTreeAlloc.reserve(2, (1ull << (depth + 1)) + 2); - allocateExpandTree(depth + 1, mTreeAlloc, sTree, sLevels); - allocateExpandTree(depth + 1, mTreeAlloc, rTree, rLevels); - - - allocateExpandBuffer(depth, program, sBuff, sSums, sLast); - //allocateExpandBuffer(depth, program, rbuff, sSums, sLast); - - recver.mPoints.resize(roundUpTo(recver.mPntCount, 8)); - recver.getPoints(recver.mPoints, PprfOutputFormat::ByLeafIndex); - - sender.expandOne(seed, 0, program, sLevels, sSums, sLast); - recver.expandOne(0, program, rLevels, sSums, sLast); - - bool failed = false; - for (u64 i = 0; i < pntCount; ++i) - { - // the index of the leaf node that is active. - auto leafIdx = points[i]; - //std::cout << leafIdx << std::endl; - for (u64 d = 1; d <= depth; ++d) - { - //u64 width = std::min(domain, 1ull << d); - auto width = divCeil(domain, 1ull << (depth - d)); - - // The index of the active child node. - auto activeChildIdx = leafIdx >> (depth - d); - - // The index of the active child node sibling. - - for (u64 j = 0; j < width; ++j) - { - //std::cout - // << " " << sLevels[d][j][i].get()[0] - // << " " << rLevels[d][j][i].get()[0] - // << ", "; - - if (j == activeChildIdx) - { - //std::cout << "*"; - continue; - } - - if (sLevels[d][j][i] != rLevels[d][j][i]) - { - //std::cout << " < "; - failed = true; - } - } - //std::cout << std::endl; - } - } - - if (failed) - throw RTE_LOC; -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif - } - -void Tools_Pprf_test(const CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - u64 depth = cmd.getOr("d", 3);; - u64 domain = 1ull << depth; - auto threads = cmd.getOr("t", 3ull); - u64 numPoints = cmd.getOr("s", 8); - - PRNG prng(ZeroBlock); - - //IOService ios; - //Session s0(ios, "localhost:1212", SessionMode::Server); - //Session s1(ios, "localhost:1212", SessionMode::Client); - //auto sockets[0] = s0.addChannel(); - //auto sockets[1] = s1.addChannel(); - - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::ByLeafIndex; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain, format, prng); - - prng.get(sendOTs.data(), sendOTs.size()); - //sendOTs[cmd.getOr("i",0)] = prng.get(); - - //recvBits[16] = 1; - for (u64 i = 0; i < numOTs; ++i) - { - //recvBits[i] = 0; - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - Matrix sOut(domain, numPoints); - Matrix rOut(domain, numPoints); - std::vector points(numPoints); - recver.getPoints(points, format); - - auto p0 = sender.expand(sockets[0], { &CCBlock,1 }, prng.get(), sOut, format, true, threads); - auto p1 = recver.expand(sockets[1], rOut, format, true, threads); - - eval(p0, p1); - - bool failed = false; - - - for (u64 j = 0; j < numPoints; ++j) - { - - for (u64 i = 0; i < domain; ++i) - { - - auto exp = sOut(i, j); - if (points[j] == i) - exp = exp ^ CCBlock; - - if (neq(exp, rOut(i, j))) - { - failed = true; - - if (cmd.isSet("v")) - std::cout << Color::Red; - } - if (cmd.isSet("v")) - std::cout << "r[" << j << "][" << i << "] " << exp << " " << rOut(i, j) << std::endl << Color::Default; - } - } - - if (failed) - throw RTE_LOC; - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif -} - -void Tools_Pprf_inter_test(const CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - //u64 depth = 6; - //u64 domain = 13;// (1ull << depth) - 7; - //u64 numPoints = 40; - - u64 domain = cmd.getOr("d", 334); - auto threads = cmd.getOr("t", 3ull); - u64 numPoints = cmd.getOr("s", 5) * 8; - //bool mal = cmd.isSet("mal"); - - PRNG prng(ZeroBlock); - - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::Interleaved; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); - //recvBits.randomize(prng); - - //recvBits[16] = 1; - prng.get(sendOTs.data(), sendOTs.size()); - for (u64 i = 0; i < numOTs; ++i) - { - //recvBits[i] = 0; - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - //auto cols = (numPoints * domain + 127) / 128; - Matrix sOut2(numPoints * domain, 1); - Matrix rOut2(numPoints * domain, 1); - std::vector points(numPoints); - recver.getPoints(points, format); - - - auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng.get(), sOut2, format, true, threads); - auto p1 = recver.expand(sockets[1], rOut2, format, true, threads); - - try - { - - eval(p0, p1); - } - catch (std::exception& e) - { - sockets[0].close(); - sockets[1].close(); - macoro::sync_wait(macoro::when_all_ready( - sockets[0].flush(), - sockets[1].flush() - )); - throw; - } - for (u64 i = 0; i < rOut2.rows(); ++i) - { - sOut2(i) = (sOut2(i) ^ rOut2(i)); - } - - - bool failed = false; - for (u64 i = 0; i < sOut2.rows(); ++i) - { - - auto f = std::find(points.begin(), points.end(), i) != points.end(); - - auto exp = f ? AllOneBlock : ZeroBlock; - - if (neq(sOut2(i), exp)) - { - failed = true; - - if (cmd.getOr("v", 0) > 1) - std::cout << Color::Red; - } - if (cmd.getOr("v", 0) > 1) - std::cout << i << " " << sOut2(i) << " " << exp << std::endl << Color::Default; - } - - if (failed) - throw RTE_LOC; - - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif -} - - - -void Tools_Pprf_blockTrans_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - - u64 depth = cmd.getOr("d", 2);; - u64 domain = 1ull << depth; - auto threads = cmd.getOr("t", 1ull); - u64 numPoints = cmd.getOr("s", 8); - - PRNG prng(ZeroBlock); - - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::ByTreeIndex; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain, format, prng); - - prng.get(sendOTs.data(), sendOTs.size()); - //sendOTs[cmd.getOr("i",0)] = prng.get(); - - //recvBits[16] = 1; - for (u64 i = 0; i < numOTs; ++i) - { - //recvBits[i] = 0; - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - Matrix sOut(numPoints, domain); - Matrix rOut(numPoints, domain); - std::vector points(numPoints); - recver.getPoints(points, format); - - cp::sync_wait(cp::when_all_ready( - sender.expand(sockets[0], span{}, prng.get(), sOut, format, false, threads), - recver.expand(sockets[1], rOut, format, false, threads) - )); - - bool failed = false; - - for (u64 j = 0; j < numPoints; ++j) - { - - for (u64 i = 0; i < domain; ++i) - { - auto ss = sOut(j, i); - auto rr = rOut(j, i); - - if (points[j] == i) - { - if (ss == rr || rr != ZeroBlock) - { - failed = true; - - if (cmd.isSet("v")) - std::cout << Color::Red; - } - - } - else - { - if (ss != rr || rr == ZeroBlock) - { - failed = true; - - if (cmd.isSet("v")) - std::cout << Color::Red; - } - } - if (cmd.isSet("v")) - std::cout << "r[" << j << "][" << i << "] " << ss << " " << rr << std::endl << Color::Default; - } - } - - if (failed) - throw RTE_LOC; - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif -} - - - -void Tools_Pprf_callback_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) || defined(ENABLE_SILENT_VOLE) - - u64 domain = cmd.getOr("d", 512); - auto threads = cmd.getOr("t", 1ull); - u64 numPoints = cmd.getOr("s", 5) * 8; - - PRNG prng(ZeroBlock); - auto sockets = cp::LocalAsyncSocket::makePair(); - - - auto format = PprfOutputFormat::Callback; - SilentMultiPprfSender sender; - SilentMultiPprfReceiver recver; - - sender.configure(domain, numPoints); - recver.configure(domain, numPoints); - - auto numOTs = sender.baseOtCount(); - std::vector> sendOTs(numOTs); - std::vector recvOTs(numOTs); - BitVector recvBits = recver.sampleChoiceBits(domain * numPoints, format, prng); - - prng.get(sendOTs.data(), sendOTs.size()); - for (u64 i = 0; i < numOTs; ++i) - { - recvOTs[i] = sendOTs[i][recvBits[i]]; - } - sender.setBase(sendOTs); - recver.setBase(recvOTs); - - //auto cols = (numPoints * domain + 127) / 128; - Matrix sOut2(numPoints * domain, 1); - Matrix rOut2(numPoints * domain, 1); - std::vector points(numPoints); - recver.getPoints(points, format); - - sender.mOutputFn = [&](u64 treeIdx, span> data) - { - span d = sOut2; - d = d.subspan(treeIdx * data.size()); - d = d.subspan(0, std::min(d.size(), data.size() * 8)); - memcpy(d.data(), data.data(), d.size_bytes()); - }; - recver.mOutputFn = [&](u64 treeIdx, span> data) - { - span d = rOut2; - d = d.subspan(treeIdx * data.size()); - d = d.subspan(0, std::min(d.size(), data.size() * 8)); - memcpy(d.data(), data.data(), d.size_bytes()); - }; - - - auto p0 = sender.expand(sockets[0], { &AllOneBlock,1 }, prng.get(), span{}, format, true, threads); - auto p1 = recver.expand(sockets[1], span{}, format, true, threads); - - eval(p0, p1); - for (u64 i = 0; i < rOut2.rows(); ++i) - { - sOut2(i) = (sOut2(i) ^ rOut2(i)); - } - - bool failed = false; - for (u64 i = 0; i < sOut2.rows(); ++i) - { - - auto f = std::find(points.begin(), points.end(), i) != points.end(); - - auto exp = f ? AllOneBlock : ZeroBlock; - - if (neq(sOut2(i), exp)) - { - failed = true; - - if (cmd.getOr("v", 0) > 1) - std::cout << Color::Red; - } - if (cmd.getOr("v", 0) > 1) - std::cout << i << " " << sOut2(i) << " " << exp << std::endl << Color::Default; - } - - if (failed) - throw RTE_LOC; - -#else - throw UnitTestSkipped("ENABLE_SILENTOT not defined."); -#endif - } diff --git a/libOTe_Tests/SilentOT_Tests.h b/libOTe_Tests/SilentOT_Tests.h index d93e09be..26ff014d 100644 --- a/libOTe_Tests/SilentOT_Tests.h +++ b/libOTe_Tests/SilentOT_Tests.h @@ -9,12 +9,6 @@ #include -void Tools_Pprf_expandOne_test(const oc::CLP& cmd); -void Tools_Pprf_test(const oc::CLP& cmd); -void Tools_Pprf_trans_test(const oc::CLP& cmd); -void Tools_Pprf_inter_test(const oc::CLP& cmd); -void Tools_Pprf_blockTrans_test(const oc::CLP& cmd); -void Tools_Pprf_callback_test(const oc::CLP& cmd); void OtExt_Silent_random_Test(const oc::CLP& cmd); void OtExt_Silent_correlated_Test(const oc::CLP& cmd); diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 1c1e72a0..f2f1bb50 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -11,11 +11,10 @@ #include "libOTe_Tests/SoftSpoken_Tests.h" #include "libOTe_Tests/bitpolymul_Tests.h" #include "libOTe_Tests/Vole_Tests.h" -#include "libOTe/Tools/LDPC/LdpcDecoder.h" -#include "libOTe/Tools/LDPC/LdpcEncoder.h" #include "libOTe_Tests/ExConvCode_Tests.h" #include "libOTe_Tests/EACode_Tests.h" #include "libOTe/Tools/LDPC/Mtx.h" +#include "libOTe_Tests/Pprf_Tests.h" using namespace osuCrypto; namespace tests_libOTe @@ -41,28 +40,17 @@ namespace tests_libOTe tc.add("Mtx_add_test ", tests::Mtx_add_test); tc.add("Mtx_mult_test ", tests::Mtx_mult_test); tc.add("Mtx_invert_test ", tests::Mtx_invert_test); -#ifdef ENABLE_LDPC - tc.add("ldpc_Mtx_block_test ", tests::Mtx_block_test); - tc.add("LdpcDecode_pb_test ", tests::LdpcDecode_pb_test); - tc.add("LdpcEncoder_diagonalSolver_test ", tests::LdpcEncoder_diagonalSolver_test); - tc.add("LdpcEncoder_encode_test ", tests::LdpcEncoder_encode_test); - tc.add("LdpcEncoder_encode_g0_test ", tests::LdpcEncoder_encode_g0_test); - tc.add("LdpcS1Encoder_encode_test ", tests::LdpcS1Encoder_encode_test); - tc.add("LdpcS1Encoder_encode_Trans_test ", tests::LdpcS1Encoder_encode_Trans_test); - tc.add("LdpcComposit_RegRepDiagBand_encode_test ", tests::LdpcComposit_RegRepDiagBand_encode_test); - tc.add("LdpcComposit_RegRepDiagBand_Trans_test ", tests::LdpcComposit_RegRepDiagBand_Trans_test); -#endif - tc.add("EACode_encode_basic_test ", EACode_encode_basic_test); tc.add("ExConvCode_encode_basic_test ", ExConvCode_encode_basic_test); - + tc.add("Tools_Pprf_expandOne_test ", Tools_Pprf_expandOne_test); - tc.add("Tools_Pprf_test ", Tools_Pprf_test); tc.add("Tools_Pprf_inter_test ", Tools_Pprf_inter_test); - tc.add("Tools_Pprf_blockTrans_test ", Tools_Pprf_blockTrans_test); + tc.add("Tools_Pprf_ByLeafIndex_test ", Tools_Pprf_ByLeafIndex_test); + tc.add("Tools_Pprf_ByTreeIndex_test ", Tools_Pprf_ByTreeIndex_test); tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); + tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); @@ -103,10 +91,9 @@ namespace tests_libOTe tc.add("OtExt_SoftSpokenMalicious21_Test ", OtExt_SoftSpokenMalicious21_Test); tc.add("OtExt_SoftSpokenMalicious21_Split_Test ", OtExt_SoftSpokenMalicious21_Split_Test); tc.add("DotExt_SoftSpokenMaliciousLeaky_Test ", DotExt_SoftSpokenMaliciousLeaky_Test); - + tc.add("Vole_Noisy_test ", Vole_Noisy_test); tc.add("Vole_Silent_QuasiCyclic_test ", Vole_Silent_QuasiCyclic_test); - tc.add("Vole_Silent_Silver_test ", Vole_Silent_Silver_test); tc.add("Vole_Silent_paramSweep_test ", Vole_Silent_paramSweep_test); tc.add("Vole_Silent_baseOT_test ", Vole_Silent_baseOT_test); tc.add("Vole_Silent_mal_test ", Vole_Silent_mal_test); @@ -116,5 +103,6 @@ namespace tests_libOTe tc.add("NcoOt_Oos_Test ", NcoOt_Oos_Test); tc.add("NcoOt_genBaseOts_Test ", NcoOt_genBaseOts_Test); + }); } diff --git a/libOTe_Tests/Vole_Tests.cpp b/libOTe_Tests/Vole_Tests.cpp index b287607d..52ced20e 100644 --- a/libOTe_Tests/Vole_Tests.cpp +++ b/libOTe_Tests/Vole_Tests.cpp @@ -15,82 +15,78 @@ using namespace oc; #include -using namespace tests_libOTe; - - -#if defined(ENABLE_SILENT_VOLE) || defined(ENABLE_SILENTOT) +#include "libOTe/Tools/CoeffCtx.h" -void Vole_Noisy_test(const oc::CLP& cmd) +using namespace tests_libOTe; +#ifdef ENABLE_SILENT_VOLE +template +void Vole_Noisy_test_impl(u64 n) { - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 123); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); + PRNG prng(CCBlock); - block x = prng.get(); - std::vector y(n), z0(n), z1(n); - prng.get(y); + F delta = prng.get(); + std::vector c(n); + std::vector a(n), b(n); + prng.get(c.data(), c.size()); - NoisyVoleReceiver recv; - NoisyVoleSender send; - - recv.setTimer(timer); - send.setTimer(timer); - - //IOService ios; - //auto chl1 = Session(ios, "localhost:1212", SessionMode::Server).addChannel(); - //auto chl0 = Session(ios, "localhost:1212", SessionMode::Client).addChannel(); + NoisyVoleReceiver recv; + NoisyVoleSender send; auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - BitVector recvChoice((u8*)&x, 128); - std::vector otRecvMsg(128); - std::vector> otSendMsg(128); + Ctx ctx; + BitVector recvChoice = ctx.binaryDecomposition(delta); + std::vector otRecvMsg(recvChoice.size()); + std::vector> otSendMsg(recvChoice.size()); prng.get>(otSendMsg); - for (u64 i = 0; i < 128; ++i) + for (u64 i = 0; i < recvChoice.size(); ++i) otRecvMsg[i] = otSendMsg[i][recvChoice[i]]; - timer.setTimePoint("ot"); - auto p0 = recv.receive(y, z0, prng, otSendMsg, chls[0]); - auto p1 = send.send(x, z1, prng, otRecvMsg, chls[1]); + // compute a,b such that + // + // a = b + c * delta + // + auto p0 = recv.receive(c, a, prng, otSendMsg, chls[0], ctx); + auto p1 = send.send(delta, b, prng, otRecvMsg, chls[1], ctx); eval(p0, p1); for (u64 i = 0; i < n; ++i) { - if (y[i].gf128Mul(x) != (z0[i] ^ z1[i])) + F prod, sum; + + ctx.mul(prod, delta, c[i]); + ctx.minus(sum, a[i], b[i]); + + if (prod != sum) { throw RTE_LOC; } } - timer.setTimePoint("done"); - - //std::cout << timer << std::endl; - } -#else void Vole_Noisy_test(const oc::CLP& cmd) { - throw UnitTestSkipped( - "ENABLE_SILENT_VOLE not defined. " - ); + for (u64 n : {1, 8, 433}) + { + Vole_Noisy_test_impl(n); + Vole_Noisy_test_impl(n); + Vole_Noisy_test_impl(n); + Vole_Noisy_test_impl, u32, CoeffCtxArray>(n); + } } -#endif -#ifdef ENABLE_SILENT_VOLE +namespace +{ -namespace { + template void fakeBase( u64 n, - u64 threads, PRNG& prng, - block delta, - SilentVoleReceiver& recver, - SilentVoleSender& sender) + F delta, + R& recver, + S& sender, + Ctx ctx) { sender.configure(n, SilentBaseType::Base, 128); recver.configure(n, SilentBaseType::Base, 128); @@ -112,281 +108,117 @@ namespace { for (auto i : rng(msg.size())) msg[i] = msg2[i][choices[i]]; - auto y = recver.sampleBaseVoleVals(prng);; - std::vector c(y.size()), b(y.size()); - prng.get(c.data(), c.size()); - for (auto i : rng(y.size())) + // a = b + c * d + // the sender gets b, d + // the recver gets a, c + auto c = recver.sampleBaseVoleVals(prng); + typename Ctx::template Vec a(c.size()), b(c.size()); + + prng.get(b.data(), b.size()); + for (auto i : rng(c.size())) { - b[i] = delta.gf128Mul(y[i]) ^ c[i]; + ctx.mul(a[i], delta, c[i]); + ctx.plus(a[i], b[i], a[i]); } sender.setSilentBaseOts(msg2, b); - - // fake base OTs. - recver.setSilentBaseOts(msg, c); + recver.setSilentBaseOts(msg, a); } } -void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) -{ -#if defined(ENABLE_SILENTOT) && defined(ENABLE_BITPOLYMUL) - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 102043); - u64 nt = cmd.getOr("nt", std::thread::hardware_concurrency()); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - std::vector c(n), z0(n), z1(n); - - SilentVoleReceiver recv; - SilentVoleSender send; - - recv.mMultType = MultType::QuasiCyclic; - send.mMultType = MultType::QuasiCyclic; - recv.setTimer(timer); - send.setTimer(timer); +template +void Vole_Silent_test_impl(u64 n, MultType type, bool debug, bool doFakeBase, bool mal) +{ + using VecF = typename Ctx::template Vec; + using VecG = typename Ctx::template Vec; + Ctx ctx; - recv.mDebug = true; - send.mDebug = true; + block seed = CCBlock; + PRNG prng(seed); auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); + SilentVoleReceiver recv; + SilentVoleSender send; + recv.mMultType = type; + send.mMultType = type; + recv.mDebug = debug; + send.mDebug = debug; + if (mal) + { + recv.mMalType = SilentSecType::Malicious; + send.mMalType = SilentSecType::Malicious; + } - timer.setTimePoint("ot"); - fakeBase(n, nt, prng, x, recv, send); + VecF a(n), b(n); + VecG c(n); + F d = prng.get(); - // c * x = z + m + if (doFakeBase) + fakeBase(n, prng, d, recv, send, ctx); - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); + auto p0 = recv.silentReceive(c, a, prng, chls[0]); + auto p1 = send.silentSend(d, b, prng, chls[1]); eval(p0, p1); - timer.setTimePoint("send"); + for (u64 i = 0; i < n; ++i) { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) + // a = b + c * d + F exp; + ctx.mul(exp, d, c[i]); + ctx.plus(exp, exp, b[i]); + + if (a[i] != exp) { - std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << c[i].gf128Mul(x) << std::endl; - std::cout << " z0[i] " << z0[i] << " ^ z1 " << z1[i] << " = " << (z0[i] ^ z1[i]) << std::endl; throw RTE_LOC; } } - timer.setTimePoint("done"); -#else - throw UnitTestSkipped("ENABLE_BITPOLYMUL not defined." LOCATION); -#endif } + void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { - - Timer timer; - timer.setTimePoint("start"); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - u64 threads = 0; - - - auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - - //recv.mDebug = true; - //send.mDebug = true; - - SilentVoleReceiver recv; - SilentVoleSender send; - // c * x = z + m - - //for (u64 n = 5000; n < 10000; ++n) - for (u64 n : {12,/* 123,465,*/1642,/*4356,34254,*/93425}) + auto debug = cmd.isSet("debug"); + for (u64 n : {128, 45364}) { - std::vector c(n), z0(n), z1(n); - - fakeBase(n, threads, prng, x, recv, send); - - recv.setTimer(timer); - send.setTimer(timer); - - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - timer.setTimePoint("send"); - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, false); + Vole_Silent_test_impl, u32, CoeffCtxArray>(n, DefaultMultType, debug, false, false); } } -void Vole_Silent_Silver_test(const oc::CLP& cmd) +void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { - -#ifdef ENABLE_INSECURE_SILVER - Timer timer; - timer.setTimePoint("start"); - u64 n = cmd.getOr("n", 102043); - u64 nt = cmd.getOr("nt", std::thread::hardware_concurrency()); - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - std::vector c(n), z0(n), z1(n); - - SilentVoleReceiver recv; - SilentVoleSender send; - - recv.mMultType = MultType::slv5; - send.mMultType = MultType::slv5; - - recv.setTimer(timer); - send.setTimer(timer); - - recv.mDebug = false; - send.mDebug = false; - - auto chls = cp::LocalAsyncSocket::makePair(); - - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - fakeBase(n, nt, prng, x, recv, send); - - // c * x = z + m - - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - - eval(p0, p1); - timer.setTimePoint("send"); - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - std::cout << "bad " << i << "\n c[i] " << c[i] << " * x " << x << " = " << c[i].gf128Mul(x) << std::endl; - std::cout << " z0[i] " << z0[i] << " ^ z1 " << z1[i] << " = " << (z0[i] ^ z1[i]) << std::endl; - throw RTE_LOC; - } - } - timer.setTimePoint("done"); +#if defined(ENABLE_SILENTOT) && defined(ENABLE_BITPOLYMUL) + auto debug = cmd.isSet("debug"); + for (u64 n : {128, 333}) + Vole_Silent_test_impl(n, MultType::QuasiCyclic, debug, false, false); +#else + throw UnitTestSkipped("ENABLE_BITPOLYMUL not defined." LOCATION); #endif } - void Vole_Silent_baseOT_test(const oc::CLP& cmd) { - - Timer timer; - timer.setTimePoint("start"); - u64 n = 123; - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - - - - auto chls = cp::LocalAsyncSocket::makePair(); - - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - - //recv.mDebug = true; - //send.mDebug = true; - - SilentVoleReceiver recv; - SilentVoleSender send; - // c * x = z + m - - //for (u64 n = 5000; n < 10000; ++n) - { - std::vector c(n), z0(n), z1(n); - - - recv.setTimer(timer); - send.setTimer(timer); - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); - } + auto debug = cmd.isSet("debug"); + u64 n = 128; + Vole_Silent_test_impl(n, DefaultMultType, debug, true, false); + Vole_Silent_test_impl(n, DefaultMultType, debug, true, false); + Vole_Silent_test_impl, u32, CoeffCtxArray>(n, DefaultMultType, debug, true, false); } void Vole_Silent_mal_test(const oc::CLP& cmd) { - - Timer timer; - timer.setTimePoint("start"); - u64 n = 12343; - block seed = block(0, cmd.getOr("seed", 0)); - PRNG prng(seed); - - block x = prng.get(); - - - - auto chls = cp::LocalAsyncSocket::makePair(); - timer.setTimePoint("net"); - - timer.setTimePoint("ot"); - - //recv.mDebug = true; - //send.mDebug = true; - - SilentVoleReceiver recv; - SilentVoleSender send; - - send.mMalType = SilentSecType::Malicious; - recv.mMalType = SilentSecType::Malicious; - // c * x = z + m - - //for (u64 n = 5000; n < 10000; ++n) + auto debug = cmd.isSet("debug"); + for (u64 n : {45364}) { - std::vector c(n), z0(n), z1(n); - - - recv.setTimer(timer); - send.setTimer(timer); - auto p0 = recv.silentReceive(c, z0, prng, chls[0]); - auto p1 = send.silentSend(x, z1, prng, chls[1]); - - - eval(p0, p1); - - for (u64 i = 0; i < n; ++i) - { - if (c[i].gf128Mul(x) != (z0[i] ^ z1[i])) - { - throw RTE_LOC; - } - } - timer.setTimePoint("done"); + Vole_Silent_test_impl(n, DefaultMultType, debug, false, true); } } @@ -395,7 +227,6 @@ inline u64 eval( macoro::task<>& t1, macoro::task<>& t0, cp::BufferingSocket& s1, cp::BufferingSocket& s0) { - std::cout << "begin " << std::endl; auto e = macoro::make_eager(macoro::when_all_ready(std::move(t0), std::move(t1))); u64 rounds = 0; @@ -405,8 +236,6 @@ inline u64 eval( { s0.processInbound(*b1); ++rounds; - - std::cout << "round " << rounds << std::endl; } } @@ -429,10 +258,7 @@ inline u64 eval( s0.processInbound(*b1); } - ++rounds; - std::cout << "round " << rounds << std::endl; - ++idx; } @@ -448,7 +274,7 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) Timer timer; timer.setTimePoint("start"); - u64 n = 12343; + u64 n = 1233; block seed = block(0, cmd.getOr("seed", 0)); PRNG prng(seed); @@ -457,8 +283,8 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) cp::BufferingSocket chls[2]; - SilentVoleReceiver recv; - SilentVoleSender send; + SilentVoleReceiver recv; + SilentVoleSender send; send.mMalType = SilentSecType::SemiHonest; recv.mMalType = SilentSecType::SemiHonest; @@ -476,7 +302,7 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) send.setTimer(timer); if (jj) { - std::vector c(n), z0(n), z1(n); + AlignedUnVector c(n), z0(n), z1(n); auto p0 = recv.silentReceive(c, z0, prng, chls[0]); auto p1 = send.silentSend(x, z1, prng, chls[1]); #if (defined ENABLE_MRR_TWIST && defined ENABLE_SSE) || \ @@ -488,7 +314,7 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) auto rounds = eval(p0, p1, chls[1], chls[0]); if (rounds != expRound) - throw std::runtime_error(std::to_string(rounds) + "!="+std::to_string(expRound)+". " +COPROTO_LOCATION); + throw std::runtime_error(std::to_string(rounds) + "!=" + std::to_string(expRound) + ". " + COPROTO_LOCATION); for (u64 i = 0; i < n; ++i) @@ -530,9 +356,9 @@ void Vole_Silent_Rounds_test(const oc::CLP& cmd) timer.setTimePoint("done"); } } - #else + namespace { void throwDisabled() { @@ -541,13 +367,20 @@ namespace { ); } } - - +void Vole_Noisy_test(const oc::CLP& cmd) { throwDisabled(); } void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { throwDisabled(); } -void Vole_Silent_Silver_test(const oc::CLP& cmd) { throwDisabled(); } void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { throwDisabled(); } void Vole_Silent_baseOT_test(const oc::CLP& cmd) { throwDisabled(); } void Vole_Silent_mal_test(const oc::CLP& cmd) { throwDisabled(); } void Vole_Silent_Rounds_test(const oc::CLP& cmd) { throwDisabled(); } + #endif +// +// +//void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_Silver_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_paramSweep_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_baseOT_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_mal_test(const oc::CLP& cmd) { throwDisabled(); } +//void Vole_Silent_Rounds_test(const oc::CLP& cmd) { throwDisabled(); } diff --git a/libOTe_Tests/Vole_Tests.h b/libOTe_Tests/Vole_Tests.h index 5ef787ad..664ace30 100644 --- a/libOTe_Tests/Vole_Tests.h +++ b/libOTe_Tests/Vole_Tests.h @@ -10,9 +10,7 @@ void Vole_Noisy_test(const oc::CLP& cmd); void Vole_Silent_QuasiCyclic_test(const oc::CLP& cmd); -void Vole_Silent_Silver_test(const oc::CLP& cmd); void Vole_Silent_paramSweep_test(const oc::CLP& cmd); -void Vole_Noisy_test(const oc::CLP& cmd); void Vole_Silent_baseOT_test(const oc::CLP& cmd); void Vole_Silent_mal_test(const oc::CLP& cmd); void Vole_Silent_Rounds_test(const oc::CLP& cmd);