Skip to content

Commit

Permalink
Add expansions functions to expand ciphertext array (#83)
Browse files Browse the repository at this point in the history
Add ciphertext array expansion functions and corresponding input indices compression functions.

Also include the helper functions for inner product.
  • Loading branch information
RuiyuZhu authored and GitHub Enterprise committed Apr 5, 2024
1 parent 3387fad commit 30dffc7
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 11 deletions.
116 changes: 114 additions & 2 deletions Sources/Pir/PirUtil.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ enum PirUtil<Scheme: HeScheme> {
/// multiples-of-`2^{logStep-1}` positions. After that, sum/subtraction helps cancel coefficients at
/// `2^{logStep}*i`-th or `(2^{logStep}*i + 2^{logStep-1})`-th positions. As the last step, shifting by multiplying
/// the polynomial with `x^-{2^{logStep-1}}` helps compensate for the offset of `2^{logStep-1})`.
static func expandOneCiphertext(
ciphertext: CanonicalCiphertext,
static func expandCiphertextForOneStep(
_ ciphertext: CanonicalCiphertext,
logStep: Int,
using evaluationKey: EvaluationKey<Scheme>) throws -> (CanonicalCiphertext, CanonicalCiphertext)
{
Expand All @@ -35,4 +35,116 @@ enum PirUtil<Scheme: HeScheme> {
try difference.multiplyInversePowerOfX(power: shiftingPower)
return try (ciphertext + c1, difference.convertToCanonicalFormat())
}

/// expand one ciphertext into given number of encrypted constant polynomials
/// The input ciphertext is expected to have zero-coefficient except at multiple-of-2^{logStep-1} positions
/// Each time, the input ciphertext is expanded to two ciphertexts, containing the even/odd non-zero coefficients,
/// respectively. These two ciphertexts are used to generate ceil(outputCount/2) and floor(outputCount/2)
/// ciphertexts, respectively. When only 1 ciphertext is needed to be generated, no further expansion is needed.
/// If outputCount is a power of two, then every resulting ciphertext will come from same number of expansion where
/// each expansion will multiply the coefficients by 2.
/// However when outputCount is not power of two, some of them may experience one less expansion. To make them have
/// the same blow-up factor, we add the ciphertext to itself when returning.
static func expandCiphertext(
_ ciphertext: CanonicalCiphertext,
outputCount: Int,
logStep: Int,
expectedHeight: Int,
using evaluationKey: EvaluationKey<Scheme>) throws -> [CanonicalCiphertext]
{
precondition(outputCount >= 0 && outputCount <= ciphertext.degree)
if outputCount == 1 {
if logStep > expectedHeight {
return [ciphertext]
}
return try [ciphertext + ciphertext]
}
let secondHalfCount = outputCount >> 1
let firstHalfCount = outputCount - secondHalfCount

let (p0, p1) = try expandCiphertextForOneStep(
ciphertext,
logStep: logStep,
using: evaluationKey)
let firstHalf = try expandCiphertext(
p0,
outputCount: firstHalfCount,
logStep: logStep + 1,
expectedHeight: expectedHeight,
using: evaluationKey)
let secondHalf = try expandCiphertext(
p1,
outputCount: secondHalfCount,
logStep: logStep + 1,
expectedHeight: expectedHeight,
using: evaluationKey)
return zip(firstHalf.prefix(secondHalfCount), secondHalf).flatMap { [$0, $1] } + firstHalf
.suffix(firstHalfCount - secondHalfCount)
}

/// expand a ciphertext array into given number of encrypted constant polynomials
static func expandCiphertexts(
_ ciphertexts: [CanonicalCiphertext],
outputCount: Int,
using evaluationKey: EvaluationKey<Scheme>) throws -> [CanonicalCiphertext]
{
precondition((ciphertexts.count - 1) * ciphertexts[0].degree < outputCount)
precondition(ciphertexts.count * ciphertexts[0].degree >= outputCount)
var remainingOutputs = outputCount
return try ciphertexts.flatMap { ciphertext in
let outputToGenerate = min(remainingOutputs, ciphertext.degree)
remainingOutputs -= outputToGenerate
return try expandCiphertext(
ciphertext,
outputCount: outputToGenerate,
logStep: 1,
expectedHeight: outputToGenerate.ceilLog2,
using: evaluationKey)
}
}

/// convert the MulPir indices (i.e. the index of non-zero results after expansion) into plaintext
static func compressInputsForOneCiphertext(totalInputCount: Int, nonZeroInputs: [Int],
context: Context<Scheme>) throws -> Plaintext<Scheme, Coeff>
{
precondition(totalInputCount <= context.degree)
var rawData: [Scheme.Scalar] = Array(repeating: 0, count: context.degree)

let inputCountCeilLog = totalInputCount.ceilLog2
let inverseInputCountCeilLog = try Scheme.Scalar(2).powMod(
exponent: Scheme.Scalar(inputCountCeilLog),
modulus: context.plaintextModulus,
variableTime: true).inverseMod(modulus: context.plaintextModulus, variableTime: true)

for index in nonZeroInputs {
rawData[index] = inverseInputCountCeilLog
}
return try Scheme.encode(context: context, values: rawData, format: .coefficient)
}

/// generate the ciphertext based on the given non-zero positions
static func compressInputs(
totalInputCount: Int,
nonZeroInputs: [Int],
context: Context<Scheme>,
using secretKey: SecretKey<Scheme>) throws -> [CanonicalCiphertext]
{
var remainingInputs = totalInputCount
var processedInputCount = 0
var plaintexts: [Plaintext<Scheme, Coeff>] = []

while remainingInputs > 0 {
let numberOfInputsToProcess = min(remainingInputs, context.degree)
let inputs = nonZeroInputs.filter { x in
(processedInputCount..<(processedInputCount + numberOfInputsToProcess)).contains(x)
}.map { $0 - processedInputCount }
try plaintexts.append(compressInputsForOneCiphertext(
totalInputCount: numberOfInputsToProcess,
nonZeroInputs: inputs,
context: context))
processedInputCount += numberOfInputsToProcess
remainingInputs -= numberOfInputsToProcess
}
return try plaintexts.map { try Scheme.encrypt($0, using: secretKey) }
}
}
33 changes: 33 additions & 0 deletions Sources/SwiftHe/Ciphertext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -514,3 +514,36 @@ public extension Ciphertext where Format == Scheme.CanonicalCiphertextFormat {
try Scheme.relinearize(&self, using: key)
}
}

extension Array {
func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Eval> {
precondition(!isEmpty)
return try dropFirst().reduce(self[0]) { try $0 + $1 }
}

func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Coeff> {
precondition(!isEmpty)
return try dropFirst().reduce(self[0]) { try $0 + $1 }
}

func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat> {
precondition(!isEmpty)
return try dropFirst().reduce(self[0]) { try $0 + $1 }
}
}

public extension ArraySlice {
func innerProduct<Scheme>(plaintexts: ArraySlice<Plaintext<Scheme, Eval>>) throws -> Element
where Element == Ciphertext<Scheme, Eval>
{
// the precondition in sum will fail if self or plaintexts is empty
try (zip(self, plaintexts).map { try $0.0 * $0.1 }).sum()
}

func innerProduct<Scheme>(ciphertexts: Self) throws -> Element
where Element == Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat>
{
// the precondition in sum will fail if self or ciphertexts is empty
try (zip(self, ciphertexts).map { try $0.0 * $0.1 }).sum()
}
}
101 changes: 94 additions & 7 deletions Tests/PirTests/PirUtilTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import TestUtil
import XCTest

class PirUtilTests: XCTestCase {
private func expandOneCiphertextTest<Scheme: HeScheme>(scheme _: Scheme.Type) throws {
private func expandCiphertextForOneStepTest<Scheme: HeScheme>(scheme _: Scheme.Type) throws {
let context: Context<Scheme> = try TestUtils.getTestContext()
let degree = TestUtils.testPolyDegree
let logDegree = degree.log2
Expand All @@ -25,8 +25,8 @@ class PirUtilTests: XCTestCase {
galoisElements: [1 << (logDegree - logStep + 1) + 1],
generateRelinearizationKey: false)
let ciphertext = try Scheme.encrypt(plaintext, using: secretKey)
let expandedCiphertexts = try PirUtil.expandOneCiphertext(
ciphertext: ciphertext,
let expandedCiphertexts = try PirUtil.expandCiphertextForOneStep(
ciphertext,
logStep: logStep,
using: evaluationKey)
let p0: [Scheme.Scalar] = try Scheme.decode(
Expand All @@ -46,9 +46,96 @@ class PirUtilTests: XCTestCase {
}
}

func testExpandOneCiphertext() throws {
try expandOneCiphertextTest(scheme: NoOpScheme.self)
try expandOneCiphertextTest(scheme: Bfv<UInt32>.self)
try expandOneCiphertextTest(scheme: Bfv<UInt64>.self)
private func expandCiphertextTest<Scheme: HeScheme>(scheme _: Scheme.Type) throws {
let context: Context<Scheme> = try TestUtils.getTestContext()
let degree = TestUtils.testPolyDegree
let logDegree = degree.log2
for inputCount in 1...degree {
let data: [Scheme.Scalar] = (0..<inputCount).map { _ in Scheme.Scalar(Int.random(in: 0...1)) }
let nonZeroInputs = data.enumerated().compactMap { $0.element == 0 ? nil : $0.offset }
let plaintext: Plaintext<Scheme, Coeff> = try PirUtil.compressInputsForOneCiphertext(
totalInputCount: inputCount,
nonZeroInputs: nonZeroInputs,
context: context)
let secretKey = try Scheme.generateSecretKey(context: context)
let galoisElements = (1...logDegree).map { (1 << $0) + 1 }
let evaluationKey = try Scheme.generateEvaluationKey(
context: context,
using: secretKey,
galoisElements: galoisElements,
generateRelinearizationKey: false)
let ciphertext = try Scheme.encrypt(plaintext, using: secretKey)
let expandedCiphertexts = try PirUtil.expandCiphertext(
ciphertext,
outputCount: inputCount,
logStep: 1,
expectedHeight: inputCount.ceilLog2,
using: evaluationKey)
XCTAssertEqual(expandedCiphertexts.count, inputCount)
for index in 0..<inputCount {
let decodedData: [Scheme.Scalar] = try Scheme.decode(
plaintext: Scheme.decrypt(expandedCiphertexts[index], using: secretKey),
format: .coefficient)
XCTAssertEqual(decodedData[0], data[index])
for coeff in decodedData.dropFirst() {
XCTAssertEqual(coeff, 0)
}
}
}
}

private func expandCiphertextsTest<Scheme: HeScheme>(scheme _: Scheme.Type) throws {
let context: Context<Scheme> = try TestUtils.getTestContext()
let degree = TestUtils.testPolyDegree
let logDegree = degree.log2
for inputCount in 1...degree * 2 {
let data: [Int] = (0..<inputCount).map { _ in Int.random(in: 0...1) }
let nonZeroInputs = data.enumerated().compactMap { $0.element == 0 ? nil : $0.offset }
let secretKey = try Scheme.generateSecretKey(context: context)
let ciphertexts = try PirUtil.compressInputs(
totalInputCount: inputCount,
nonZeroInputs: nonZeroInputs,
context: context,
using: secretKey)
let galoisElements = (1...logDegree).map { (1 << $0) + 1 }
let evaluationKey = try Scheme.generateEvaluationKey(
context: context,
using: secretKey,
galoisElements: galoisElements,
generateRelinearizationKey: false)

let expandedCiphertexts = try PirUtil.expandCiphertexts(
ciphertexts,
outputCount: inputCount,
using: evaluationKey)
XCTAssertEqual(expandedCiphertexts.count, inputCount)
for index in 0..<inputCount {
let decodedData: [Scheme.Scalar] = try Scheme.decode(
plaintext: Scheme.decrypt(expandedCiphertexts[index], using: secretKey),
format: .coefficient)
XCTAssertEqual(Int(decodedData[0]), data[index])
for coeff in decodedData.dropFirst() {
XCTAssertEqual(coeff, 0)
}
}
}
}

func testExpandCiphertextForOneStep() throws {
try expandCiphertextForOneStepTest(scheme: NoOpScheme.self)
try expandCiphertextForOneStepTest(scheme: Bfv<UInt32>.self)
try expandCiphertextForOneStepTest(scheme: Bfv<UInt64>.self)
}

func testExpandCiphertext() throws {
try expandCiphertextTest(scheme: NoOpScheme.self)
try expandCiphertextTest(scheme: Bfv<UInt32>.self)
try expandCiphertextTest(scheme: Bfv<UInt64>.self)
}

func testExpandCiphertexts() throws {
try expandCiphertextsTest(scheme: NoOpScheme.self)
try expandCiphertextsTest(scheme: Bfv<UInt32>.self)
try expandCiphertextsTest(scheme: Bfv<UInt64>.self)
}
}
10 changes: 8 additions & 2 deletions Tests/SwiftHeTests/HeAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,18 @@ class HeAPITests: XCTestCase {
let evalCiphertext: Ciphertext<Scheme, Eval> = try ciphertextSum1.convertToEvalFormat()
let coeffCiphertext: Ciphertext<Scheme, Coeff> = try evalCiphertext.inverseNtt()

let decryptedData0: [Scheme.Scalar] = try context.decode(plaintext: plaintextSum, format: .coefficient)
XCTAssertEqual(decryptedData0, sumData)
let decodedData: [Scheme.Scalar] = try context.decode(plaintext: plaintextSum, format: .coefficient)
XCTAssertEqual(decodedData, sumData)

let ciphertextSum3 = try [testEnv.ciphertext1, testEnv.ciphertext2].sum()
let ciphertextSum4 = try [testEnv.ciphertext1.convertToEvalFormat(), testEnv.ciphertext2.convertToEvalFormat()]
.sum()

try testEnv.checkDecryptsDecodes(ciphertext: coeffCiphertext, format: .coefficient, expected: sumData)
try testEnv.checkDecryptsDecodes(ciphertext: evalCiphertext, format: .coefficient, expected: sumData)
try testEnv.checkDecryptsDecodes(ciphertext: ciphertextSum2, format: .coefficient, expected: sumData)
try testEnv.checkDecryptsDecodes(ciphertext: ciphertextSum3, format: .coefficient, expected: sumData)
try testEnv.checkDecryptsDecodes(ciphertext: ciphertextSum4, format: .coefficient, expected: sumData)
}

private func schemeSameTypeSubtractionTest<Scheme: HeScheme>(context: Context<Scheme>) throws {
Expand Down
7 changes: 7 additions & 0 deletions Tests/SwiftHeTests/UtilTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,11 @@ class UtilTests: XCTestCase {
let hexString = data.hexEncodedString()
XCTAssertEqual(Data(hexEncoded: hexString), data)
}

func testSum() {
XCTAssertEqual([UInt8]().sum(), 0)
XCTAssertEqual([7].sum(), 7)
XCTAssertEqual([1, 2, 3].sum(), 6)
XCTAssertEqual([UInt8(255), UInt8(2)].sum(), UInt16(257))
}
}

0 comments on commit 30dffc7

Please sign in to comment.