Skip to content

Commit

Permalink
clarify tests
Browse files Browse the repository at this point in the history
Signed-off-by: bytemare <[email protected]>
  • Loading branch information
bytemare committed Dec 26, 2023
1 parent 51c439f commit 940dedb
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 110 deletions.
39 changes: 24 additions & 15 deletions tests/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package voprf_test
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"testing"

Expand Down Expand Up @@ -47,20 +48,38 @@ func TestEvaluationSerde(t *testing.T) {
t.Fatal(err)
}

errSerDeFailed := errors.New("evaluation serde failed")

if !areArraysOfArraysEqual(evaluation.Elements, deser.Elements) {
t.Fatal("evaluation serde failed")
t.Fatal(errSerDeFailed)
}

if bytes.Compare(evaluation.ProofC, evaluation.ProofC) != 0 {
t.Fatal("evaluation serde failed")
t.Fatal(errSerDeFailed)
}

if bytes.Compare(evaluation.ProofS, evaluation.ProofS) != 0 {
t.Fatal("evaluation serde failed")
t.Fatal(errSerDeFailed)
}
}

func serdeExport(t *testing.T, client *voprf.Client) (*voprf.State, *voprf.State) {
export := client.Export()

serialized, err := json.Marshal(export)
if err != nil {
t.Fatal(err)
}

state := &voprf.State{}
if err := json.Unmarshal(serialized, state); err != nil {
t.Fatal(err)
}

return export, state
}

func TestClient_State(t *testing.T) {
func TestClientState(t *testing.T) {
suite := voprf.Ristretto255Sha512
input := []byte("input")
kp := suite.KeyGen() // only used in VOPRF and POPRF
Expand All @@ -75,17 +94,7 @@ func TestClient_State(t *testing.T) {

client.Blind(input, info)

export := client.Export()

serialized, err := json.Marshal(export)
if err != nil {
t.Fatal(err)
}

state := &voprf.State{}
if err := json.Unmarshal(serialized, state); err != nil {
t.Fatal(err)
}
export, state := serdeExport(t, client)

resumed, err := state.RecoverClient()
if err != nil {
Expand Down
199 changes: 109 additions & 90 deletions tests/vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,25 +254,10 @@ func testBlindBatchWithBlinds(t *testing.T, client *voprf.Client, inputs, blinds
}
}

func testOPRF(
t *testing.T,
ciphersuite voprf.Ciphersuite,
mode voprf.Mode,
client *voprf.Client,
server *voprf.Server,
test *test,
) {
func testOPRFServerEvaluation(t *testing.T, server *voprf.Server, test *test) *voprf.Evaluation {
var ev *voprf.Evaluation
var err error

// OPRFClient Blinding
if test.Batch == 1 {
testBlind(t, ciphersuite, client, test.Input[0], test.Blind[0], test.BlindedElement[0], test.Info)
} else {
testBlindBatchWithBlinds(t, client, test.Input, test.Blind, test.BlindedElement, test.Info)
}

// OPRFServer evaluating
var ev *voprf.Evaluation
if test.Batch == 1 {
ev, err = server.EvaluateWithRandom(test.BlindedElement[0], test.NonceR, test.Info)
if err != nil {
Expand All @@ -295,6 +280,51 @@ func testOPRF(
}
}

return ev
}

func testOPRFClientFinalize(t *testing.T, client *voprf.Client, ev *voprf.Evaluation, test *test) {
if test.Batch == 1 {
output, err := client.Finalize(ev, test.Info)
if err != nil {
t.Fatal(err)
}

if !bytes.Equal(test.Output[0], output) {
t.Fatal("finalize() output is not valid.")
}
} else {
output, err := client.FinalizeBatch(ev, test.Info)
if err != nil {
t.Fatal(err)
}

for i, o := range test.Output {
if !bytes.Equal(o, output[i]) {
t.Fatal("finalizeBatch() output is not valid.")
}
}
}
}

func testOPRF(
t *testing.T,
ciphersuite voprf.Ciphersuite,
mode voprf.Mode,
client *voprf.Client,
server *voprf.Server,
test *test,
) {
// OPRFClient Blinding
if test.Batch == 1 {
testBlind(t, ciphersuite, client, test.Input[0], test.Blind[0], test.BlindedElement[0], test.Info)
} else {
testBlindBatchWithBlinds(t, client, test.Input, test.Blind, test.BlindedElement, test.Info)
}

// OPRFServer evaluating
ev := testOPRFServerEvaluation(t, server, test)

// Verify proofs
if mode == voprf.VOPRF || mode == voprf.POPRF {
if !bytes.Equal(test.ProofC, ev.ProofC) {
Expand All @@ -315,27 +345,70 @@ func testOPRF(
}

// OPRFClient finalize
if test.Batch == 1 {
output, err := client.Finalize(ev, test.Info)
if err != nil {
t.Fatal(err)
}
testOPRFClientFinalize(t, client, ev, test)
}

if !bytes.Equal(test.Output[0], output) {
t.Fatal("finalize() output is not valid.")
}
} else {
output, err := client.FinalizeBatch(ev, test.Info)
if err != nil {
t.Fatal(err)
}
func (v vector) testVector(
t *testing.T,
tv *testVector,
suite voprf.Ciphersuite,
mode voprf.Mode,
privKey, serverPublicKey, expectedDST []byte,
) {
test, err := tv.Decode()
if err != nil {
t.Fatal(fmt.Sprintf("batches : %v Failed %v\n", tv.Batch, err))
}

for i, o := range test.Output {
if !bytes.Equal(o, output[i]) {
t.Fatal("finalizeBatch() output is not valid.")
}
}
if err := test.Verify(suite); err != nil {
t.Fatal(err)
}

// Test DeriveKeyPair
seed, err := hex.DecodeString(v.SksSeed)
if err != nil {
t.Fatal(err)
}

keyInfo, err := hex.DecodeString(v.KeyInfo)
if err != nil {
t.Fatal(err)
}

sks, _ := deriveKeyPair(seed, keyInfo, mode, suite)
// log.Printf("sks %v", hex.EncodeToString(serializeScalar(sks, scalarLength(o.id))))
if !bytes.Equal(sks.Encode(), privKey) {
t.Fatalf("DeriveKeyPair yields unexpected output\n\twant: %v\n\tgot : %v", privKey, sks.Encode())
}

// Set up a new server.
server, err := suite.Server(mode, privKey)
if err != nil {
t.Fatalf(
"failed on setting up server %q\nvector value (%d) %v\ndecoded (%d) %v\n",
err,
len(v.SkSm),
v.SkSm,
len(privKey),
privKey,
)
}

if string(expectedDST) != string(dst(hash2groupDSTPrefix, contextString(mode, suite))) {
t.Fatal("GroupDST output is not valid.")
}

client, err := suite.Client(mode, serverPublicKey)
if err != nil {
t.Fatal(err)
}

if string(expectedDST) != string(dst(hash2groupDSTPrefix, contextString(mode, suite))) {
t.Fatal("GroupDST output is not valid.")
}

// test protocol execution
testOPRF(t, v.SuiteID, mode, client, server, test)
}

func (v vector) test(t *testing.T) {
Expand Down Expand Up @@ -365,63 +438,9 @@ func (v vector) test(t *testing.T) {
t.Fatalf("hex decoding errored with %q", err)
}

// Test Multiplicative Mode
for i, tv := range v.TestVectors {
t.Run(fmt.Sprintf("Vector %d", i), func(t *testing.T) {
test, err := tv.Decode()
if err != nil {
t.Fatal(fmt.Sprintf("batches : %v Failed %v\n", tv.Batch, err))
}

if err := test.Verify(suite); err != nil {
t.Fatal(err)
}

// Test DeriveKeyPair
seed, err := hex.DecodeString(v.SksSeed)
if err != nil {
t.Fatal(err)
}

keyInfo, err := hex.DecodeString(v.KeyInfo)
if err != nil {
t.Fatal(err)
}

sks, _ := deriveKeyPair(seed, keyInfo, mode, suite)
// log.Printf("sks %v", hex.EncodeToString(serializeScalar(sks, scalarLength(o.id))))
if !bytes.Equal(sks.Encode(), privKey) {
t.Fatalf("DeriveKeyPair yields unexpected output\n\twant: %v\n\tgot : %v", privKey, sks.Encode())
}

// Set up a new server.
server, err := suite.Server(mode, privKey)
if err != nil {
t.Fatalf(
"failed on setting up server %q\nvector value (%d) %v\ndecoded (%d) %v\n",
err,
len(v.SkSm),
v.SkSm,
len(privKey),
privKey,
)
}

if string(expectedDST) != string(dst(hash2groupDSTPrefix, contextString(mode, suite))) {
t.Fatal("GroupDST output is not valid.")
}

client, err := suite.Client(mode, serverPublicKey)
if err != nil {
t.Fatal(err)
}

if string(expectedDST) != string(dst(hash2groupDSTPrefix, contextString(mode, suite))) {
t.Fatal("GroupDST output is not valid.")
}

// test protocol execution
testOPRF(t, v.SuiteID, mode, client, server, test)
v.testVector(t, &tv, suite, mode, privKey, serverPublicKey, expectedDST)
})
}
}
Expand Down
13 changes: 8 additions & 5 deletions tests/voprf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@ package voprf_test

import (
"encoding/hex"
"errors"
"testing"

"github.com/bytemare/voprf"
)

var errExpectedEquality = errors.New("expected equality")

func makeClientAndServer(t *testing.T, mode voprf.Mode, ciphersuite voprf.Ciphersuite) (*voprf.Client, *voprf.Server) {
server, err := ciphersuite.Server(mode, nil)
if err != nil {
Expand Down Expand Up @@ -115,7 +118,7 @@ func TestAvailability(t *testing.T) {
func TestCiphersuiteGroup(t *testing.T) {
testAll(t, func(c *configuration) {
if c.ciphersuite.Group() != c.group {
t.Fatal("expected equality")
t.Fatal(errExpectedEquality)
}

ciphersuite, err := voprf.FromGroup(c.group)
Expand All @@ -124,15 +127,15 @@ func TestCiphersuiteGroup(t *testing.T) {
}

if ciphersuite != c.ciphersuite {
t.Fatal("expected equality")
t.Fatal(errExpectedEquality)
}
})
}

func TestCiphersuiteHashes(t *testing.T) {
testAll(t, func(c *configuration) {
if c.hash != c.ciphersuite.Hash() {
t.Fatal("expected equality")
t.Fatal(errExpectedEquality)
}
})
}
Expand All @@ -158,7 +161,7 @@ func TestServerKeys(t *testing.T) {

pk := c.ciphersuite.Group().Base().Multiply(private)
if pk.Equal(public) != 1 {
t.Fatal("expected equality")
t.Fatal(errExpectedEquality)
}
})
}
Expand All @@ -184,6 +187,6 @@ func TestDeriveKeyPair(t *testing.T) {
keyPair := ciphersuite.DeriveKeyPair(voprf.OPRF, random, info)

if keyPair.SecretKey.Equal(refSk) != 1 || keyPair.PublicKey.Equal(refPk) != 1 {
t.Fatal("expected equality")
t.Fatal(errExpectedEquality)
}
}

0 comments on commit 940dedb

Please sign in to comment.